docs/source/en/model_doc/gemma4_unified_assistant.md
This model was contributed to Hugging Face Transformers on 2026-06-03.
Gemma 4 Unified Assistant is a small, text-only model that enables speculative decoding with for Gemma 4 Unified models using the Multi-Token Prediction (MTP) method and associated candidate generator. Pre-trained models are provided for the IT variants of the Gemma 4 12B model.
For more information, please see Gemma4 Assistant. Architecturally and conceptually, they share the same concept and differences to their base model:
position_ids value are constant. Since the KV cache is shared and the assistant does not have a mean of
updating the cache, the assistant predicts all tokens from the same position ID.position_ids, the model takes its inputs as the concatenation of the embedding and hidden_states for the last
seen token from the target model and projects them into assistant model space with a nn.Linear transform. The
definition of last seen token changes throughout the assisted decoding loop. For the first token drafted after
pre-fill, the last seen token will be the last token from the prompt. For subsequent drafting steps, the last seen
token will be the last token generated by the assistant (within a drafting round) or the last token accepted by the
target model (between drafting rounds).The example below demonstrates how to generate text based on an image with [Pipeline] or the [AutoModel] class.
import torch
from transformers import pipeline
pipeline = pipeline(
task="image-text-to-text",
model="google/gemma-4-12B-it",
assistant_model="google/gemma-4-12B-it-assistant",
)
pipeline(
images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
text="<|image|>\n\nWhat is shown in this image?"
)
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained(
"google/gemma-4-12B-it",
dtype=torch.bfloat16,
device_map="auto",
)
assistant_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-4-12B-it-assistant",
dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(
"google/gemma-4-12B-it",
padding_side="left"
)
messages = [
{
"role": "user", "content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
output = model.generate(**inputs, max_new_tokens=50, assistant_model=assistant_model)
print(processor.decode(output[0][input_len:], skip_special_tokens=True))
[[autodoc]] Gemma4UnifiedAssistantConfig
[[autodoc]] Gemma4UnifiedAssistantForCausalLM