Last active
September 28, 2025 17:03
-
-
Save w32zhong/2c066a7f7ed0bdc9e31007bf0ea8c6d6 to your computer and use it in GitHub Desktop.
Revisions
-
w32zhong revised this gist
Sep 28, 2025 . 1 changed file with 42 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -24,8 +24,30 @@ def batch_generate(llm, prompts, sampling_params): print("===============================") print(f"Prompt: {prompt}\nGenerated text: {output['text']}") def hot_replace_config(path): import os, json, shutil json_path = f'{path}/config.json' bkup_path = f'{path}/config.json.bkup' assert os.path.exists(json_path) shutil.copyfile(json_path, bkup_path) with open(json_path) as fh: j = json.load(fh) j['architectures'][0] += 'Eagle' j['tie_word_embeddings'] = False with open(json_path, 'w') as fh: json.dump(j, fh) def hot_recover_config(path): import os, shutil json_path = f'{path}/config.json' bkup_path = f'{path}/config.json.bkup' assert os.path.exists(bkup_path) shutil.copyfile(bkup_path, json_path) def main(speculative_algorithm=None, bs=1, tp_size=1, no_prefill_cache=True, draft_model_path=None): tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507') questions = [ @@ -47,19 +69,24 @@ def main(speculative_algorithm=None, bs=1, tp_size=1): from sglang.srt.server_args import ServerArgs from sglang.srt.models.qwen3_spec import Qwen3ForCausalLMEagle if draft_model_path: hot_replace_config(draft_model_path) llm = sgl.Engine( model_path="Qwen/Qwen3-4B-Instruct-2507", tp_size=tp_size, cuda_graph_max_bs=bs, disable_cuda_graph=False, disable_radix_cache=no_prefill_cache, disable_chunked_prefix_cache=no_prefill_cache, speculative_algorithm=speculative_algorithm, speculative_draft_model_path=draft_model_path, speculative_num_steps=6, speculative_eagle_topk=10, speculative_num_draft_tokens=60, #speculative_attention_mode="prefill" ) if draft_model_path: hot_recover_config(draft_model_path) sampling_params = {"temperature": 0, "max_new_tokens": 8000} -
w32zhong revised this gist
Sep 10, 2025 . 1 changed file with 0 additions and 64 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,13 +4,6 @@ import sglang as sgl from sglang.utils import async_stream_and_merge, trim_overlap async def generate(llm, tokenizer, prompt, sampling_params): final_text = "" generator = await llm.async_generate(prompt, sampling_params, stream=True) @@ -88,63 +81,6 @@ def main(speculative_algorithm=None, bs=1, tp_size=1): print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens)) if __name__ == '__main__': import fire fire.Fire(main) -
w32zhong revised this gist
Sep 10, 2025 . 3 changed files with 150 additions and 64 deletions.There are no files selected for viewing
File renamed without changes.This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,64 +0,0 @@ This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,150 @@ import time from transformers import AutoTokenizer import asyncio import sglang as sgl from sglang.utils import async_stream_and_merge, trim_overlap async def generate(llm, tokenizer, prompt): sampling_params = {"temperature": 0, "max_new_tokens": 8000}import time from transformers import AutoTokenizer import asyncio import sglang as sgl from sglang.utils import async_stream_and_merge, trim_overlap async def generate(llm, tokenizer, prompt, sampling_params): final_text = "" generator = await llm.async_generate(prompt, sampling_params, stream=True) cnt_tokens = [] print(prompt) async for chunk in generator: chunk_text = chunk["text"] cleaned_chunk = trim_overlap(final_text, chunk_text) final_text += cleaned_chunk print(tokenizer.decode(chunk['output_ids']), end="", flush=True) cnt_tokens.append(len(chunk['output_ids'])) return cnt_tokens def batch_generate(llm, prompts, sampling_params): outputs = llm.generate(prompts, sampling_params) for prompt, output in zip(prompts, outputs): print("===============================") print(f"Prompt: {prompt}\nGenerated text: {output['text']}") def main(speculative_algorithm=None, bs=1, tp_size=1): tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507') questions = [ "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?", "Who is the president of the United States?", "Write an essay about the future of AI.", "What is your favorite book?", "What is your least favorite book?", "What is your favorite programming language?", "What is your least favorite programming language?", "Write a short, neutral self-introduction for a fictional character.", "Provide a concise factual statement about France’s capital city." ][:bs] messages = lambda question: [{"role": "user", "content": question}] prompts = [ tokenizer.apply_chat_template(messages(Q), tokenize=False, add_generation_prompt=True) for Q in questions ] from sglang.srt.server_args import ServerArgs from sglang.srt.models.qwen3_spec import Qwen3ForCausalLMEagle llm = sgl.Engine( model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507", tp_size=tp_size, cuda_graph_max_bs=bs, disable_cuda_graph=False, speculative_algorithm=speculative_algorithm, speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model", speculative_num_steps=6, speculative_eagle_topk=10, speculative_num_draft_tokens=60, #speculative_attention_mode="prefill" ) sampling_params = {"temperature": 0, "max_new_tokens": 8000} begin = time.perf_counter() if bs > 1: cnt_tokens = batch_generate(llm, prompts, sampling_params) else: cnt_tokens = asyncio.run(generate(llm, tokenizer, prompts[0], sampling_params)) cnt_tokens.pop(0) time_cost = time.perf_counter() - begin llm.shutdown() print() print(cnt_tokens) print('tokens and time:', sum(cnt_tokens), time_cost) print('e2e speed:', sum(cnt_tokens) / time_cost) print('max accept length:', max(cnt_tokens)) print('min accept length:', min(cnt_tokens)) print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens)) if __name__ == '__main__': import fire fire.Fire(main) final_text = "" generator = await llm.async_generate(prompt, sampling_params, stream=True) cnt_tokens = [] async for chunk in generator: chunk_text = chunk["text"] cleaned_chunk = trim_overlap(final_text, chunk_text) final_text += cleaned_chunk print(tokenizer.decode(chunk['output_ids']), end="", flush=True) cnt_tokens.append(len(chunk['output_ids'])) return cnt_tokens def main(speculative_algorithm=None): tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507') question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?" messages = [ {"role": "user", "content": question}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(prompt) from sglang.srt.server_args import ServerArgs llm = sgl.Engine( model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507", tp_size=1, cuda_graph_max_bs=1, disable_cuda_graph=True, speculative_algorithm=speculative_algorithm, speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model", speculative_num_steps=6, speculative_eagle_topk=10, speculative_num_draft_tokens=60, #speculative_attention_mode="prefill" ) begin = time.perf_counter() cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt)) cnt_tokens.pop(0) time_cost = time.perf_counter() - begin llm.shutdown() print() print(cnt_tokens) print('tokens and time:', sum(cnt_tokens), time_cost) print('e2e speed:', sum(cnt_tokens) / time_cost) print('max accept length:', max(cnt_tokens)) print('min accept length:', min(cnt_tokens)) print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens)) if __name__ == '__main__': import fire fire.Fire(main) -
w32zhong revised this gist
Sep 10, 2025 . 2 changed files with 135 additions and 0 deletions.There are no files selected for viewing
File renamed without changes.This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,135 @@ from sglang.srt.utils import add_prefix # Adapted from # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" from typing import Iterable, Optional, Tuple import torch from torch import nn from sglang.srt.distributed import get_pp_group from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.models.qwen3 import Qwen3DecoderLayer, Qwen3ForCausalLM Qwen3Config = None class Qwen3DecoderLayer(Qwen3DecoderLayer): def __init__( self, config: Qwen3Config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config, layer_id, quant_config, prefix=prefix) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 if layer_id == 0: del self.input_layernorm setattr(self, "input_layernorm", lambda x: x) class Qwen3Model(nn.Module): def __init__( self, config: Qwen3Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ Qwen3DecoderLayer( config, i, quant_config=quant_config, prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] ) self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: hidden_states = input_embeds hidden_states = self.fc( torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1) ) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual, ) return hidden_states + residual class Qwen3ForCausalLMEagle(Qwen3ForCausalLM): def __init__( self, config: Qwen3Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: nn.Module.__init__(self) self.config = config self.quant_config = quant_config self.pp_group = get_pp_group() self.model = Qwen3Model( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): print("[Qwen3ForCausalLMEagle] current model weights: ", [name for name, _ in self.named_parameters()]) for name, loaded_weight in weights: name = name.replace("eagle_fc", "fc") if "lm_head" not in name: name = "model." + name print("[Qwen3ForCausalLMEagle] loading weight: ", name) super().load_weights([(name, loaded_weight)]) EntryClass = [Qwen3ForCausalLMEagle] -
w32zhong revised this gist
Sep 10, 2025 . 1 changed file with 30 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,30 @@ { "architectures": [ "Qwen3ForCausalLMEagle" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 151643, "eos_token_id": 151645, "head_dim": 128, "hidden_act": "silu", "hidden_size": 2560, "initializer_range": 0.02, "intermediate_size": 9728, "max_position_embeddings": 262144, "max_window_layers": 36, "model_type": "qwen3", "num_attention_heads": 32, "num_hidden_layers": 1, "num_key_value_heads": 8, "rms_norm_eps": 1e-06, "rope_scaling": null, "rope_theta": 5000000, "sliding_window": null, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.54.1", "use_cache": true, "use_sliding_window": false, "vocab_size": 151936 } -
w32zhong revised this gist
Sep 10, 2025 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -54,6 +54,8 @@ def main(speculative_algorithm=None): print(cnt_tokens) print('tokens and time:', sum(cnt_tokens), time_cost) print('e2e speed:', sum(cnt_tokens) / time_cost) print('max accept length:', max(cnt_tokens)) print('min accept length:', min(cnt_tokens)) print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens)) -
w32zhong revised this gist
Sep 10, 2025 . 1 changed file with 31 additions and 12 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,40 +4,59 @@ import sglang as sgl from sglang.utils import async_stream_and_merge, trim_overlap async def generate(llm, tokenizer, prompt): sampling_params = {"temperature": 0, "max_new_tokens": 8000} final_text = "" generator = await llm.async_generate(prompt, sampling_params, stream=True) cnt_tokens = [] async for chunk in generator: chunk_text = chunk["text"] cleaned_chunk = trim_overlap(final_text, chunk_text) final_text += cleaned_chunk print(tokenizer.decode(chunk['output_ids']), end="", flush=True) cnt_tokens.append(len(chunk['output_ids'])) return cnt_tokens def main(speculative_algorithm=None): tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507') question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?" messages = [ {"role": "user", "content": question}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(prompt) from sglang.srt.server_args import ServerArgs llm = sgl.Engine( model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507", tp_size=1, cuda_graph_max_bs=1, disable_cuda_graph=True, speculative_algorithm=speculative_algorithm, speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model", speculative_num_steps=6, speculative_eagle_topk=10, speculative_num_draft_tokens=60, #speculative_attention_mode="prefill" ) begin = time.perf_counter() cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt)) cnt_tokens.pop(0) time_cost = time.perf_counter() - begin llm.shutdown() print() print(cnt_tokens) print('tokens and time:', sum(cnt_tokens), time_cost) print('e2e speed:', sum(cnt_tokens) / time_cost) print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens)) if __name__ == '__main__': import fire fire.Fire(main) -
w32zhong created this gist
Sep 10, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,43 @@ import time from transformers import AutoTokenizer import asyncio import sglang as sgl from sglang.utils import async_stream_and_merge, trim_overlap async def main(llm, tokenizer, prompt): sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 8000} final_text = "" generator = await llm.async_generate(prompt, sampling_params, stream=True) cnt_token = 0 async for chunk in generator: chunk_text = chunk["text"] cleaned_chunk = trim_overlap(final_text, chunk_text) final_text += cleaned_chunk print(tokenizer.decode(chunk['output_ids']), end="") cnt_token += len(chunk['output_ids']) return cnt_token if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507') question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?" messages = [ {"role": "user", "content": question}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(prompt) llm = sgl.Engine( model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507", tp_size=4, cuda_graph_max_bs=16 ) begin = time.perf_counter() cnt_token = asyncio.run(main(llm, tokenizer, prompt)) time_cost = time.perf_counter() - begin llm.shutdown() print() print('tokens and time:', cnt_token, time_cost) print('e2e speed:', cnt_token / time_cost)