examples/inference/README.md
This guide provides an example for Megatron Core for running model inference.
This example runs statically-batched inference on a model trained using Megatron Core. The entrypoint is gpt_static_inference.py. A similar workflow can be adapted for gpt_dynamic_inference.py.
STEP 1 - Initialize model parallel and other default arguments The micro batch size defaults to 1. It is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime.
# Initialize Megatron model using the same model provider from training.
initialize_megatron(
args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1}
)
STEP 2 - Load the model using the model_provider_function The model provider function supports both MCore and Legacy models.
# Load the model checkpoint
model = get_model(model_provider, wrap_with_ddp=False)
load_checkpoint(model, None, None)
model.eval()
model = model[0]
STEP 3 - Choose an engine Text generation requires an inference engine, which includes a scheduler. The default engine is the Megatron Core engine with a text generation controller. TRTLLMEngine will be supported in the future.
# Create an inference wrapper to setup the model.
inference_wrapped_model = GPTInferenceWrapper(model, args)
# Define a sampling loop.
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model,
tokenizer=tokenizer
)
# Create a static or dynamic inference engine.
inference_engine = StaticInferenceEngine(
text_generation_controller=text_generation_controller,
max_batch_size=args.max_batch_size
)
STEP 4 - Run text generation The SamplingParams class uses suggested defaults. Customize this to change top_p, top_k, number of tokens to generate, etc. The result is returned as a list of InferenceRequests.
results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, sampling_params=sampling_params
)
if torch.distributed.get_rank() == 0:
for idx, result in enumerate(results):
print(f' ------------- RESULT FOR PROMPT {idx} --------------- ')
result = {
'id': result.request_id,
'input_prompt': result.prompt,
'generated_text': result.generated_text,
'generated_tokens' : result.generated_tokens
}
print(result)
An example Slurm script is shown below. Set the tokenizer paths, inference params, and other settings appropriately.
For a recap on sampling parameters, refer to this blog.
# Slurm cluster settings
ACCOUNT=<account>
MLM_PATH=/path/to/megatron-lm
GPT_CKPT=/path/to/gpt/ckpt
VOCAB_MERGE_FILE_PATH=/path/to/vocab/and/merge/file
CONTAINER_IMAGE=nvcr.io/ea-bignlp/ga-participants/nemofw-training:23.11
srun --account $ACCOUNT \
--job-name=$ACCOUNT:inference \
--partition=batch \
--time=01:00:00 \
--container-image $CONTAINER_IMAGE \
--container-mounts $MLM_PATH:/workspace/megatron-lm/,$GPT_CKPT:/workspace/mcore_gpt_ckpt,$VOCAB_MERGE_FILE_PATH:/workspace/tokenizer \
--no-container-mount-home \
--pty /bin/bash \
# Inside the container run the following.
cd megatron-lm/
export CUDA_DEVICE_MAX_CONNECTIONS=1
TOKENIZER_ARGS=(
--vocab-file /workspace/tokenizer/gpt2-vocab.json
--merge-file /workspace/tokenizer/gpt2-merges.txt
--tokenizer-type GPT2BPETokenizer
)
MODEL_ARGS=(
--use-checkpoint-args
--use-mcore-models
--load /workspace/mcore_gpt_ckpt
)
INFERENCE_SPECIFIC_ARGS=(
--attention-dropout 0.0
--hidden-dropout 0.0
--num-tokens-to-generate 20
--max-batch-size 4
)
torchrun --nproc-per-node=4 examples/inference/gpt/gpt_static_inference.py \
${TOKENIZER_ARGS[@]} \
${MODEL_ARGS[@]} \
${INFERENCE_SPECIFIC_ARGS[@]} \
--prompts "prompt one " "sample prompt two" "sample prompt 3"
NOTE: Other parameters which can be customized for inference:
--temperature (Sampling temperature)
--top_k (top_k sampling)
--top_p (top_p sampling)
--num-tokens-to-generate (Number of tokens to generate for each prompt)
--inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use microbatched pipelining.')
--use-dist-ckpt (If using dist checkpoint format for the model)
--use-legacy-models (If using legacy models instead of MCore models)
An example of inference with static batching is provided in gpt_static_inference.py.
Scheduler in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool..forward() method to get the output logitsThe inference pipeline supports three levels of customization:
The abstract_engine.py file contains a generate method that can be extended to support a new backend.
class AbstractEngine(ABC):
@staticmethod
def generate(self) -> dict:
"""The abstract backend's generate function.
To define a new backend, implement this method and return the outputs as a dictionary.
The TextGenerationController contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies.
class TextGenerationController:
def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts"""
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
sampling_params: SamplingParams,
vocab_size: int,
generation_started : Optional[torch.Tensor] = None,
top_n_logprobs_dict: Dict[int, List[Dict[str, float]]] = None,
) -> torch.Tensor:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens. If sampling_params.top_n_logprobs > 0
at each step it also updates the top_n_logprobs_dict.
"""
def update_generation_status(
self,
updated_prompts_tokens: torch.Tensor,
generation_started: torch.Tensor,
current_context_end_position: int,
is_generation_done_tensor: torch.Tensor,
generated_sequence_lengths: torch.Tensor,
) -> torch.Tensor:
"""Function to check which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding flags of the is_generation_done_tensor to True . The generated sequence lengths increases as we keep generating, until that prompts hits an eod condition. The generation started status tensor helps us determine which prompts have started generating
"""
def generate_all_output_tokens_static_batch(
self, active_requests: OrderedDict[int, InferenceRequest],
) -> OrderedDict[int, InferenceRequest]:
"""Utility to generate all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till all prompts complete generation, updates the status of these requests to completed, adds the generated result and returns these requests
"""
def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
"""Detokenize the output generations"""
Extend abstract_model_inference_wrapper.py to support other models. The abstract model wrapper implements:
forward method depending on model parallel settings.eval() modeThe following methods should be implemented:
class AbstractModelInferenceWrapper:
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass
"""
@abc.abstractclassmethod
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference
This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
Refer to gpt_inference_wrapper.py for an example of implementing this for GPTModel.
We use common inference params for text generation. Customize this to change top_p, top_k, number of tokens to generate etc. Other attributes can be added for the inference loop as shown below.
from megatron.core.inference.sampling_params import SamplingParams
c = SamplingParams(temperature=0.5)
c.add_attributes({'min_length':4, 'eod_id':153})
The following features are planned for future releases.