Skip to content

Instantly share code, notes, and snippets.

@cli99
Last active August 29, 2024 02:34
Show Gist options
  • Save cli99/df69b76edfe41a85af2636510835236c to your computer and use it in GitHub Desktop.
Save cli99/df69b76edfe41a85af2636510835236c to your computer and use it in GitHub Desktop.

Revisions

  1. cli99 renamed this gist Aug 29, 2024. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. cli99 revised this gist Aug 29, 2024. No changes.
  3. cli99 created this gist Aug 29, 2024.
    52 changes: 52 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,52 @@
    import os
    import time

    import torch
    import transformers
    from torch.profiler import ProfilerActivity, profile, record_function

    from vllm import LLM, SamplingParams

    os.environ["HOST_IP"] = "10.42.10.16"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"

    if __name__ == "__main__":
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    tensor_parallel_size = 1
    llm = LLM(
    model=model_id,
    tensor_parallel_size=tensor_parallel_size,
    quantization="fp8",
    # kv_cache_dtype="fp8",
    enforce_eager=False,
    enable_chunked_prefill=True,
    max_num_batched_tokens=2048,
    gpu_memory_utilization=0.90,
    )

    # model = llm.llm_engine.model_executor.driver_worker.model_runner.model
    # print(model)
    batch_size = 1
    prompts = [
    # "Hello, my name is",
    "The president of the United States is",
    # "The capital of France is",
    # "The future of AI is",
    ] * batch_size

    sampling_params = SamplingParams(temperature=0, top_p=1.0, top_k=-1, max_tokens=5)

    outputs = llm.generate(prompts, sampling_params)

    with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True
    ) as prof:
    outputs = llm.generate(prompts, sampling_params)
    timestamp = int(time.time())
    prof.export_chrome_trace(f"trace_8b_tp1_bs{batch_size}_{timestamp}.json")

    # Print the outputs.
    for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")