megatron/core/models/mimo/README.md
MIMO is a model architecture that enables language models to understand and generate multiple modalities (text, images, audio, etc.). It achieves this through:
MIMO provides a flexible and canonical architecture that can be configured into various multimodal models, for example
The model architecture consists of 2 main components:
The complete data flow:
Input → Encoder → Projection → Align input embeddings → Language Model → Hidden states for special generation tokens -> Output Projection → Decoder → Output
Encoding:
Decoding:
The language model is the core component that processes all modality information in a unified embedding space:
ModalitySubmodules connect raw modality data with the language model:
# Base class constructor with named encoders and decoders
class ModalitySubmodules(ABC, nn.Module):
def __init__(
self,
encoders: Optional[Dict[str, nn.Module]] = None,
decoders: Optional[Dict[str, nn.Module]] = None,
input_projections: Optional[List[nn.Module]] = None,
output_projections: Optional[List[nn.Module]] = None,
):
MIMO provides default implementations (VisionModalitySubmodules, AudioModalitySubmodules), but you can create custom submodules for specialized processing:
# Custom implementation
class CustomVisionSubmodules(ModalitySubmodules):
def encode(self, inputs):
# Specialized encoding logic
return projected_embeddings
# Use custom submodules when creating the model
model = MimoModel(
mimo_config,
modality_submodules={"images": ModuleSpec(module=CustomVisionSubmodules, params={...})}
)
The MimoModel handles the integration of different modality embeddings through its align_embeddings_by_token_positions method:
Example of what happens internally:
# Inside MimoModel's forward method
aligned_embeddings = self.align_embeddings_by_token_positions(
modality_embeddings={"text": text_emb, "images": image_emb},
input_ids=tokens,
special_token_ids={"images": 32000}
)
MimoModel(
config: MimoModelConfig, # Required: Configuration for the model
)
MIMO models are instantiated with a MimoModelConfig, which contains:
MimoModelConfig(
language_model: ModuleSpec, # Specification for the language model
modality_submodules: Dict[str, ModuleSpec], # Dictionary mapping modality names to their submodule specifications
special_token_ids: Dict[str, int] = {} # Dictionary mapping modality names to their special token IDs
)
# Language model specification
lm_spec = ModuleSpec(
module=GPTModel,
params={
"config": language_config,
"transformer_layer_spec": get_mock_language_layer_spec(),
"vocab_size": 50304,
}
)
# Vision modality specification
vision_submodule_spec = ModuleSpec(
module=VisionModalitySubmodules,
params={
# Any general parameters for the submodule can go here
},
submodules={
"encoders": {
"clip_encoder": ModuleSpec(
module=CLIPViTModel,
params={
"transformer_config": vision_config,
"transformer_layer_spec": get_mock_vision_layer_spec(),
"patch_dim": 16,
"img_h": 224,
"img_w": 224,
}
),
},
"input_projections": [
ModuleSpec(
module=MultimodalProjector,
params={
"config": get_mock_projection_config(),
"submodules": get_mock_projection_layer_spec().submodules,
"projector_type": "mlp",
"input_size": 128
}
),
],
}
)
# Instantiate the model
vlm = MimoModel(
MimoModelConfig(
language_model=lm_spec,
modality_submodules={"images": vision_submodule_spec},
special_token_ids={"images": 32000}
)
)
# Prepare inputs for multiple modalities and encoders
modality_inputs = {
# modality names and encoder names should match the keys used in mimo config during initialization.
"images": {
"clip_encoder": {"pixel_values": images}, # Encoder-specific inputs
"vit_encoder": {"images": vit_images}
},
"audio": {
"whisper_encoder": {"input_features": audio_features}
}
}
# Call forward method
outputs, _ = mimo_model(
input_ids=input_ids,
position_ids=position_ids,
modality_inputs=modality_inputs,
)