Back to Graphrag

Templating

packages/graphrag-llm/notebooks/11_templating.ipynb

3.0.92.7 KB
Original Source

Templating

python
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import os

from dotenv import load_dotenv
from graphrag_llm.completion import LLMCompletion, create_completion
from graphrag_llm.config import (
    AuthMethod,
    ModelConfig,
    TemplateEngineConfig,
    TemplateEngineType,
    TemplateManagerType,
)
from graphrag_llm.templating import create_template_engine
from graphrag_llm.types import LLMCompletionResponse
from pydantic import BaseModel, Field

load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
model_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    azure_deployment_name=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)
llm_completion: LLMCompletion = create_completion(model_config)


template_engine = create_template_engine()

# The above default is the same as the following configuration:
template_engine = create_template_engine(
    TemplateEngineConfig(
        type=TemplateEngineType.Jinja,
        template_manager=TemplateManagerType.File,
        base_dir="templates",
        template_extension=".jinja",
        encoding="utf-8",
    )
)

msg = template_engine.render(
    # Name of the template file without extension
    template_name="weather_listings",
    # Values to fill in the template
    context={
        "weather_reports": [
            {"city": "Seattle", "temperature_f": 52, "condition": "sunny"},
            {"city": "San Francisco", "temperature_f": 75, "condition": "cloudy"},
        ]
    },
)


print(f"The rendered message to parse: {msg}")


# Structured response parsing using pydantic
class LocalWeather(BaseModel):
    """City weather information model."""

    city: str = Field(description="The name of the city")
    temperature: float = Field(description="The temperature in Celsius")
    condition: str = Field(description="The weather condition description")


class WeatherReports(BaseModel):
    """Weather information model."""

    reports: list[LocalWeather] = Field(
        description="The weather reports for multiple cities"
    )


response: LLMCompletionResponse[WeatherReports] = llm_completion.completion(
    messages=msg,
    response_format=WeatherReports,
)  # type: ignore

local_weather_reports: WeatherReports = response.formatted_response  # type: ignore
for report in local_weather_reports.reports:
    print(f"City: {report.city}")
    print(f"  Temperature: {report.temperature} °C")
    print(f"  Condition: {report.condition}")