Back to Feast

%pip install --quiet feast[milvus] sentence-transformers datasets

examples/rag-retriever/rag_feast_docembedder.ipynb

0.63.04.0 KB
Original Source
python
# %pip install --quiet feast[milvus] sentence-transformers datasets
# %pip install bigtree==0.19.2
# %pip install marshmallow==3.10.0 
python
from datasets import load_dataset
# load wikipedia dataset - 1% of the training split
dataset = load_dataset(
    "facebook/wiki_dpr",
    "psgs_w100.nq.exact",
    split="train[:1%]",
    with_index=False,
    trust_remote_code=True,
)

python
dataset.column_names
df = dataset.select(range(100)).to_pandas()
df.head()
python
import yaml
import os


def write_feature_store_yaml(file_path: str, project_name: str) -> str:
    """
    Write a feature_store.yaml file to the specified path.

    Args:
        file_path: Full path where the YAML file should be written
                   (e.g. "feature_repo/feature_store.yaml").
        project_name: The project name to use in the YAML.

    Returns:
        The absolute path of the written file.
    """
    config = {
        "project": project_name,
        "provider": "local",
        "registry": "data/registry.db",
        "online_store": {
            "type": "milvus",
            "host": "http://localhost",
            "port": 19530,
            "vector_enabled": True,
            "embedding_dim": 384,
            "index_type": "FLAT",
            "metric_type": "COSINE",
        },
        "offline_store": {
            "type": "file",
        },
        "entity_key_serialization_version": 3,
        "auth": {
            "type": "no_auth",
        },
    }

    os.makedirs(os.path.dirname(os.path.abspath(file_path)), exist_ok=True)

    with open(file_path, "w") as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)

    return os.path.abspath(file_path)
python
%mkdir feature_repo_docebedder
!pwd
python
path = write_feature_store_yaml("feature_repo_docebedder/feature_store.yaml", "my_project")
print(f"YAML written to: {path}")
python
from feast import DocEmbedder

de = DocEmbedder(repo_path="feature_repo_docebedder", feature_view_name="text_feature_view",yaml_file="feature_store.yaml")
python
de.embed_documents(documents=df, id_column="id", source_column="text", column_mapping= ("text", "text_embedding"))
python
%cd feature_repo_docebedder
python
from feast import FeatureStore
import pandas as pd

store = FeatureStore(repo_path=".")
python
from transformers import AutoTokenizer, AutoModelForCausalLM, RagConfig, AutoModel

generator_model_id = "ibm-granite/granite-3.2-2b-instruct"
generator_model = AutoModelForCausalLM.from_pretrained(generator_model_id)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_id)
python
import sys
sys.path.append("..")
from text_feature_view import text_feature_view
from feast.vector_store import FeastVectorStore
from feast.rag_retriever import FeastIndex, FeastRAGRetriever

generator_config=generator_model.config
question_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
question_encoder_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")


query_encoder_config = {
    "model_type": "bert",
    "hidden_size": 384
}

vector_store = FeastVectorStore(
    repo_path=".",
    rag_view=text_feature_view,
    features=["text_feature_view:text", "text_feature_view:embedding", "text_feature_view:passage_id","text_feature_view:source_id"]
)

feast_index = FeastIndex()

config = RagConfig(
    question_encoder=query_encoder_config,
    generator=generator_config.to_dict(),
    index=feast_index
)
retriever = FeastRAGRetriever(
    question_encoder=question_encoder,
    question_encoder_tokenizer=question_encoder_tokenizer,
    generator_tokenizer=generator_tokenizer,
    feast_repo_path=".",
    feature_view=vector_store.rag_view,
    features=vector_store.features,
    generator_model=generator_model, 
    search_type="vector",
    id_field="passage_id",
    text_field="text",
    config=config,
    index=feast_index,
)
python
query = "What is the capital of Ireland?"
answer = retriever.generate_answer(query, top_k=10)
print("Generated Answer:", answer)