Skip to content

Instantly share code, notes, and snippets.

@w32zhong
Last active September 28, 2025 17:03
Show Gist options
  • Save w32zhong/2c066a7f7ed0bdc9e31007bf0ea8c6d6 to your computer and use it in GitHub Desktop.
Save w32zhong/2c066a7f7ed0bdc9e31007bf0ea8c6d6 to your computer and use it in GitHub Desktop.

Revisions

  1. w32zhong revised this gist Sep 28, 2025. 1 changed file with 42 additions and 15 deletions.
    57 changes: 42 additions & 15 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -24,8 +24,30 @@ def batch_generate(llm, prompts, sampling_params):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")


    def main(speculative_algorithm=None, bs=1, tp_size=1):

    def hot_replace_config(path):
    import os, json, shutil
    json_path = f'{path}/config.json'
    bkup_path = f'{path}/config.json.bkup'
    assert os.path.exists(json_path)
    shutil.copyfile(json_path, bkup_path)
    with open(json_path) as fh:
    j = json.load(fh)
    j['architectures'][0] += 'Eagle'
    j['tie_word_embeddings'] = False
    with open(json_path, 'w') as fh:
    json.dump(j, fh)

    def hot_recover_config(path):
    import os, shutil
    json_path = f'{path}/config.json'
    bkup_path = f'{path}/config.json.bkup'
    assert os.path.exists(bkup_path)
    shutil.copyfile(bkup_path, json_path)


    def main(speculative_algorithm=None, bs=1, tp_size=1,
    no_prefill_cache=True, draft_model_path=None):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    questions = [
    @@ -47,19 +69,24 @@ def main(speculative_algorithm=None, bs=1, tp_size=1):

    from sglang.srt.server_args import ServerArgs
    from sglang.srt.models.qwen3_spec import Qwen3ForCausalLMEagle
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=tp_size,
    cuda_graph_max_bs=bs,
    disable_cuda_graph=False,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )
    if draft_model_path:
    hot_replace_config(draft_model_path)
    llm = sgl.Engine( model_path="Qwen/Qwen3-4B-Instruct-2507",
    tp_size=tp_size,
    cuda_graph_max_bs=bs,
    disable_cuda_graph=False,
    disable_radix_cache=no_prefill_cache,
    disable_chunked_prefix_cache=no_prefill_cache,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path=draft_model_path,
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )
    if draft_model_path:
    hot_recover_config(draft_model_path)

    sampling_params = {"temperature": 0, "max_new_tokens": 8000}

  2. w32zhong revised this gist Sep 10, 2025. 1 changed file with 0 additions and 64 deletions.
    64 changes: 0 additions & 64 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -4,13 +4,6 @@
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def generate(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0, "max_new_tokens": 8000}import time
    from transformers import AutoTokenizer
    import asyncio
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def generate(llm, tokenizer, prompt, sampling_params):
    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    @@ -88,63 +81,6 @@ def main(speculative_algorithm=None, bs=1, tp_size=1):
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)


    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_tokens = []
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
    cnt_tokens.append(len(chunk['output_ids']))
    return cnt_tokens


    def main(speculative_algorithm=None):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?"
    messages = [
    {"role": "user", "content": question},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)

    from sglang.srt.server_args import ServerArgs
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=1,
    cuda_graph_max_bs=1,
    disable_cuda_graph=True,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )

    begin = time.perf_counter()
    cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt))
    cnt_tokens.pop(0)
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('max accept length:', max(cnt_tokens))
    print('min accept length:', min(cnt_tokens))
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)
  3. w32zhong revised this gist Sep 10, 2025. 3 changed files with 150 additions and 64 deletions.
    File renamed without changes.
    64 changes: 0 additions & 64 deletions sglang.py
    Original file line number Diff line number Diff line change
    @@ -1,64 +0,0 @@
    import time
    from transformers import AutoTokenizer
    import asyncio
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def generate(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0, "max_new_tokens": 8000}

    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_tokens = []
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
    cnt_tokens.append(len(chunk['output_ids']))
    return cnt_tokens


    def main(speculative_algorithm=None):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?"
    messages = [
    {"role": "user", "content": question},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)

    from sglang.srt.server_args import ServerArgs
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=1,
    cuda_graph_max_bs=1,
    disable_cuda_graph=True,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )

    begin = time.perf_counter()
    cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt))
    cnt_tokens.pop(0)
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('max accept length:', max(cnt_tokens))
    print('min accept length:', min(cnt_tokens))
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)
    150 changes: 150 additions & 0 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,150 @@
    import time
    from transformers import AutoTokenizer
    import asyncio
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def generate(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0, "max_new_tokens": 8000}import time
    from transformers import AutoTokenizer
    import asyncio
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def generate(llm, tokenizer, prompt, sampling_params):
    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_tokens = []
    print(prompt)
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
    cnt_tokens.append(len(chunk['output_ids']))
    return cnt_tokens


    def batch_generate(llm, prompts, sampling_params):
    outputs = llm.generate(prompts, sampling_params)
    for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")


    def main(speculative_algorithm=None, bs=1, tp_size=1):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    questions = [
    "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?",
    "Who is the president of the United States?",
    "Write an essay about the future of AI.",
    "What is your favorite book?",
    "What is your least favorite book?",
    "What is your favorite programming language?",
    "What is your least favorite programming language?",
    "Write a short, neutral self-introduction for a fictional character.",
    "Provide a concise factual statement about France’s capital city."
    ][:bs]
    messages = lambda question: [{"role": "user", "content": question}]
    prompts = [
    tokenizer.apply_chat_template(messages(Q), tokenize=False, add_generation_prompt=True)
    for Q in questions
    ]

    from sglang.srt.server_args import ServerArgs
    from sglang.srt.models.qwen3_spec import Qwen3ForCausalLMEagle
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=tp_size,
    cuda_graph_max_bs=bs,
    disable_cuda_graph=False,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )

    sampling_params = {"temperature": 0, "max_new_tokens": 8000}

    begin = time.perf_counter()
    if bs > 1:
    cnt_tokens = batch_generate(llm, prompts, sampling_params)
    else:
    cnt_tokens = asyncio.run(generate(llm, tokenizer, prompts[0], sampling_params))
    cnt_tokens.pop(0)
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('max accept length:', max(cnt_tokens))
    print('min accept length:', min(cnt_tokens))
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)


    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_tokens = []
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
    cnt_tokens.append(len(chunk['output_ids']))
    return cnt_tokens


    def main(speculative_algorithm=None):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?"
    messages = [
    {"role": "user", "content": question},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)

    from sglang.srt.server_args import ServerArgs
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=1,
    cuda_graph_max_bs=1,
    disable_cuda_graph=True,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )

    begin = time.perf_counter()
    cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt))
    cnt_tokens.pop(0)
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('max accept length:', max(cnt_tokens))
    print('min accept length:', min(cnt_tokens))
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)
  4. w32zhong revised this gist Sep 10, 2025. 2 changed files with 135 additions and 0 deletions.
    File renamed without changes.
    135 changes: 135 additions & 0 deletions qwen3_draft.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,135 @@
    from sglang.srt.utils import add_prefix

    # Adapted from
    # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
    """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""

    from typing import Iterable, Optional, Tuple

    import torch
    from torch import nn

    from sglang.srt.distributed import get_pp_group
    from sglang.srt.layers.logits_processor import LogitsProcessor
    from sglang.srt.layers.quantization.base_config import QuantizationConfig
    from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
    )
    from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
    from sglang.srt.models.qwen3 import Qwen3DecoderLayer, Qwen3ForCausalLM

    Qwen3Config = None


    class Qwen3DecoderLayer(Qwen3DecoderLayer):
    def __init__(
    self,
    config: Qwen3Config,
    layer_id: int = 0,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    ) -> None:
    super().__init__(config, layer_id, quant_config, prefix=prefix)

    # Skip the input_layernorm
    # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
    if layer_id == 0:
    del self.input_layernorm
    setattr(self, "input_layernorm", lambda x: x)


    class Qwen3Model(nn.Module):
    def __init__(
    self,
    config: Qwen3Config,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    ) -> None:
    super().__init__()
    self.config = config
    self.vocab_size = config.vocab_size
    self.embed_tokens = VocabParallelEmbedding(
    config.vocab_size,
    config.hidden_size,
    prefix=add_prefix("embed_tokens", prefix),
    )
    self.layers = nn.ModuleList(
    [
    Qwen3DecoderLayer(
    config,
    i,
    quant_config=quant_config,
    prefix=add_prefix(f"layers.{i}", prefix),
    )
    for i in range(config.num_hidden_layers)
    ]
    )
    self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)

    def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    forward_batch: ForwardBatch,
    input_embeds: torch.Tensor = None,
    pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> torch.Tensor:
    if input_embeds is None:
    hidden_states = self.embed_tokens(input_ids)
    else:
    hidden_states = input_embeds

    hidden_states = self.fc(
    torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
    )

    residual = None
    for i in range(len(self.layers)):
    layer = self.layers[i]
    hidden_states, residual = layer(
    positions,
    hidden_states,
    forward_batch,
    residual,
    )
    return hidden_states + residual


    class Qwen3ForCausalLMEagle(Qwen3ForCausalLM):
    def __init__(
    self,
    config: Qwen3Config,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    ) -> None:
    nn.Module.__init__(self)
    self.config = config
    self.quant_config = quant_config
    self.pp_group = get_pp_group()
    self.model = Qwen3Model(
    config, quant_config=quant_config, prefix=add_prefix("model", prefix)
    )
    if self.config.tie_word_embeddings:
    self.lm_head = self.model.embed_tokens
    else:
    self.lm_head = ParallelLMHead(
    config.vocab_size,
    config.hidden_size,
    quant_config=quant_config,
    prefix=add_prefix("lm_head", prefix),
    )
    self.logits_processor = LogitsProcessor(config)
    self.capture_aux_hidden_states = False

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
    print("[Qwen3ForCausalLMEagle] current model weights: ", [name for name, _ in self.named_parameters()])
    for name, loaded_weight in weights:
    name = name.replace("eagle_fc", "fc")
    if "lm_head" not in name:
    name = "model." + name
    print("[Qwen3ForCausalLMEagle] loading weight: ", name)
    super().load_weights([(name, loaded_weight)])


    EntryClass = [Qwen3ForCausalLMEagle]
  5. w32zhong revised this gist Sep 10, 2025. 1 changed file with 30 additions and 0 deletions.
    30 changes: 30 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,30 @@
    {
    "architectures": [
    "Qwen3ForCausalLMEagle"
    ],
    "attention_bias": false,
    "attention_dropout": 0.0,
    "bos_token_id": 151643,
    "eos_token_id": 151645,
    "head_dim": 128,
    "hidden_act": "silu",
    "hidden_size": 2560,
    "initializer_range": 0.02,
    "intermediate_size": 9728,
    "max_position_embeddings": 262144,
    "max_window_layers": 36,
    "model_type": "qwen3",
    "num_attention_heads": 32,
    "num_hidden_layers": 1,
    "num_key_value_heads": 8,
    "rms_norm_eps": 1e-06,
    "rope_scaling": null,
    "rope_theta": 5000000,
    "sliding_window": null,
    "tie_word_embeddings": false,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.54.1",
    "use_cache": true,
    "use_sliding_window": false,
    "vocab_size": 151936
    }
  6. w32zhong revised this gist Sep 10, 2025. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions sglang.py
    Original file line number Diff line number Diff line change
    @@ -54,6 +54,8 @@ def main(speculative_algorithm=None):
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('max accept length:', max(cnt_tokens))
    print('min accept length:', min(cnt_tokens))
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


  7. w32zhong revised this gist Sep 10, 2025. 1 changed file with 31 additions and 12 deletions.
    43 changes: 31 additions & 12 deletions sglang.py
    Original file line number Diff line number Diff line change
    @@ -4,40 +4,59 @@
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def main(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 8000}
    async def generate(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0, "max_new_tokens": 8000}

    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_token = 0
    cnt_tokens = []
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="")
    cnt_token += len(chunk['output_ids'])
    return cnt_token
    print(tokenizer.decode(chunk['output_ids']), end="", flush=True)
    cnt_tokens.append(len(chunk['output_ids']))
    return cnt_tokens


    if __name__ == '__main__':
    def main(speculative_algorithm=None):

    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?"
    messages = [
    {"role": "user", "content": question},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)

    from sglang.srt.server_args import ServerArgs
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=4,
    cuda_graph_max_bs=16
    tp_size=1,
    cuda_graph_max_bs=1,
    disable_cuda_graph=True,

    speculative_algorithm=speculative_algorithm,
    speculative_draft_model_path="/mnt/asus_card/hfdownloader/w32zhong_deft-bee-66/draft_model",
    speculative_num_steps=6,
    speculative_eagle_topk=10,
    speculative_num_draft_tokens=60,
    #speculative_attention_mode="prefill"
    )

    begin = time.perf_counter()
    cnt_token = asyncio.run(main(llm, tokenizer, prompt))
    cnt_tokens = asyncio.run(generate(llm, tokenizer, prompt))
    cnt_tokens.pop(0)
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print('tokens and time:', cnt_token, time_cost)
    print('e2e speed:', cnt_token / time_cost)
    print(cnt_tokens)
    print('tokens and time:', sum(cnt_tokens), time_cost)
    print('e2e speed:', sum(cnt_tokens) / time_cost)
    print('avg accept length:', sum(cnt_tokens) / len(cnt_tokens))


    if __name__ == '__main__':
    import fire
    fire.Fire(main)
  8. w32zhong created this gist Sep 10, 2025.
    43 changes: 43 additions & 0 deletions sglang.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,43 @@
    import time
    from transformers import AutoTokenizer
    import asyncio
    import sglang as sgl
    from sglang.utils import async_stream_and_merge, trim_overlap

    async def main(llm, tokenizer, prompt):
    sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 8000}

    final_text = ""
    generator = await llm.async_generate(prompt, sampling_params, stream=True)
    cnt_token = 0
    async for chunk in generator:
    chunk_text = chunk["text"]
    cleaned_chunk = trim_overlap(final_text, chunk_text)
    final_text += cleaned_chunk
    print(tokenizer.decode(chunk['output_ids']), end="")
    cnt_token += len(chunk['output_ids'])
    return cnt_token


    if __name__ == '__main__':
    tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B-Instruct-2507')
    question = "Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?"
    messages = [
    {"role": "user", "content": question},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)
    llm = sgl.Engine(
    model_path="/mnt/asus_card/hfdownloader/Qwen_Qwen3-4B-Instruct-2507",
    tp_size=4,
    cuda_graph_max_bs=16
    )

    begin = time.perf_counter()
    cnt_token = asyncio.run(main(llm, tokenizer, prompt))
    time_cost = time.perf_counter() - begin
    llm.shutdown()

    print()
    print('tokens and time:', cnt_token, time_cost)
    print('e2e speed:', cnt_token / time_cost)