Back to Graphrag

Function Tool Calling

packages/graphrag-llm/notebooks/10_tool_calling.ipynb

3.0.97.6 KB
Original Source

Function Tool Calling

In order to use function tools, the completion endpoint needs a json schema of the function(s). This notebook uses pydantic to describe a function and its parameters and the OpenAI built-in pydantic_function_tool to create the necessary json schema. Other techniques may be used to create a definition for your functions.

Manual Function Tool Calling

This example demonstrates function tool calling by manually using pydantic and pydantic_function_tool. See the next example for a simplified approach.

python
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import json
import os

from dotenv import load_dotenv
from graphrag_llm.completion import LLMCompletion, create_completion
from graphrag_llm.config import AuthMethod, ModelConfig
from graphrag_llm.types import LLMCompletionResponse
from graphrag_llm.utils import (
    CompletionMessagesBuilder,
)
from openai import pydantic_function_tool
from pydantic import BaseModel, ConfigDict, Field

load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
model_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    azure_deployment_name=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)
llm_completion: LLMCompletion = create_completion(model_config)


class AddTwoNumbers(BaseModel):
    """Input Argument for add two numbers."""

    model_config = ConfigDict(
        extra="forbid",
    )

    a: int = Field(description="The first number to add.")
    b: int = Field(description="The second number to add.")


# The actual function
def add_two_numbers(options: AddTwoNumbers) -> int:
    """Add two numbers."""
    return options.a + options.b


add_definition = pydantic_function_tool(
    AddTwoNumbers,
    # Function name and description
    name="my_add_two_numbers_function",
    description="Add two numbers.",
)

# Mapping of available functions
available_functions = {
    "my_add_two_numbers_function": {
        "function": add_two_numbers,
        "input_model": AddTwoNumbers,
    },
}

messages_builder = CompletionMessagesBuilder().add_user_message(
    "Add 5 and 7 using a function call."
)

response: LLMCompletionResponse = llm_completion.completion(
    messages=messages_builder.build(),
    tools=[add_definition],
)  # type: ignore

if not response.choices[0].message.tool_calls:
    msg = "No function call found in response."
    raise ValueError(msg)

# Add the assistant message with the function call to the message history
messages_builder.add_assistant_message(
    message=response.choices[0].message,
)

for tool_call in response.choices[0].message.tool_calls:
    tool_id = tool_call.id
    if tool_call.type != "function":
        continue
    function_name = tool_call.function.name
    function_args = tool_call.function.arguments

    args_dict = json.loads(function_args)

    InputModel = available_functions[function_name]["input_model"]
    function = available_functions[function_name]["function"]
    input_options = InputModel(**args_dict)

    result = function(input_options)

    messages_builder.add_tool_message(
        content=str(result),
        tool_call_id=tool_id,
    )

final_response: LLMCompletionResponse = llm_completion.completion(
    messages=messages_builder.build(),
)  # type: ignore
print(final_response.content)

Function Tool Definition

python
# View the output schema
# This is what is passed to the completion tools param
# Created using pydantic and pydantic_function_tool
# but may be created manually as well
print(json.dumps(add_definition, indent=2))

Tool Calling with FunctionToolManager

If using pydantic to describe function arguments, you can use the FunctionToolManager to register functions, produce defintions, and call functions in response to the LLM. This helps automate some of the above work.

The following example demonstrates calling multiple functions in one LLM call.

python
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import os

from dotenv import load_dotenv
from graphrag_llm.completion import LLMCompletion, create_completion
from graphrag_llm.config import AuthMethod, ModelConfig
from graphrag_llm.types import LLMCompletionResponse
from graphrag_llm.utils import (
    CompletionMessagesBuilder,
    FunctionToolManager,
)
from pydantic import BaseModel, ConfigDict, Field

load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
model_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    azure_deployment_name=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)
llm_completion: LLMCompletion = create_completion(model_config)


class NumbersInput(BaseModel):
    """Numbers input."""

    model_config = ConfigDict(
        extra="forbid",
    )

    a: int = Field(description="The first number.")
    b: int = Field(description="The second number.")


def add(options: NumbersInput) -> str:
    """Add two numbers."""
    # Print something to ensure function is called for verification
    print("Adding numbers:", options.a, options.b)
    return str(options.a + options.b)


def multiply(options: NumbersInput) -> str:
    """Multiply two numbers."""
    # Print something to ensure function is called for verification
    print("Multiplying numbers:", options.a, options.b)
    return str(options.a * options.b)


class TextInput(BaseModel):
    """Text input."""

    model_config = ConfigDict(
        extra="forbid",
    )

    test: str = Field(description="The string to reverse.")


def reverse_text(options: TextInput) -> str:
    """Reverse a string."""
    # Print something to ensure function is called for verification
    print("Reversing text:", options.test)
    return options.test[::-1]


function_tool_manager = FunctionToolManager()

function_tool_manager.register_function_tool(
    name="add",
    description="Add two numbers.",
    function=add,
    input_model=NumbersInput,
)
function_tool_manager.register_function_tool(
    name="multiply",
    description="Multiply two numbers.",
    function=multiply,
    input_model=NumbersInput,
)
function_tool_manager.register_function_tool(
    name="reverse_text",
    description="Reverse a string.",
    function=reverse_text,
    input_model=TextInput,
)


messages_builder = CompletionMessagesBuilder().add_user_message(
    "What is 3 + 8 and 9 * 5? Also, reverse the string 'GraphRAG'."
)

# Multiple tool calls in parallel
response: LLMCompletionResponse = llm_completion.completion(
    messages=messages_builder.build(),
    tools=function_tool_manager.definitions(),
    parallel_tool_calls=True,
)  # type: ignore

# Add the assistant message with the function call to the message history
messages_builder.add_assistant_message(
    message=response.choices[0].message,
)

tool_results = function_tool_manager.call_functions(response)

for tool_message in tool_results:
    messages_builder.add_tool_message(**tool_message)

final_response: LLMCompletionResponse = llm_completion.completion(
    messages=messages_builder.build(),
)  # type: ignore
print(final_response.content)

MCP Tools

Not currently supported. graphrag_llm currently only implements the completion endpoints which do not support MCP tools.