Back to Feast

RAG with Feast Feature Store

examples/rag-retriever/rag_feast.ipynb

0.63.04.9 KB
Original Source

RAG with Feast Feature Store

Install required dependencies

python
%pip install --quiet feast[milvus] sentence-transformers datasets
%pip install bigtree==0.19.2
%pip install marshmallow==3.10.0 

Retrieve test dataset and chunk it

python
from datasets import load_dataset
# load wikipedia dataset - 1% of the training split
dataset = load_dataset(
    "facebook/wiki_dpr",
    "psgs_w100.nq.exact",
    split="train[:1%]",
    with_index=False,
)
python
def chunk_dataset(examples, chunk_size=100, overlap=20, max_chars=500):
    all_chunks = []
    all_ids = []
    all_titles = []

    for i, text in enumerate(examples['text']):  # Iterate over texts in the batch
        words = text.split()
        chunks = []
        for j in range(0, len(words), chunk_size - overlap):
            chunk_words = words[j:j + chunk_size]
            if len(chunk_words) < 20:
                continue
            chunk_text_value = ' '.join(chunk_words)  # Store the chunk text
            chunk_text_value = chunk_text_value[:max_chars]
            chunks.append(chunk_text_value)
            all_ids.append(f"{examples['id'][i]}_{j}")  # Unique ID for the chunk
            all_titles.append(examples['title'][i])

        all_chunks.extend(chunks)

    return {'id': all_ids, 'title': all_titles, 'text': all_chunks}


chunked_dataset = dataset.map(
    chunk_dataset,
    batched=True,
    remove_columns=dataset.column_names,
    num_proc=1
)

Define embedding model and generate embeddings

python
from sentence_transformers import SentenceTransformer

sentences = chunked_dataset["text"]
# Take the first 100 sentences
test_sentences = sentences[:100]
# load pretrained sentence transformer model and create embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(test_sentences, show_progress_bar=True, batch_size=64, device="cuda")

print(f"Generated embeddings of shape: {embeddings.shape}")

Create parquet file as historical data source

python
%mkdir feature_repo/data
python
import pandas as pd
from datetime import datetime, timezone

# Create DataFrame
df = pd.DataFrame({
    "passage_id": list(range(len(test_sentences))),
    "passage_text": test_sentences,
    "embedding": pd.Series(
         [embedding.tolist() for embedding in embeddings],
         dtype=object
     ),
    "event_timestamp": [datetime.now(timezone.utc) for _ in test_sentences],
})

print("DataFrame Info:")
print(df.head())
print(df["embedding"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts())  # Check lengths

# Save to Parquet
df.to_parquet("feature_repo/data/wiki_dpr.parquet", index=False)
print("Saved to wiki_dpr.parquet")

Ensure you are in the feature_repo directory and run feast apply

python
%cd feature_repo
python
!feast apply

Write to the Milvus online store from the parquet file

python
from feast import FeatureStore
import pandas as pd

store = FeatureStore(repo_path=".")

df = pd.read_parquet("./data/wiki_dpr.parquet")
store.write_to_online_store(feature_view_name='wiki_passages', df=df)

Define generator model

python
from transformers import AutoTokenizer, AutoModelForCausalLM, RagConfig, AutoModel

generator_model_id = "ibm-granite/granite-3.2-2b-instruct"
generator_model = AutoModelForCausalLM.from_pretrained(generator_model_id)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_id)

Initialize Feast Vector Store, Feast Index and FeastRAGRetriever

python
import sys
sys.path.append("..")
from ragproject_repo import wiki_passage_feature_view
from feast.vector_store import FeastVectorStore
from feast.rag_retriever import FeastIndex, FeastRAGRetriever

generator_config=generator_model.config
question_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
question_encoder_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")


query_encoder_config = {
    "model_type": "bert",
    "hidden_size": 384
}

vector_store = FeastVectorStore(
    repo_path=".",
    rag_view=wiki_passage_feature_view,
    features=["wiki_passages:passage_text", "wiki_passages:embedding", "wiki_passages:passage_id"]
)

feast_index = FeastIndex()

config = RagConfig(
    question_encoder=query_encoder_config,
    generator=generator_config.to_dict(),
    index=feast_index
)
retriever = FeastRAGRetriever(
    question_encoder=question_encoder,
    question_encoder_tokenizer=question_encoder_tokenizer,
    generator_tokenizer=generator_tokenizer,
    feast_repo_path=".",
    feature_view=vector_store.rag_view,
    features=vector_store.features,
    generator_model=generator_model, 
    search_type="vector",
    id_field="passage_id",
    text_field="passage_text",
    config=config,
    index=feast_index,
)

Submit a query

python
query = "What is the capital of Ireland?"
answer = retriever.generate_answer(query, top_k=10)
print("Generated Answer:", answer)