examples/rag-retriever/rag_feast.ipynb
%pip install --quiet feast[milvus] sentence-transformers datasets
%pip install bigtree==0.19.2
%pip install marshmallow==3.10.0
from datasets import load_dataset
# load wikipedia dataset - 1% of the training split
dataset = load_dataset(
"facebook/wiki_dpr",
"psgs_w100.nq.exact",
split="train[:1%]",
with_index=False,
)
def chunk_dataset(examples, chunk_size=100, overlap=20, max_chars=500):
all_chunks = []
all_ids = []
all_titles = []
for i, text in enumerate(examples['text']): # Iterate over texts in the batch
words = text.split()
chunks = []
for j in range(0, len(words), chunk_size - overlap):
chunk_words = words[j:j + chunk_size]
if len(chunk_words) < 20:
continue
chunk_text_value = ' '.join(chunk_words) # Store the chunk text
chunk_text_value = chunk_text_value[:max_chars]
chunks.append(chunk_text_value)
all_ids.append(f"{examples['id'][i]}_{j}") # Unique ID for the chunk
all_titles.append(examples['title'][i])
all_chunks.extend(chunks)
return {'id': all_ids, 'title': all_titles, 'text': all_chunks}
chunked_dataset = dataset.map(
chunk_dataset,
batched=True,
remove_columns=dataset.column_names,
num_proc=1
)
from sentence_transformers import SentenceTransformer
sentences = chunked_dataset["text"]
# Take the first 100 sentences
test_sentences = sentences[:100]
# load pretrained sentence transformer model and create embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(test_sentences, show_progress_bar=True, batch_size=64, device="cuda")
print(f"Generated embeddings of shape: {embeddings.shape}")
%mkdir feature_repo/data
import pandas as pd
from datetime import datetime, timezone
# Create DataFrame
df = pd.DataFrame({
"passage_id": list(range(len(test_sentences))),
"passage_text": test_sentences,
"embedding": pd.Series(
[embedding.tolist() for embedding in embeddings],
dtype=object
),
"event_timestamp": [datetime.now(timezone.utc) for _ in test_sentences],
})
print("DataFrame Info:")
print(df.head())
print(df["embedding"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts()) # Check lengths
# Save to Parquet
df.to_parquet("feature_repo/data/wiki_dpr.parquet", index=False)
print("Saved to wiki_dpr.parquet")
%cd feature_repo
!feast apply
from feast import FeatureStore
import pandas as pd
store = FeatureStore(repo_path=".")
df = pd.read_parquet("./data/wiki_dpr.parquet")
store.write_to_online_store(feature_view_name='wiki_passages', df=df)
from transformers import AutoTokenizer, AutoModelForCausalLM, RagConfig, AutoModel
generator_model_id = "ibm-granite/granite-3.2-2b-instruct"
generator_model = AutoModelForCausalLM.from_pretrained(generator_model_id)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_id)
import sys
sys.path.append("..")
from ragproject_repo import wiki_passage_feature_view
from feast.vector_store import FeastVectorStore
from feast.rag_retriever import FeastIndex, FeastRAGRetriever
generator_config=generator_model.config
question_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
question_encoder_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
query_encoder_config = {
"model_type": "bert",
"hidden_size": 384
}
vector_store = FeastVectorStore(
repo_path=".",
rag_view=wiki_passage_feature_view,
features=["wiki_passages:passage_text", "wiki_passages:embedding", "wiki_passages:passage_id"]
)
feast_index = FeastIndex()
config = RagConfig(
question_encoder=query_encoder_config,
generator=generator_config.to_dict(),
index=feast_index
)
retriever = FeastRAGRetriever(
question_encoder=question_encoder,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
feast_repo_path=".",
feature_view=vector_store.rag_view,
features=vector_store.features,
generator_model=generator_model,
search_type="vector",
id_field="passage_id",
text_field="passage_text",
config=config,
index=feast_index,
)
query = "What is the capital of Ireland?"
answer = retriever.generate_answer(query, top_k=10)
print("Generated Answer:", answer)