Back to Llama Index

Contextual Retrieval With Llama Index

docs/examples/metadata_extraction/DocumentContextExtractor.ipynb

0.14.214.2 KB
Original Source

Contextual Retrieval With Llama Index

This notebook covers contextual retrieval with llama_index DocumentContextExtractor

Based on an Anthropic blost post, the concept is to:

  1. Use an LLM to generate a 'context' for each chunk based on the entire document
  2. embed the chunk + context together
  3. reap the benefits of higher RAG accuracy

While you can also do this manually, the DocumentContextExtractor offers a lot of convenience and error handling, plus you can integrate it into your llama_index pipelines! Let's get started.

NOTE: This notebook costs about $0.02 everytime you run it.

Install Packages

python
%pip install llama-index
%pip install llama-index-readers-file
%pip install llama-index-embeddings-huggingface
%pip install llama-index-llms-openai

Setup an LLM

You can use the MockLLM or you can use a real LLM of your choice here. flash 2 and gpt-4o-mini work well.

python
from llama_index.llms.openai import OpenAI
from llama_index.core import Settings

OPENAI_API_KEY = "sk-..."
llm = OpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY)
Settings.llm = llm

Setup a data pipeline

we'll need an embedding model, an index store, a vectore store, and a way to split tokens.

Build Pipeline & Index

python
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.storage.docstore.simple_docstore import (
    SimpleDocumentStore,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Initialize document store and embedding model
docstore = SimpleDocumentStore()
embed_model = HuggingFaceEmbedding(model_name="baai/bge-small-en-v1.5")

# Create storage contexts
storage_context = StorageContext.from_defaults(docstore=docstore)
storage_context_no_extra_context = StorageContext.from_defaults()
text_splitter = TokenTextSplitter(
    separator=" ", chunk_size=256, chunk_overlap=10
)

DocumentContextExtractor

python
# This is the new part!

from llama_index.core.extractors import DocumentContextExtractor

context_extractor = DocumentContextExtractor(
    # these 2 are mandatory
    docstore=docstore,
    max_context_length=128000,
    # below are optional
    llm=llm,  # default to Settings.llm
    oversized_document_strategy="warn",
    max_output_tokens=100,
    key="context",
    prompt=DocumentContextExtractor.SUCCINCT_CONTEXT_PROMPT,
)

Load Data

python
!wget "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay_ambiguated.txt" -O "paul_graham_essay_ambiguated.txt"
python
from llama_index.core import SimpleDirectoryReader

reader = SimpleDirectoryReader(
    input_files=["./paul_graham_essay_ambiguated.txt"]
)
documents = reader.load_data()

Run the pipeline, then search

python
import nest_asyncio

nest_asyncio.apply()

# need to add documents directly for the DocumentContextExtractor to work
storage_context.docstore.add_documents(documents)
index = VectorStoreIndex.from_documents(
    documents=documents,
    storage_context=storage_context,
    embed_model=embed_model,
    transformations=[text_splitter, context_extractor],
)

index_nocontext = VectorStoreIndex.from_documents(
    documents=documents,
    storage_context=storage_context_no_extra_context,
    embed_model=embed_model,
    transformations=[text_splitter],
)
python
test_question = "Which chunks of text discuss the IBM 704?"
retriever = index.as_retriever(similarity_top_k=2)
nodes_fromcontext = retriever.retrieve(test_question)

retriever_nocontext = index_nocontext.as_retriever(similarity_top_k=2)
nodes_nocontext = retriever_nocontext.retrieve(test_question)
python
# Print each node's content
print("==========")
print("NO CONTEXT")
for i, node in enumerate(nodes_nocontext, 1):
    print(f"\nChunk {i}:")
    print(f"Score: {node.score}")  # Similarity score
    print(f"Content: {node.node.text}")  # The actual text content

# Print each node's content
print("==========")
print("WITH CONTEXT")
for i, node in enumerate(nodes_fromcontext, 1):
    print(f"\nChunk {i}:")
    print(f"Score: {node.score}")  # Similarity score
    print(f"Content: {node.node.text}")  # The actual text content