docs/SPECULATIVE_DECODING.md
Speculative decoding is an inference acceleration technique that uses a smaller "draft" model to propose tokens, which are then validated in parallel by the larger "target" model. This can significantly speed up generation when the draft model frequently predicts tokens the target model would also choose.
Mistral.rs implements speculative decoding based on the paper: Fast Inference from Transformers via Speculative Decoding.
gamma candidate tokens autoregressivelyp_target(x) / p_draft(x)This approach guarantees the same output distribution as running the target model alone, while often achieving significant speedups.
The key parameter is gamma - the number of draft tokens to generate per speculation step. Higher values can increase throughput when the draft model is accurate, but waste computation when predictions are frequently rejected.
Recommended values: Start with gamma = 12-32 and tune based on your models and workload.
The recommended way to configure speculative decoding is via TOML. Create a config file (e.g., speculative.toml):
[model]
model_id = "meta-llama/Llama-3.1-8B-Instruct"
[speculative]
gamma = 12
[speculative.draft_model]
model_id = "meta-llama/Llama-3.2-1B-Instruct"
Then run with:
mistralrs run --from-toml speculative.toml
The draft model can use any supported format (Plain, GGUF, etc.) and can have different quantization than the target model.
[model]
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
[speculative]
gamma = 16
[speculative.draft_model]
model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
model_file = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
tok_model_id = "mistralai/Mistral-7B-Instruct-v0.1"
[model]
model_id = "meta-llama/Llama-3.1-8B-Instruct"
[speculative]
gamma = 16
[speculative.draft_model]
model_id = "meta-llama/Llama-3.2-1B-Instruct"
isq = "Q8_0"
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture
runner = Runner(
which=Which.Plain(
model_id="mistralai/Mistral-7B-Instruct-v0.1",
arch=Architecture.Mistral,
),
which_draft=Which.GGUF(
tok_model_id="mistralai/Mistral-7B-Instruct-v0.1",
quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf",
),
speculative_gamma=32,
)
res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="default",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
temperature=0.1,
)
)
print(res.choices[0].message.content)
print(res.usage)
| Parameter | Type | Description |
|---|---|---|
which_draft | Which | Draft model specification (Plain, GGUF, etc.) |
speculative_gamma | int | Number of draft tokens per step (default: 32) |
You can find this example at mistralrs/examples/advanced/speculative/main.rs.
use anyhow::Result;
use mistralrs::{
IsqType, RequestBuilder, SpeculativeConfig, TextMessageRole, TextMessages,
TextModelBuilder, TextSpeculativeBuilder,
};
#[tokio::main]
async fn main() -> Result<()> {
let target = TextModelBuilder::new("meta-llama/Llama-3.1-8B-Instruct")
.with_logging();
let draft = TextModelBuilder::new("meta-llama/Llama-3.2-1B-Instruct")
.with_logging()
.with_isq(IsqType::Q8_0);
let spec_cfg = SpeculativeConfig { gamma: 16 };
let model = TextSpeculativeBuilder::new(target, draft, spec_cfg)?
.build()
.await?;
let messages = TextMessages::new()
.add_message(
TextMessageRole::System,
"You are an AI agent with a specialty in programming.",
)
.add_message(
TextMessageRole::User,
"Hello! How are you? Please write generic binary search function in Rust.",
);
let response = model.send_chat_request(messages).await?;
println!("{}", response.choices[0].message.content.as_ref().unwrap());
dbg!(
response.usage.avg_prompt_tok_per_sec,
response.usage.avg_compl_tok_per_sec
);
Ok(())
}
For best performance:
| Target Model | Draft Model | Notes |
|---|---|---|
| Llama 3.1-8B | Llama 3.2-1B | Same family, good acceptance |
| Llama 3.1-70B | Llama 3.1-8B | Large speedup potential |
| Mistral-7B | Mistral-7B (Q4_K_M GGUF) | Same model, quantized draft |
Speculative decoding can be combined with:
See examples/python/speculative_xlora.py for an example combining speculative decoding with X-LoRA.