optional-skills/mlops/tensorrt-llm/references/optimization.md
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
Benefits:
Usage:
from tensorrt_llm import LLM
# Automatic FP8 quantization
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
dtype="fp8",
quantization="fp8"
)
Performance (Llama 3-70B on 8× H100):
Benefits:
Usage:
# INT4 with AWQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_awq",
quantization="awq"
)
# INT4 with GPTQ calibration
llm = LLM(
model="meta-llama/Meta-Llama-3-405B",
dtype="int4_gptq",
quantization="gptq"
)
Trade-offs:
What it does: Dynamically batches requests during generation instead of waiting for all sequences to finish.
Configuration:
# Server configuration
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_batch_size 256 \ # Maximum concurrent sequences
--max_num_tokens 4096 \ # Total tokens in batch
--enable_chunked_context \ # Split long prompts
--scheduler_policy max_utilization
Performance:
What it does: Manages KV cache memory like OS manages virtual memory (paging).
Benefits:
Configuration:
# Automatic paged KV cache (default)
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
enable_prefix_caching=True # Cache common prefixes
)
What it does: Uses small draft model to predict multiple tokens, verified by target model in parallel.
Speedup: 2-3× faster for long generations
Usage:
from tensorrt_llm import LLM
# Target model (Llama 3-70B)
llm = LLM(
model="meta-llama/Meta-Llama-3-70B",
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
num_speculative_tokens=5 # Tokens to predict ahead
)
# Same API, 2-3× faster
outputs = llm.generate(prompts)
Best models for drafting:
What it does: Reduces kernel launch overhead by recording GPU operations.
Benefits:
Configuration (automatic by default):
llm = LLM(
model="meta-llama/Meta-Llama-3-8B",
enable_cuda_graph=True, # Default: True
cuda_graph_cache_size=2 # Cache 2 graph variants
)
What it does: Splits long prompts into chunks to reduce memory spikes.
Use case: Prompts >8K tokens with limited GPU memory
Configuration:
trtllm-serve meta-llama/Meta-Llama-3-8B \
--max_num_tokens 4096 \
--enable_chunked_context \
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
What it does: Overlaps compute and memory operations.
Benefits:
No configuration needed - enabled automatically.
| Method | Memory | Speed | Accuracy | Use Case |
|---|---|---|---|---|
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
| FP8 | 0.5× | 2× | -0.5% ppl | H100 default |
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
Start with defaults:
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
Enable FP8 (if H100):
llm = LLM(model="...", dtype="fp8")
Tune batch size:
# Increase until OOM, then reduce 20%
trtllm-serve ... --max_batch_size 256
Enable chunked context (if long prompts):
--enable_chunked_context --max_chunked_prefill_length 2048
Try speculative decoding (if latency critical):
llm = LLM(model="...", speculative_model="...")
# Install benchmark tool
pip install tensorrt_llm[benchmark]
# Run benchmark
python benchmarks/python/benchmark.py \
--model meta-llama/Meta-Llama-3-8B \
--batch_size 64 \
--input_len 128 \
--output_len 256 \
--dtype fp8
Metrics to track:
OOM errors:
max_batch_sizemax_num_tokenstensor_parallel_sizeLow throughput:
max_batch_sizeHigh latency:
max_batch_size (less queueing)