docs/examples/workflow/rag.ipynb
This notebook walks through setting up a Workflow to perform basic RAG with reranking.
!pip install -U llama-index
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
Set up tracing to visualize each step in the workflow.
%pip install "openinference-instrumentation-llama-index>=3.0.0" "opentelemetry-proto>=1.12.0" opentelemetry-exporter-otlp opentelemetry-sdk
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as HTTPSpanExporter,
)
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
# Add Phoenix API Key for tracing
PHOENIX_API_KEY = "<YOUR-PHOENIX-API-KEY>"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
# Add Phoenix
span_phoenix_processor = SimpleSpanProcessor(
HTTPSpanExporter(endpoint="https://app.phoenix.arize.com/v1/traces")
)
# Add them to the tracer
tracer_provider = trace_sdk.TracerProvider()
tracer_provider.add_span_processor(span_processor=span_phoenix_processor)
# Instrument the application
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
!mkdir -p data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use asyncio.run() to start an async event loop if one isn't already running.
async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())
RAG + Reranking consists of some clearly defined steps
With this in mind, we can create events and workflow steps to follow this process!
To handle these steps, we need to define a few events:
The other steps will use the built-in StartEvent and StopEvent events.
from llama_index.core.workflow import Event
from llama_index.core.schema import NodeWithScore
class RetrieverEvent(Event):
"""Result of running retrieval"""
nodes: list[NodeWithScore]
class RerankEvent(Event):
"""Result of running reranking on retrieved nodes"""
nodes: list[NodeWithScore]
With our events defined, we can construct our workflow and steps.
Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.response_synthesizers import CompactAndRefine
from llama_index.core.postprocessor.llm_rerank import LLMRerank
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
class RAGWorkflow(Workflow):
@step
async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
"""Entry point to ingest a document, triggered by a StartEvent with `dirname`."""
dirname = ev.get("dirname")
if not dirname:
return None
documents = SimpleDirectoryReader(dirname).load_data()
index = VectorStoreIndex.from_documents(
documents=documents,
embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
)
return StopEvent(result=index)
@step
async def retrieve(
self, ctx: Context, ev: StartEvent
) -> RetrieverEvent | None:
"Entry point for RAG, triggered by a StartEvent with `query`."
query = ev.get("query")
index = ev.get("index")
if not query:
return None
print(f"Query the database with: {query}")
# store the query in the global context
await ctx.store.set("query", query)
# get the index from the global context
if index is None:
print("Index is empty, load some documents before querying!")
return None
retriever = index.as_retriever(similarity_top_k=2)
nodes = await retriever.aretrieve(query)
print(f"Retrieved {len(nodes)} nodes.")
return RetrieverEvent(nodes=nodes)
@step
async def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
# Rerank the nodes
ranker = LLMRerank(
choice_batch_size=5, top_n=3, llm=OpenAI(model="gpt-4o-mini")
)
print(await ctx.store.get("query", default=None), flush=True)
new_nodes = ranker.postprocess_nodes(
ev.nodes, query_str=await ctx.store.get("query", default=None)
)
print(f"Reranked nodes to {len(new_nodes)}")
return RerankEvent(nodes=new_nodes)
@step
async def synthesize(self, ctx: Context, ev: RerankEvent) -> StopEvent:
"""Return a streaming response using reranked nodes."""
llm = OpenAI(model="gpt-4o-mini")
summarizer = CompactAndRefine(llm=llm, streaming=True, verbose=True)
query = await ctx.store.get("query", default=None)
response = await summarizer.asynthesize(query, nodes=ev.nodes)
return StopEvent(result=response)
And thats it! Let's explore the workflow we wrote a bit.
StartEvent)w = RAGWorkflow()
# Ingest the documents
index = await w.run(dirname="data")
# Run a query
result = await w.run(query="How was Llama2 trained?", index=index)
async for chunk in result.async_response_gen():
print(chunk, end="", flush=True)