Skip to content

Instantly share code, notes, and snippets.

@w32zhong
Last active September 28, 2025 17:03
Show Gist options
  • Save w32zhong/2c066a7f7ed0bdc9e31007bf0ea8c6d6 to your computer and use it in GitHub Desktop.
Save w32zhong/2c066a7f7ed0bdc9e31007bf0ea8c6d6 to your computer and use it in GitHub Desktop.
sglang
{
"architectures": [
"Qwen3ForCausalLMEagle"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 262144,
"max_window_layers": 36,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 5000000,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.54.1",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
from sglang.srt.utils import add_prefix
# Adapted from
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.qwen3 import Qwen3DecoderLayer, Qwen3ForCausalLM
Qwen3Config = None
class Qwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(
self,
config: Qwen3Config,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, layer_id, quant_config, prefix=prefix)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if layer_id == 0:
del self.input_layernorm
setattr(self, "input_layernorm", lambda x: x)
class Qwen3Model(nn.Module):
def __init__(
self,
config: Qwen3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = nn.ModuleList(
[
Qwen3DecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers)
]
)
self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
hidden_states = self.fc(
torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
return hidden_states + residual
class Qwen3ForCausalLMEagle(Qwen3ForCausalLM):
def __init__(
self,
config: Qwen3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
print("[Qwen3ForCausalLMEagle] current model weights: ", [name for name, _ in self.named_parameters()])
for name, loaded_weight in weights:
name = name.replace("eagle_fc", "fc")
if "lm_head" not in name:
name = "model." + name
print("[Qwen3ForCausalLMEagle] loading weight: ", name)
super().load_weights([(name, loaded_weight)])
EntryClass = [Qwen3ForCausalLMEagle]
import time
from transformers import AutoTokenizer
import asyncio
import sglang as sgl
from sglang.utils import async_stream_and_merge, trim_overlap
async def generate(llm, tokenizer, prompt, sampling_params):
final_text = ""
generator = await llm.async_generate(prompt, sampling_params, stream=True)
cnt_tokens = []
print(prompt)
async for chunk in generator:
chunk_text = chunk["text"]
cleaned_chunk = trim_overlap(final_text, chunk_text)
final_text += cleaned_chunk
print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
cnt_tokens.append(len(chunk['output_ids']))
return cnt_tokens
def batch_generate(llm, prompts, sampling_params):
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
def hot_replace_config(path):
import os, json, shutil
json_path = f'{path}/config.json'
bkup_path = f'{path}/config.json.bkup'
assert os.path.exists(json_path)
shutil.copyfile(json_path, bkup_path)
with open(json_path) as fh:
j = json.load(fh)
j['architectures'][0] += 'Eagle'
j['tie_word_embeddings'] = False
with open(json_path, 'w') as fh:
json.dump(j, fh)
def hot_recover_config(path):
import os, shutil
json_path = f'{path}/config.json'
bkup_path = f'{path}/config.json.bkup'
assert os.path.exists(bkup_path)
shutil.copyfile(bkup_path, json_path)
def main(speculative_algorithm=None, bs=1, tp_size=1,
no_prefill_cache=True, draft_model_path=None):
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
questions = [
"Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?",
"Who is the president of the United States?",
"Write an essay about the future of AI.",
"What is your favorite book?",
"What is your least favorite book?",
"What is your favorite programming language?",
"What is your least favorite programming language?",
"Write a short, neutral self-introduction for a fictional character.",
"Provide a concise factual statement about France’s capital city."
][:bs]
messages = lambda question: [{"role": "user", "content": question}]
prompts = [
tokenizer.apply_chat_template(messages(Q), tokenize=False, add_generation_prompt=True)
for Q in questions
]
from sglang.srt.server_args import ServerArgs
from sglang.srt.models.qwen3_spec import Qwen3ForCausalLMEagle
if draft_model_path:
hot_replace_config(draft_model_path)
llm = sgl.Engine( model_path="Qwen/Qwen3-4B-Instruct-2507",
tp_size=tp_size,
cuda_graph_max_bs=bs,
disable_cuda_graph=False,
disable_radix_cache=no_prefill_cache,
disable_chunked_prefix_cache=no_prefill_cache,
speculative_algorithm=speculative_algorithm,
speculative_draft_model_path=draft_model_path,
speculative_num_steps=6,
speculative_eagle_topk=10,
speculative_num_draft_tokens=60,
#speculative_attention_mode="prefill"
)
if draft_model_path:
hot_recover_config(draft_model_path)
sampling_params = {"temperature": 0, "max_new_tokens": 8000}
begin = time.perf_counter()
if bs > 1:
cnt_tokens = batch_generate(llm, prompts, sampling_params)
else:
cnt_tokens = asyncio.run(generate(llm, tokenizer, prompts[0], sampling_params))
cnt_tokens.pop(0)
time_cost = time.perf_counter() - begin
llm.shutdown()
print()
print(cnt_tokens)
print('tokens and time:', sum(cnt_tokens), time_cost)
print('e2e speed:', sum(cnt_tokens) / time_cost)
print('max accept length:', max(cnt_tokens))
print('min accept length:', min(cnt_tokens))
print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))
if __name__ == '__main__':
import fire
fire.Fire(main)
@w32zhong
Copy link
Author

sglang==0.5.1

@w32zhong
Copy link
Author

related PR: sgl-project/sglang#10846

CUDA_VISIBLE_DEVICES=0 uv run python -m sglang.launch_server --speculative-algo EAGLE \
    --model Qwen/Qwen3-4B-Instruct-2507 \
    --speculative-draft-model-path /workspace/mnt/specforge_PoC/output/deft-bee-66/draft_model_sglang \
    --speculative-num-steps 6 \
    --speculative-eagle-topk 10 \
    --speculative-num-draft-tokens 60 \
    --cuda-graph-max-bs 2 \
    --mem-fraction-static 0.8 \
    --chunked-prefill-size 4
cd benchmark/mtbench
python3 bench_sglang_eagle.py --parallel 1 --num-questions 10
# or: python3 -m sglang.test.send_one --batch-size 2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment