Created
          July 1, 2024 20:08 
        
      - 
      
 - 
        
Save vsevolodl/93730228729ac00ebf8b8222d118a20d to your computer and use it in GitHub Desktop.  
    LLM throughput benchmarking
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # Sample usage: | |
| # Download py file and sample formatted_prompts.json files | |
| # | |
| # Run in command line: python llm_benchmark_throughput.py --prompts_file formatted_prompts.json --model meta-llama/Meta-Llama-3-8B-Instruct | |
| # --api_base http://vllm:8000/v1 --api_key vllm_key_here --profile vllm050_fp16_tp2 --stream --iterations 3 --max_tokens 128 --qps "16, 32, 64" | |
| # | |
| # | |
| # formatted_prompts.json sample record: | |
| # {"prompt": "Your task is blah blah blah."} | |
| # | |
| import asyncio | |
| import json | |
| import os | |
| import time | |
| import random | |
| from typing import List, Dict, Tuple, Optional | |
| import argparse | |
| import datetime | |
| import statistics | |
| import tiktoken | |
| import logging | |
| import aiohttp | |
| import nest_asyncio | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from aiohttp import ClientSession, TCPConnector | |
| # Apply the patch from nest_asyncio to enable nested use of asyncio | |
| nest_asyncio.apply() | |
| # Load environment variables | |
| # load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class Config: | |
| """Configuration class to store all settings.""" | |
| def __init__(self, model_path: str, max_tokens: int, temperature: float, api_base: str, api_key: str, no_stop: bool, profile: str): | |
| self.OPENAI_API_KEY = api_key | |
| self.OPENAI_API_BASE = api_base | |
| self.MODEL_PATH = model_path | |
| self.MAX_TOKENS = max_tokens | |
| self.TEMPERATURE = temperature | |
| self.NO_STOP = no_stop | |
| self.PROFILE = profile | |
| self.HEADERS = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.OPENAI_API_KEY}" | |
| } | |
| def load_prompts_from_json(file_path: str) -> List[Dict[str, str]]: | |
| """Load prompts from a JSON file.""" | |
| with open(file_path, 'r') as f: | |
| return [json.loads(line) for line in f] | |
| async def process_streaming_data(response: aiohttp.ClientResponse, prompt: str) -> Tuple[Dict, float]: | |
| """Process streaming data from the API response.""" | |
| content = "" | |
| enc = tiktoken.get_encoding("cl100k_base") | |
| prompt_tokens = len(enc.encode(prompt)) | |
| completion_tokens = 0 | |
| total_tokens = 0 | |
| chunk = None | |
| ttft = None | |
| start_time = time.perf_counter() | |
| async for line in response.content: | |
| if line.startswith(b'data: '): | |
| json_data = line.decode('utf-8')[6:] | |
| if json_data.strip() == "[DONE]": | |
| break | |
| chunk = json.loads(json_data) | |
| if "choices" in chunk and len(chunk["choices"]) > 0 and "delta" in chunk["choices"][0] and "content" in \ | |
| chunk["choices"][0]["delta"]: | |
| if ttft is None: | |
| ttft = time.perf_counter() - start_time | |
| content += chunk["choices"][0]["delta"]["content"] | |
| if chunk: | |
| completion_tokens = len(enc.encode(content)) | |
| total_tokens = prompt_tokens + completion_tokens | |
| result = { | |
| "id": chunk.get("id", ""), | |
| "object": "chat.completion", | |
| "created": chunk.get("created", 0), | |
| "model": chunk.get("model", ""), | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": content | |
| }, | |
| "logprobs": None, | |
| "finish_reason": chunk["choices"][0].get("finish_reason", None), | |
| "stop_reason": None | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "total_tokens": total_tokens, | |
| "completion_tokens": completion_tokens | |
| } | |
| } | |
| return result, ttft | |
| else: | |
| raise Exception("No valid streaming data received") | |
| async def send_request( | |
| session: ClientSession, | |
| prompt: Dict[str, str], | |
| config: Config, | |
| stream: bool, | |
| query_number: int, | |
| semaphore: asyncio.Semaphore | |
| ) -> Dict: | |
| """Send a request to the LLM API and process the response.""" | |
| message = { | |
| "model": config.MODEL_PATH, | |
| "messages": [ | |
| {"role": "user", "content": prompt['prompt']} | |
| ], | |
| "n": 1, | |
| "frequency_penalty": 1.0, | |
| "max_tokens": config.MAX_TOKENS, | |
| "temperature": config.TEMPERATURE, | |
| "stream": stream, | |
| "ignore_eos": config.NO_STOP, | |
| } | |
| start_time = time.perf_counter() | |
| async with semaphore: | |
| try: | |
| async with session.post(f"{config.OPENAI_API_BASE}/chat/completions", json=message) as response: | |
| response.raise_for_status() | |
| if stream: | |
| result, ttft = await process_streaming_data(response, prompt['prompt']) | |
| raw_response = result['choices'][0]['message']['content'] | |
| llm_generated_text = raw_response | |
| prompt_tokens = result['usage']['prompt_tokens'] | |
| completion_tokens = result['usage']['completion_tokens'] | |
| total_tokens = result['usage']['total_tokens'] | |
| else: | |
| result = await response.json() | |
| raw_response = await response.text() | |
| llm_generated_text = result['choices'][0]['message']['content'] | |
| usage = result.get('usage', {}) | |
| prompt_tokens = usage.get('prompt_tokens', 0) | |
| completion_tokens = usage.get('completion_tokens', 0) | |
| total_tokens = usage.get('total_tokens', 0) | |
| ttft = None # TTFT is not applicable for non-streaming requests | |
| end_time = time.perf_counter() | |
| duration = end_time - start_time | |
| return { | |
| "query_number": query_number, | |
| "query_start_time": start_time, | |
| "query_end_time": end_time, | |
| "query_run_time": duration, | |
| "query_success_status": True, | |
| "query_prompt": prompt['prompt'], | |
| "llm_raw_response": raw_response, | |
| "llm_generated_text": llm_generated_text, | |
| "prompt_tokens": prompt_tokens, | |
| "generated_tokens": completion_tokens, | |
| "total_tokens": total_tokens, | |
| "tokens_per_second": total_tokens / duration, | |
| "model_path": config.MODEL_PATH, | |
| "api_base": config.OPENAI_API_BASE, | |
| "ttft": ttft, | |
| } | |
| except Exception as e: | |
| end_time = time.perf_counter() | |
| logger.error(f"Error in Query {query_number}: {str(e)}") | |
| return { | |
| "query_number": query_number, | |
| "query_start_time": start_time, | |
| "query_end_time": end_time, | |
| "query_run_time": end_time - start_time, | |
| "query_success_status": False, | |
| "error": str(e), | |
| "model_path": config.MODEL_PATH, | |
| "api_base": config.OPENAI_API_BASE, | |
| } | |
| async def run_batch( | |
| prompts: List[Dict[str, str]], | |
| config: Config, | |
| stream: bool, | |
| batch_number: int, | |
| qps: int, | |
| batch_results_file: str, | |
| batch_summary_file: str | |
| ) -> None: | |
| """Run a batch of queries and save the results.""" | |
| batch_start_time = time.perf_counter() | |
| connector = TCPConnector(limit=qps) | |
| async with ClientSession(headers=config.HEADERS, connector=connector) as session: | |
| batch_results = [] | |
| tasks = [] | |
| semaphore = asyncio.Semaphore(qps) | |
| with tqdm(total=len(prompts), desc=f"Submitting queries for batch {batch_number}", | |
| unit="query") as submission_pbar: | |
| for i, prompt in enumerate(prompts): | |
| task = asyncio.create_task(send_request(session, prompt, config, stream, i, semaphore)) | |
| tasks.append(task) | |
| submission_pbar.update(1) | |
| if (i + 1) % qps == 0 or i == len(prompts) - 1: | |
| await asyncio.sleep(5) # Wait for 5 second after each QPS batch | |
| ttft_bar = tqdm(total=len(tasks), desc="Time to First Token", unit="query", leave=False) | |
| for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), | |
| desc=f"Completing queries for Batch {batch_number}", unit="query"): | |
| result = await f | |
| batch_results.append(result) | |
| if result.get('ttft'): | |
| ttft_bar.update(1) | |
| ttft_bar.close() | |
| batch_end_time = time.perf_counter() | |
| # Save detailed batch results | |
| with open(batch_results_file, 'w') as f: | |
| json.dump(batch_results, f, indent=2) | |
| # Calculate and save batch summary | |
| successful_queries = [r for r in batch_results if r['query_success_status']] | |
| total_prompt_tokens = sum(r.get('prompt_tokens', 0) for r in successful_queries) | |
| total_generated_tokens = sum(r.get('generated_tokens', 0) for r in successful_queries) | |
| total_tokens = sum(r.get('total_tokens', 0) for r in successful_queries) | |
| batch_runtime = batch_end_time - batch_start_time | |
| tokens_per_second_list = [r['tokens_per_second'] for r in successful_queries if 'tokens_per_second' in r] | |
| avg_tokens_per_second = statistics.mean(tokens_per_second_list) if tokens_per_second_list else 0 | |
| batch_tokens_per_second = total_tokens / batch_runtime if batch_runtime > 0 else 0 | |
| ttft_list = [r['ttft'] for r in successful_queries if r.get('ttft') is not None] | |
| avg_ttft = statistics.mean(ttft_list) if ttft_list else None | |
| batch_summary = { | |
| "batch_start_time": batch_start_time, | |
| "batch_end_time": batch_end_time, | |
| "batch_runtime": batch_runtime, | |
| "total_prompt_tokens": total_prompt_tokens, | |
| "total_generated_tokens": total_generated_tokens, | |
| "total_tokens": total_tokens, | |
| "avg_tokens_per_second": avg_tokens_per_second, | |
| "batch_tokens_per_second": batch_tokens_per_second, | |
| "model_path": config.MODEL_PATH, | |
| "api_base": config.OPENAI_API_BASE, | |
| "queries_in_batch": qps, | |
| "successful_queries": len(successful_queries), | |
| "qps": qps, | |
| "avg_ttft": avg_ttft | |
| } | |
| with open(batch_summary_file, 'a') as f: | |
| json.dump(batch_summary, f) | |
| f.write('\n') | |
| logger.info(f"Batch {batch_number} completed: {len(successful_queries)}/{len(prompts)} successful queries") | |
| logger.info(f"Average tokens per query per second: {avg_tokens_per_second:.2f}") | |
| logger.info(f"Iteration total tokens per second: {batch_tokens_per_second:.2f}") | |
| if avg_ttft: | |
| logger.info(f"Average Time to First Token: {avg_ttft:.4f} seconds") | |
| def run_experiment( | |
| prompts: List[Dict[str, str]], | |
| config: Config, | |
| stream: bool, | |
| iterations: int, | |
| qps: int, | |
| results_dir: str | |
| ) -> None: | |
| """Run the experiment for a specific QPS value.""" | |
| os.makedirs(results_dir, exist_ok=True) | |
| model_name = os.path.basename(config.MODEL_PATH) | |
| profile_suffix = f"_{config.PROFILE}" if config.PROFILE else "" | |
| batch_summary_file = os.path.join(results_dir, f'batch_summaries_{model_name}{profile_suffix}_qps_{qps}.jsonl') | |
| all_batch_tokens_per_second = [] | |
| # Reuse random prompts if QPS is higher than the number of supplied prompts | |
| if len(prompts) < qps: | |
| random_prompts = [] | |
| while len(random_prompts) < qps: | |
| random_prompts.extend(random.choices(prompts, k=min(qps - len(random_prompts), len(prompts)))) | |
| prompts = random_prompts | |
| elif len(prompts) > qps: | |
| prompts = prompts[:qps] | |
| for i in range(iterations): | |
| logger.info(f"Starting iteration {i + 1}/{iterations} for QPS {qps}") | |
| batch_results_file = os.path.join(results_dir, | |
| f'batch_results_{model_name}{profile_suffix}_iteration_{i + 1}_qps_{qps}.json') | |
| asyncio.run(run_batch(prompts, config, stream, i + 1, qps, batch_results_file, batch_summary_file)) | |
| # Read the batch summary to get the batch_tokens_per_second | |
| with open(batch_summary_file, 'r') as f: | |
| lines = f.readlines() | |
| last_batch_summary = json.loads(lines[-1]) | |
| all_batch_tokens_per_second.append(last_batch_summary.get('batch_tokens_per_second', 0)) | |
| # Calculate and print average tokens per second across all iterations | |
| if all_batch_tokens_per_second: | |
| avg_tokens_per_second = statistics.mean(all_batch_tokens_per_second) | |
| logger.info(f"Average total tokens per second across all iterations: {avg_tokens_per_second:.2f}") | |
| else: | |
| logger.warning("No valid data for tokens per second across iterations.") | |
| # Generate performance graph | |
| generate_performance_graph(all_batch_tokens_per_second, qps, results_dir) | |
| def generate_performance_graph(tokens_per_second: List[float], qps: int, results_dir: str) -> None: | |
| """Generate a performance graph for the experiment.""" | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(range(1, len(tokens_per_second) + 1), tokens_per_second, marker='o') | |
| plt.title(f'Performance Graph for QPS {qps}') | |
| plt.xlabel('Iteration') | |
| plt.ylabel('Tokens per Second') | |
| plt.grid(True) | |
| plt.savefig(os.path.join(results_dir, f'performance_graph_qps_{qps}.png')) | |
| plt.close() | |
| def export_results_to_csv(results_dir: str) -> None: | |
| """Export all batch summaries to a CSV file.""" | |
| all_summaries = [] | |
| for filename in os.listdir(results_dir): | |
| if filename.startswith('batch_summaries_') and filename.endswith('.jsonl'): | |
| with open(os.path.join(results_dir, filename), 'r') as f: | |
| summaries = [json.loads(line) for line in f] | |
| all_summaries.extend(summaries) | |
| if all_summaries: | |
| df = pd.DataFrame(all_summaries) | |
| csv_file = os.path.join(results_dir, 'all_batch_summaries.csv') | |
| df.to_csv(csv_file, index=False) | |
| logger.info(f"Exported all batch summaries to {csv_file}") | |
| else: | |
| logger.warning("No batch summaries found to export.") | |
| def main(): | |
| """Main function to run the experiment.""" | |
| parser = argparse.ArgumentParser(description="Run throughput test for LLM API") | |
| parser.add_argument("--model", type=str, required=True, help="Model path") | |
| parser.add_argument("--max_tokens", type=int, default=128, help="Maximum number of tokens to generate") | |
| parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") | |
| parser.add_argument("--prompts_file", type=str, required=True, help="Path to the JSON file containing prompts") | |
| parser.add_argument("--api_base", type=str, required=True, help="LLM API endpoint") | |
| parser.add_argument("--api_key", type=str, required=True, help="API key for the LLM service") | |
| parser.add_argument("--stream", action="store_true", help="Use streaming mode") | |
| parser.add_argument("--iterations", type=int, default=1, help="Number of times to run the experiment") | |
| parser.add_argument("--qps", type=str, required=True, help="Comma-separated list of QPS values to test") | |
| parser.add_argument("--no_stop", action="store_true", help="Don't stop generation early") | |
| parser.add_argument("--profile", type=str, help="Optional profile name to include in result file names") | |
| args = parser.parse_args() | |
| config = Config(args.model, args.max_tokens, args.temperature, args.api_base, args.api_key, args.no_stop, | |
| args.profile) | |
| prompts = load_prompts_from_json(args.prompts_file) | |
| logger.info(f"Loaded {len(prompts)} prompts") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| results_dir = f'results_{timestamp}' | |
| qps_values = [int(qps.strip()) for qps in args.qps.split(',')] | |
| for qps in qps_values: | |
| logger.info(f"Running experiment for QPS: {qps}") | |
| run_experiment(prompts, config, args.stream, args.iterations, qps, results_dir) | |
| export_results_to_csv(results_dir) | |
| logger.info(f"All experiments completed. Results saved in {results_dir}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| logger.info("Experiment interrupted by user. Shutting down gracefully...") | |
| except Exception as e: | |
| logger.exception(f"An unexpected error occurred: {str(e)}") | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment