Skip to content

Instantly share code, notes, and snippets.

@vsevolodl
Created July 1, 2024 20:08
Show Gist options
  • Save vsevolodl/93730228729ac00ebf8b8222d118a20d to your computer and use it in GitHub Desktop.
Save vsevolodl/93730228729ac00ebf8b8222d118a20d to your computer and use it in GitHub Desktop.
LLM throughput benchmarking
# Sample usage:
# Download py file and sample formatted_prompts.json files
#
# Run in command line: python llm_benchmark_throughput.py --prompts_file formatted_prompts.json --model meta-llama/Meta-Llama-3-8B-Instruct
# --api_base http://vllm:8000/v1 --api_key vllm_key_here --profile vllm050_fp16_tp2 --stream --iterations 3 --max_tokens 128 --qps "16, 32, 64"
#
#
# formatted_prompts.json sample record:
# {"prompt": "Your task is blah blah blah."}
#
import asyncio
import json
import os
import time
import random
from typing import List, Dict, Tuple, Optional
import argparse
import datetime
import statistics
import tiktoken
import logging
import aiohttp
import nest_asyncio
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from aiohttp import ClientSession, TCPConnector
# Apply the patch from nest_asyncio to enable nested use of asyncio
nest_asyncio.apply()
# Load environment variables
# load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Config:
"""Configuration class to store all settings."""
def __init__(self, model_path: str, max_tokens: int, temperature: float, api_base: str, api_key: str, no_stop: bool, profile: str):
self.OPENAI_API_KEY = api_key
self.OPENAI_API_BASE = api_base
self.MODEL_PATH = model_path
self.MAX_TOKENS = max_tokens
self.TEMPERATURE = temperature
self.NO_STOP = no_stop
self.PROFILE = profile
self.HEADERS = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.OPENAI_API_KEY}"
}
def load_prompts_from_json(file_path: str) -> List[Dict[str, str]]:
"""Load prompts from a JSON file."""
with open(file_path, 'r') as f:
return [json.loads(line) for line in f]
async def process_streaming_data(response: aiohttp.ClientResponse, prompt: str) -> Tuple[Dict, float]:
"""Process streaming data from the API response."""
content = ""
enc = tiktoken.get_encoding("cl100k_base")
prompt_tokens = len(enc.encode(prompt))
completion_tokens = 0
total_tokens = 0
chunk = None
ttft = None
start_time = time.perf_counter()
async for line in response.content:
if line.startswith(b'data: '):
json_data = line.decode('utf-8')[6:]
if json_data.strip() == "[DONE]":
break
chunk = json.loads(json_data)
if "choices" in chunk and len(chunk["choices"]) > 0 and "delta" in chunk["choices"][0] and "content" in \
chunk["choices"][0]["delta"]:
if ttft is None:
ttft = time.perf_counter() - start_time
content += chunk["choices"][0]["delta"]["content"]
if chunk:
completion_tokens = len(enc.encode(content))
total_tokens = prompt_tokens + completion_tokens
result = {
"id": chunk.get("id", ""),
"object": "chat.completion",
"created": chunk.get("created", 0),
"model": chunk.get("model", ""),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": chunk["choices"][0].get("finish_reason", None),
"stop_reason": None
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
"completion_tokens": completion_tokens
}
}
return result, ttft
else:
raise Exception("No valid streaming data received")
async def send_request(
session: ClientSession,
prompt: Dict[str, str],
config: Config,
stream: bool,
query_number: int,
semaphore: asyncio.Semaphore
) -> Dict:
"""Send a request to the LLM API and process the response."""
message = {
"model": config.MODEL_PATH,
"messages": [
{"role": "user", "content": prompt['prompt']}
],
"n": 1,
"frequency_penalty": 1.0,
"max_tokens": config.MAX_TOKENS,
"temperature": config.TEMPERATURE,
"stream": stream,
"ignore_eos": config.NO_STOP,
}
start_time = time.perf_counter()
async with semaphore:
try:
async with session.post(f"{config.OPENAI_API_BASE}/chat/completions", json=message) as response:
response.raise_for_status()
if stream:
result, ttft = await process_streaming_data(response, prompt['prompt'])
raw_response = result['choices'][0]['message']['content']
llm_generated_text = raw_response
prompt_tokens = result['usage']['prompt_tokens']
completion_tokens = result['usage']['completion_tokens']
total_tokens = result['usage']['total_tokens']
else:
result = await response.json()
raw_response = await response.text()
llm_generated_text = result['choices'][0]['message']['content']
usage = result.get('usage', {})
prompt_tokens = usage.get('prompt_tokens', 0)
completion_tokens = usage.get('completion_tokens', 0)
total_tokens = usage.get('total_tokens', 0)
ttft = None # TTFT is not applicable for non-streaming requests
end_time = time.perf_counter()
duration = end_time - start_time
return {
"query_number": query_number,
"query_start_time": start_time,
"query_end_time": end_time,
"query_run_time": duration,
"query_success_status": True,
"query_prompt": prompt['prompt'],
"llm_raw_response": raw_response,
"llm_generated_text": llm_generated_text,
"prompt_tokens": prompt_tokens,
"generated_tokens": completion_tokens,
"total_tokens": total_tokens,
"tokens_per_second": total_tokens / duration,
"model_path": config.MODEL_PATH,
"api_base": config.OPENAI_API_BASE,
"ttft": ttft,
}
except Exception as e:
end_time = time.perf_counter()
logger.error(f"Error in Query {query_number}: {str(e)}")
return {
"query_number": query_number,
"query_start_time": start_time,
"query_end_time": end_time,
"query_run_time": end_time - start_time,
"query_success_status": False,
"error": str(e),
"model_path": config.MODEL_PATH,
"api_base": config.OPENAI_API_BASE,
}
async def run_batch(
prompts: List[Dict[str, str]],
config: Config,
stream: bool,
batch_number: int,
qps: int,
batch_results_file: str,
batch_summary_file: str
) -> None:
"""Run a batch of queries and save the results."""
batch_start_time = time.perf_counter()
connector = TCPConnector(limit=qps)
async with ClientSession(headers=config.HEADERS, connector=connector) as session:
batch_results = []
tasks = []
semaphore = asyncio.Semaphore(qps)
with tqdm(total=len(prompts), desc=f"Submitting queries for batch {batch_number}",
unit="query") as submission_pbar:
for i, prompt in enumerate(prompts):
task = asyncio.create_task(send_request(session, prompt, config, stream, i, semaphore))
tasks.append(task)
submission_pbar.update(1)
if (i + 1) % qps == 0 or i == len(prompts) - 1:
await asyncio.sleep(5) # Wait for 5 second after each QPS batch
ttft_bar = tqdm(total=len(tasks), desc="Time to First Token", unit="query", leave=False)
for f in tqdm(asyncio.as_completed(tasks), total=len(tasks),
desc=f"Completing queries for Batch {batch_number}", unit="query"):
result = await f
batch_results.append(result)
if result.get('ttft'):
ttft_bar.update(1)
ttft_bar.close()
batch_end_time = time.perf_counter()
# Save detailed batch results
with open(batch_results_file, 'w') as f:
json.dump(batch_results, f, indent=2)
# Calculate and save batch summary
successful_queries = [r for r in batch_results if r['query_success_status']]
total_prompt_tokens = sum(r.get('prompt_tokens', 0) for r in successful_queries)
total_generated_tokens = sum(r.get('generated_tokens', 0) for r in successful_queries)
total_tokens = sum(r.get('total_tokens', 0) for r in successful_queries)
batch_runtime = batch_end_time - batch_start_time
tokens_per_second_list = [r['tokens_per_second'] for r in successful_queries if 'tokens_per_second' in r]
avg_tokens_per_second = statistics.mean(tokens_per_second_list) if tokens_per_second_list else 0
batch_tokens_per_second = total_tokens / batch_runtime if batch_runtime > 0 else 0
ttft_list = [r['ttft'] for r in successful_queries if r.get('ttft') is not None]
avg_ttft = statistics.mean(ttft_list) if ttft_list else None
batch_summary = {
"batch_start_time": batch_start_time,
"batch_end_time": batch_end_time,
"batch_runtime": batch_runtime,
"total_prompt_tokens": total_prompt_tokens,
"total_generated_tokens": total_generated_tokens,
"total_tokens": total_tokens,
"avg_tokens_per_second": avg_tokens_per_second,
"batch_tokens_per_second": batch_tokens_per_second,
"model_path": config.MODEL_PATH,
"api_base": config.OPENAI_API_BASE,
"queries_in_batch": qps,
"successful_queries": len(successful_queries),
"qps": qps,
"avg_ttft": avg_ttft
}
with open(batch_summary_file, 'a') as f:
json.dump(batch_summary, f)
f.write('\n')
logger.info(f"Batch {batch_number} completed: {len(successful_queries)}/{len(prompts)} successful queries")
logger.info(f"Average tokens per query per second: {avg_tokens_per_second:.2f}")
logger.info(f"Iteration total tokens per second: {batch_tokens_per_second:.2f}")
if avg_ttft:
logger.info(f"Average Time to First Token: {avg_ttft:.4f} seconds")
def run_experiment(
prompts: List[Dict[str, str]],
config: Config,
stream: bool,
iterations: int,
qps: int,
results_dir: str
) -> None:
"""Run the experiment for a specific QPS value."""
os.makedirs(results_dir, exist_ok=True)
model_name = os.path.basename(config.MODEL_PATH)
profile_suffix = f"_{config.PROFILE}" if config.PROFILE else ""
batch_summary_file = os.path.join(results_dir, f'batch_summaries_{model_name}{profile_suffix}_qps_{qps}.jsonl')
all_batch_tokens_per_second = []
# Reuse random prompts if QPS is higher than the number of supplied prompts
if len(prompts) < qps:
random_prompts = []
while len(random_prompts) < qps:
random_prompts.extend(random.choices(prompts, k=min(qps - len(random_prompts), len(prompts))))
prompts = random_prompts
elif len(prompts) > qps:
prompts = prompts[:qps]
for i in range(iterations):
logger.info(f"Starting iteration {i + 1}/{iterations} for QPS {qps}")
batch_results_file = os.path.join(results_dir,
f'batch_results_{model_name}{profile_suffix}_iteration_{i + 1}_qps_{qps}.json')
asyncio.run(run_batch(prompts, config, stream, i + 1, qps, batch_results_file, batch_summary_file))
# Read the batch summary to get the batch_tokens_per_second
with open(batch_summary_file, 'r') as f:
lines = f.readlines()
last_batch_summary = json.loads(lines[-1])
all_batch_tokens_per_second.append(last_batch_summary.get('batch_tokens_per_second', 0))
# Calculate and print average tokens per second across all iterations
if all_batch_tokens_per_second:
avg_tokens_per_second = statistics.mean(all_batch_tokens_per_second)
logger.info(f"Average total tokens per second across all iterations: {avg_tokens_per_second:.2f}")
else:
logger.warning("No valid data for tokens per second across iterations.")
# Generate performance graph
generate_performance_graph(all_batch_tokens_per_second, qps, results_dir)
def generate_performance_graph(tokens_per_second: List[float], qps: int, results_dir: str) -> None:
"""Generate a performance graph for the experiment."""
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(tokens_per_second) + 1), tokens_per_second, marker='o')
plt.title(f'Performance Graph for QPS {qps}')
plt.xlabel('Iteration')
plt.ylabel('Tokens per Second')
plt.grid(True)
plt.savefig(os.path.join(results_dir, f'performance_graph_qps_{qps}.png'))
plt.close()
def export_results_to_csv(results_dir: str) -> None:
"""Export all batch summaries to a CSV file."""
all_summaries = []
for filename in os.listdir(results_dir):
if filename.startswith('batch_summaries_') and filename.endswith('.jsonl'):
with open(os.path.join(results_dir, filename), 'r') as f:
summaries = [json.loads(line) for line in f]
all_summaries.extend(summaries)
if all_summaries:
df = pd.DataFrame(all_summaries)
csv_file = os.path.join(results_dir, 'all_batch_summaries.csv')
df.to_csv(csv_file, index=False)
logger.info(f"Exported all batch summaries to {csv_file}")
else:
logger.warning("No batch summaries found to export.")
def main():
"""Main function to run the experiment."""
parser = argparse.ArgumentParser(description="Run throughput test for LLM API")
parser.add_argument("--model", type=str, required=True, help="Model path")
parser.add_argument("--max_tokens", type=int, default=128, help="Maximum number of tokens to generate")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--prompts_file", type=str, required=True, help="Path to the JSON file containing prompts")
parser.add_argument("--api_base", type=str, required=True, help="LLM API endpoint")
parser.add_argument("--api_key", type=str, required=True, help="API key for the LLM service")
parser.add_argument("--stream", action="store_true", help="Use streaming mode")
parser.add_argument("--iterations", type=int, default=1, help="Number of times to run the experiment")
parser.add_argument("--qps", type=str, required=True, help="Comma-separated list of QPS values to test")
parser.add_argument("--no_stop", action="store_true", help="Don't stop generation early")
parser.add_argument("--profile", type=str, help="Optional profile name to include in result file names")
args = parser.parse_args()
config = Config(args.model, args.max_tokens, args.temperature, args.api_base, args.api_key, args.no_stop,
args.profile)
prompts = load_prompts_from_json(args.prompts_file)
logger.info(f"Loaded {len(prompts)} prompts")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = f'results_{timestamp}'
qps_values = [int(qps.strip()) for qps in args.qps.split(',')]
for qps in qps_values:
logger.info(f"Running experiment for QPS: {qps}")
run_experiment(prompts, config, args.stream, args.iterations, qps, results_dir)
export_results_to_csv(results_dir)
logger.info(f"All experiments completed. Results saved in {results_dir}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.info("Experiment interrupted by user. Shutting down gracefully...")
except Exception as e:
logger.exception(f"An unexpected error occurred: {str(e)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment