Back to Llama Index

Context-Augmented Function Calling Agent

docs/examples/agent/openai_agent_context_retrieval.ipynb

0.14.217.3 KB
Original Source

<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/agent/openai_agent_context_retrieval.ipynb" target="_parent"></a>

Context-Augmented Function Calling Agent

In this tutorial, we show you how to to make your agent context-aware.

Our indexing/retrieval modules help to remove the complexity of having too many functions to fit in the prompt.

Initial Setup

Here we setup a normal FunctionAgent, and then augment it with context. This agent will perform retrieval first before calling any tools. This can help ground the agent's tool picking and answering capabilities in context.

If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.

python
%pip install llama-index
python
import os

os.environ["OPENAI_API_KEY"] = "sk-..."
python
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core.settings import Settings

Settings.llm = OpenAI(model="gpt-4o-mini")
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
python
import json
from typing import Sequence

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    load_index_from_storage,
)
from llama_index.core.tools import QueryEngineTool
python
try:
    storage_context = StorageContext.from_defaults(
        persist_dir="./storage/march"
    )
    march_index = load_index_from_storage(storage_context)

    storage_context = StorageContext.from_defaults(
        persist_dir="./storage/june"
    )
    june_index = load_index_from_storage(storage_context)

    storage_context = StorageContext.from_defaults(
        persist_dir="./storage/sept"
    )
    sept_index = load_index_from_storage(storage_context)

    index_loaded = True
except:
    index_loaded = False

Download Data

python
!mkdir -p 'data/10q/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10q/uber_10q_march_2022.pdf' -O 'data/10q/uber_10q_march_2022.pdf'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10q/uber_10q_june_2022.pdf' -O 'data/10q/uber_10q_june_2022.pdf'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10q/uber_10q_sept_2022.pdf' -O 'data/10q/uber_10q_sept_2022.pdf'
python
# build indexes across the three data sources
if not index_loaded:
    # load data
    march_docs = SimpleDirectoryReader(
        input_files=["./data/10q/uber_10q_march_2022.pdf"]
    ).load_data()
    june_docs = SimpleDirectoryReader(
        input_files=["./data/10q/uber_10q_june_2022.pdf"]
    ).load_data()
    sept_docs = SimpleDirectoryReader(
        input_files=["./data/10q/uber_10q_sept_2022.pdf"]
    ).load_data()

    # build index
    march_index = VectorStoreIndex.from_documents(march_docs)
    june_index = VectorStoreIndex.from_documents(june_docs)
    sept_index = VectorStoreIndex.from_documents(sept_docs)

    # persist index
    march_index.storage_context.persist(persist_dir="./storage/march")
    june_index.storage_context.persist(persist_dir="./storage/june")
    sept_index.storage_context.persist(persist_dir="./storage/sept")
python
march_engine = march_index.as_query_engine(similarity_top_k=3)
june_engine = june_index.as_query_engine(similarity_top_k=3)
sept_engine = sept_index.as_query_engine(similarity_top_k=3)
python
query_engine_tools = [
    QueryEngineTool.from_defaults(
        query_engine=march_engine,
        name="uber_march_10q",
        description=(
            "Provides information about Uber 10Q filings for March 2022. "
            "Use a detailed plain text question as input to the tool."
        ),
    ),
    QueryEngineTool.from_defaults(
        query_engine=june_engine,
        name="uber_june_10q",
        description=(
            "Provides information about Uber financials for June 2021. "
            "Use a detailed plain text question as input to the tool."
        ),
    ),
    QueryEngineTool.from_defaults(
        query_engine=sept_engine,
        name="uber_sept_10q",
        description=(
            "Provides information about Uber financials for Sept 2021. "
            "Use a detailed plain text question as input to the tool."
        ),
    ),
]

Try Context-Augmented Agent

Here we augment our agent with context in different settings:

  • toy context: we define some abbreviations that map to financial terms (e.g. R=Revenue). We supply this as context to the agent
python
from llama_index.core import Document
from llama_index.core.agent.workflow import FunctionAgent
python
# toy index - stores a list of abbreviations
texts = [
    "Abbreviation: 'Y' = Revenue",
    "Abbreviation: 'X' = Risk Factors",
    "Abbreviation: 'Z' = Costs",
]
docs = [Document(text=t) for t in texts]
context_index = VectorStoreIndex.from_documents(docs)
context_retriever = context_index.as_retriever(similarity_top_k=2)
python
from llama_index.core.tools import BaseTool

system_prompt_template = """You are a helpful assistant. 
Here is some context that you can use to answer the user's question and for help with picking the right tool:

{context}
"""


async def get_agent_with_context_awareness(
    query: str, context_retriever, tools: list[BaseTool]
) -> FunctionAgent:
    context_nodes = await context_retriever.aretrieve(query)
    context_text = "\n----\n".join([n.node.text for n in context_nodes])

    return FunctionAgent(
        tools=tools,
        llm=OpenAI(model="gpt-4o"),
        system_prompt=system_prompt_template.format(context=context_text),
    )
python
query = "What is the 'X' of March 2022?"
agent = await get_agent_with_context_awareness(
    query, context_retriever, query_engine_tools
)

response = await agent.run(query)
python
print(str(response))
python
query = "What is the 'Y' and 'Z' in September 2022?"
agent = await get_agent_with_context_awareness(
    query, context_retriever, query_engine_tools
)

response = await agent.run(query)
python
print(str(response))

Managing Context/Memory

By default, each .run() call is stateless. We can manage context by using a serializable Context object.

python
from llama_index.core.workflow import Context

ctx = Context(agent)

query = "What is the 'Y' and 'Z' in September 2022?"
agent = await get_agent_with_context_awareness(
    query, context_retriever, query_engine_tools
)
response = await agent.run(query, ctx=ctx)

query = "What did I just ask?"
agent = await get_agent_with_context_awareness(
    query, context_retriever, query_engine_tools
)
response = await agent.run(query, ctx=ctx)
print(str(response))

Use Uber 10-Q as context, use Calculator as Tool

python
from llama_index.core.tools import FunctionTool


def magic_formula(revenue: int, cost: int) -> int:
    """Runs MAGIC_FORMULA on revenue and cost."""
    return revenue - cost


magic_tool = FunctionTool.from_defaults(magic_formula)
python
context_retriever = sept_index.as_retriever(similarity_top_k=3)
python
query = "Can you run MAGIC_FORMULA on Uber's revenue and cost?"
agent = await get_agent_with_context_awareness(
    query, context_retriever, [magic_tool]
)
response = await agent.run(query)
print(str(response))