docs_new/docs/advanced_features/speculative_decoding.ipynb
SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.
Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.
| Method | Throughput (tokens/s) |
|---|---|
| SGLang (w/o speculative, 1x H100) | 158.34 tokens/s |
| SGLang + EAGLE-2 (1x H100) | 244.10 tokens/s |
| SGLang + EAGLE-3 (1x H100) | 373.25 tokens/s |
To enable EAGLE speculative decoding the following parameters are relevant:
speculative_draft_model_path: Specifies draft model. This parameter is required.speculative_num_steps: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.speculative_eagle_topk: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.speculative_num_draft_tokens: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.These parameters are the same for EAGLE-2 and EAGLE-3.
You can find the best combinations of these parameters with bench_speculative.py.
In the documentation below, we set --cuda-graph-max-bs to be a small value for faster engine startup. For your own workloads, please tune the above parameters together with --cuda-graph-max-bs, --max-running-requests, --mem-fraction-static for the best performance.
You can enable EAGLE-2 decoding by setting --speculative-algorithm EAGLE and choosing an appropriate model.
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
import openai
server_process, port = launch_server_cmd("""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
--speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --log-level warning
""")
wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
terminate_process(server_process)
torch.compileYou can also enable torch.compile for further optimizations and optionally set --torch-compile-max-bs:
server_process, port = launch_server_cmd("""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
--enable-torch-compile --torch-compile-max-bs 2 --log-level warning
""")
wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
terminate_process(server_process)
By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.
In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.
Thanks for the contribution from Weilin Zhao and Zhousx.
server_process, port = launch_server_cmd("""
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
--mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 --log-level warning
""")
wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
terminate_process(server_process)
You can enable EAGLE-3 decoding by setting --speculative-algorithm EAGLE3 and choosing an appropriate model.
server_process, port = launch_server_cmd("""
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algorithm EAGLE3 \
--speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
--cuda-graph-max-bs 2 --dtype float16 --log-level warning
""")
wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
terminate_process(server_process)
We support MTP(Multi-Token Prediction) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to deepseek doc)
server_process, port = launch_server_cmd("""
python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \
--speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \
--mem-fraction 0.5 --log-level warning
""")
wait_for_server(f"http://localhost:{port}")
import requests
url = f"http://localhost:{port}/v1/chat/completions"
data = {
"model": "XiaomiMiMo/MiMo-7B-RL",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print_highlight(response.json())
terminate_process(server_process)
EAGLE process is as follows:
speculative_eagle_topk parameter—to ensure a more coherent connection of context, and are given as input again.speculative_num_draft_tokens final nodes as draft tokens.This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.
For guidance how to train your own EAGLE model please see the EAGLE repo.