Back to Graphrag

Batching

packages/graphrag-llm/notebooks/08_batching.ipynb

3.0.98.7 KB
Original Source

Batching

Completion Batching

python
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import json
import os

from dotenv import load_dotenv
from graphrag_llm.completion import LLMCompletion, create_completion
from graphrag_llm.config import AuthMethod, ModelConfig
from graphrag_llm.types import LLMCompletionArgs

load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
model_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    azure_deployment_name=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)
llm_completion: LLMCompletion = create_completion(model_config)


completion_requests: list[LLMCompletionArgs] = [
    {
        "messages": "Write a 1000 word poem about the night sky and all the wonders and mysteries of the universe."
    },
] * 10

# Spins up to 25 concurrent requests
# Which is more than the number of requests being made
# and since rate limiting is not enabled, all the requests fire off immediately
# and complete as fast as the LLM provider allows
responses = llm_completion.completion_batch(completion_requests, concurrency=25)
for response in responses:
    if isinstance(response, Exception):
        print(f"Error: {response}")
    else:
        # Print the first 100 characters of the first successful response
        print(response.content[0:100])  # type: ignore
        break

print(f"Metrics for: {llm_completion.metrics_store.id}")
print(json.dumps(llm_completion.metrics_store.get_metrics(), indent=2))

Notice the difference between compute_duration_seconds and runtime_duration_seconds. The former indicates how long all the network requests took to complete and would be how long the whole process took to complete if running the requests in series. The latter indicates how long the batch as a whole took to complete when running with concurrency.

With Rate Limiting

python
from graphrag_llm.config import RateLimitConfig, RateLimitType

model_config.rate_limit = RateLimitConfig(
    type=RateLimitType.SlidingWindow,
    period_in_seconds=60,  # limit requests per minute
    requests_per_period=20,  # max 20 requests per minute. Fire one off every 3 seconds
)
llm_completion: LLMCompletion = create_completion(model_config)
llm_completion.metrics_store.clear_metrics()

responses = llm_completion.completion_batch(completion_requests, concurrency=25)

print(f"Metrics for: {llm_completion.metrics_store.id}")
print(json.dumps(llm_completion.metrics_store.get_metrics(), indent=2))

Notice the runtime_duration_seconds is now much slower as the requests are being throttled by the rate limit.

With Cache

python
from graphrag_cache import create_cache

cache = create_cache()

# Redisable rate limiting
model_config.rate_limit = None

llm_completion: LLMCompletion = create_completion(model_config, cache=cache)
llm_completion.metrics_store.clear_metrics()

responses = llm_completion.completion_batch(completion_requests, concurrency=4)

print(f"Metrics for: {llm_completion.metrics_store.id}")
print(json.dumps(llm_completion.metrics_store.get_metrics(), indent=2))

Notice the cached_responses == 6 since we are spinning up 4 threads. The first 4 requests are fired off immediately prior to any data in the cache. This means when identical requests are fired in the same thread cycle they will all hit the model since the cache is not yet populated.

The cached_responses indicates how many cache hits occurred but the rest of the metrics exist as if a cache was not used. For example, compute_duration_seconds and all the tokens and cost counts are as if cache was not used so compute_duration_seconds includes network timings for the cached responses. This is because both the response and metrics are cached and retrieved from the cache when a cache hit occurs. This means the above metrics should closely match the metrics from the first example in this notebook other than the runtime_duration_seconds which gives the true idea of how long a job takes to run. Rerunning a job with a fully hydrated cache should result in a quick runtime_duration_seconds. Metrics were designed to give an idea of how long and costly a job would be if there were no cache.

Embedding Batching

python
from graphrag_llm.embedding import LLMEmbedding, create_embedding
from graphrag_llm.types import LLMEmbeddingArgs

embedding_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"),
    azure_deployment_name=os.getenv(
        "GRAPHRAG_LLM_EMBEDDING_MODEL", "text-embedding-3-small"
    ),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)

llm_embedding: LLMEmbedding = create_embedding(embedding_config)

# A single embedding request already accepts a list of inputs to embed
# Here we demonstrate batching multiple embedding requests concurrently
# The first request has two inputs to embed and the second has one input
embedding_requests: list[LLMEmbeddingArgs] = [
    {"input": ["Hello World.", "The quick brown fox jumps over the lazy dog."]},
    {"input": ["GraphRag is an amazing LLM framework."]},
]

responses = llm_embedding.embedding_batch(embedding_requests, concurrency=4)
for response in responses:
    if isinstance(response, Exception):
        print(f"Error: {response}")
    else:
        for embedding in response.embeddings:
            print(f"Embedding vector length: {len(embedding)}")
            print(embedding[0:5])  # Print first 5 dimensions of the embedding vector

print(f"Metrics for: {llm_embedding.metrics_store.id}")
print(json.dumps(llm_embedding.metrics_store.get_metrics(), indent=2))

Details

The batch utils start up concurrency number of threads in a thread pool and then push all requests into an input queue where free threads pick up the next request to process. The threads will process requests within any defined rate limits and retry any failed request according to the retry settings. If a request fails after all the retries the thread will capture the exception and return it. Thus the batch result may contain exceptions.

Thread Pool

The batch utils are convenient if all your requests are loaded in memory. If you wish to stream over an input source then you can use the lower level thread pool utils.

Completion Thread Pool

python
from collections.abc import Iterator

from graphrag_llm.types import LLMCompletionChunk, LLMCompletionResponse

llm_completion.metrics_store.clear_metrics()


# The response handler may also be asynchronous if needed
def _handle_response(
    request_id: str,
    resp: LLMCompletionResponse | Iterator[LLMCompletionChunk] | Exception,
):
    # Imagine streaming responses to disk or elsewhere
    if isinstance(resp, Exception):
        print(f"{request_id}: Failed")
    else:
        print(f"{request_id}: Succeeded")


with llm_completion.completion_thread_pool(
    response_handler=_handle_response,
    concurrency=25,
    # set queue_limit to create backpressure on reading the requests
    queue_limit=10,
) as completion:
    # Iterating over a list of completion requests already in memory
    # but can imagine reading them from disk or another source
    # The completion function returned from the context manager
    # will block if the queue_limit is reached until some requests complete
    # and also requires a request_id for tracking the requests
    # and allowing you to identify them in the response handler
    for index, request in enumerate(completion_requests):
        completion(request_id=f"request_number_{index}", **request)

# Using the same request that was used in the caching example so
# this should complete instantly from cache
print(f"Metrics for: {llm_completion.metrics_store.id}")
print(json.dumps(llm_completion.metrics_store.get_metrics(), indent=2))

Embedding Thread Pool

python
from graphrag_llm.types import LLMEmbeddingResponse

llm_embedding.metrics_store.clear_metrics()


# The response handler may also be asynchronous if needed
def _handle_response(
    request_id: str,
    resp: LLMEmbeddingResponse | Exception,
):
    if isinstance(resp, Exception):
        print(f"{request_id}: Failed")
    else:
        print(f"{request_id}: Succeeded")


with llm_embedding.embedding_thread_pool(
    response_handler=_handle_response,
    concurrency=25,
    queue_limit=10,
) as embedding:
    for index, request in enumerate(embedding_requests):
        embedding(request_id=f"embedding_request_number_{index}", **request)