Last active
May 12, 2024 20:44
-
-
Save data2json/db67528117f1507449848e7ef2f4bb5d to your computer and use it in GitHub Desktop.
Install VLLM.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import asyncio | |
| import aiohttp | |
| import json | |
| import logging | |
| from threading import Lock | |
| # Logging setup (for better debugging) | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| # Configuration | |
| config = { | |
| "system_prompt_file": "system-prompt-3.txt", | |
| "system_prompt_file_2": "system-prompt-2.txt", | |
| "base_openai_url": "http://0.0.0.0:8000/v1/chat/completions", | |
| "max_tokens": 2048, | |
| "num_responses": 1, | |
| "model": "gpt-3.5-turbo", | |
| "temperature": 0.6, | |
| "presence_penalty": 0.9, | |
| "num_turns": 3, | |
| "num_conversations": 2, | |
| "max_connections": 1, | |
| "max_retries": 3, | |
| "logprob": True, | |
| "output_file": "conversation_history.jsonl", | |
| } | |
| # Error mappings | |
| error_mappings = { | |
| "invalid_request_error": { | |
| "code": 400, | |
| "description": "Your request was malformed or missing some required parameters.", | |
| "retry": False, | |
| }, | |
| "rate_limit_error": { | |
| "code": 429, | |
| "description": "You have hit your assigned rate limit.", | |
| "retry": True, | |
| }, | |
| "tokens_exceeded_error": { | |
| "code": 403, | |
| "description": "You have exceeded the allowed number of tokens in your request.", | |
| "retry": False, | |
| }, | |
| "authentication_error": { | |
| "code": 401, | |
| "description": "Your API key or token was invalid, expired, or revoked.", | |
| "retry": False, | |
| }, | |
| "not_found_error": { | |
| "code": 404, | |
| "description": "The requested resource was not found.", | |
| "retry": False, | |
| }, | |
| "server_error": { | |
| "code": 500, | |
| "description": "An issue occurred on the OpenAI server side.", | |
| "retry": True, | |
| }, | |
| "permission_error": { | |
| "code": 403, | |
| "description": "Your API key or token lacks the required permissions for the requested action.", | |
| "retry": False, | |
| }, | |
| } | |
| # Get OpenAI API key from environment variable | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| raise ValueError("OPENAI_API_KEY environment variable not set") | |
| # Create semaphores and locks | |
| connection_semaphore = asyncio.Semaphore(config["max_connections"]) | |
| file_lock = Lock() | |
| async def capture_output(prompt, role, session, min_length=3): | |
| """ | |
| Captures the output from the OpenAI API for a given prompt and role. | |
| Args: | |
| prompt (str): The input prompt for the API. | |
| role (str): The role of the message sender (e.g., "system", "user", "assistant"). | |
| session (aiohttp.ClientSession): The aiohttp session object for making API requests. | |
| min_length (int): The minimum length of the output to be considered valid. | |
| Returns: | |
| dict: The captured output as a dictionary containing the complete API response, or None if an error occurs. | |
| """ | |
| while True: | |
| async with connection_semaphore: | |
| data = { | |
| "model": config["model"], | |
| "messages": [{"role": role, "content": prompt}], | |
| "max_tokens": config["max_tokens"], | |
| "n": config["num_responses"], | |
| "temperature": config["temperature"], | |
| "presence_penalty": config["presence_penalty"], | |
| "logprobs": True, "top_logprobs": 1, "include_stop_str_in_output": True, "best_of": 10, | |
| } | |
| headers = {"Authorization": f"Bearer {OPENAI_API_KEY}"} | |
| try: | |
| async with session.post( | |
| config["base_openai_url"], json=data, headers=headers | |
| ) as response: | |
| if response.status in [ | |
| mapping["code"] for mapping in error_mappings.values() | |
| ]: | |
| error_type = next( | |
| key | |
| for key, value in error_mappings.items() | |
| if value["code"] == response.status | |
| ) | |
| error_description = error_mappings[error_type]["description"] | |
| if error_mappings[error_type]["retry"]: | |
| logging.warning( | |
| f"{error_type} ({response.status}): {error_description}. Retrying in 5 seconds..." | |
| ) | |
| await asyncio.sleep(5) | |
| continue | |
| else: | |
| logging.error( | |
| f"{error_type} ({response.status}): {error_description}. Skipping conversation." | |
| ) | |
| return None | |
| else: | |
| response.raise_for_status() # Check for other API request errors | |
| api_response = await response.json() | |
| output = api_response["choices"][0]["message"]["content"] | |
| if len(output) >= min_length: | |
| return {"role": role, "api_response": api_response} | |
| else: | |
| logging.warning( | |
| "Output is too short. Skipping conversation update." | |
| ) | |
| return None | |
| except aiohttp.ClientConnectorError as e: | |
| logging.warning(f"Connection error: {e}. Retrying in 5 seconds...") | |
| await asyncio.sleep(5) | |
| except aiohttp.ClientError as e: | |
| logging.warning(f"API request error: {e}. Retrying in 5 seconds...") | |
| await asyncio.sleep(5) | |
| async def save_conversation(conversation): | |
| """ | |
| Saves the conversation to a JSON lines file. | |
| Args: | |
| conversation (dict): The conversation data to be saved. | |
| """ | |
| try: | |
| with file_lock: | |
| with open(config["output_file"], "a") as output_file: | |
| json.dump(conversation, output_file) | |
| output_file.write("\n") | |
| except Exception as e: | |
| logging.error(f"Error saving conversation: {e}") | |
| async def run_conversation(session): | |
| """ | |
| Runs a single conversation with the specified number of turns. | |
| Args: | |
| session (aiohttp.ClientSession): The aiohttp session object for making API requests. | |
| Returns: | |
| dict: The conversation data, or None if an error occurs. | |
| """ | |
| conversation = {"conversation": []} | |
| # This prompt writes a new system prompt for the conversation that follows. | |
| with open(config["system_prompt_file"], "r") as file: | |
| system_prompt = file.read() | |
| system_output = await capture_output(system_prompt, "system", session) | |
| if system_output: | |
| conversation["conversation"].append(system_output) | |
| else: | |
| return None | |
| # Conversation loop | |
| for _ in range(config["num_turns"]): | |
| # Assistant's response | |
| with open(config["system_prompt_file_2"], "r") as file: | |
| system_prompt_2 = file.read() | |
| last_user_prompt = next( | |
| ( | |
| item["api_response"]["choices"][0]["message"]["content"] | |
| for item in reversed(conversation["conversation"]) | |
| if item["role"] == "user" | |
| ), | |
| "", | |
| ) | |
| prompt = ( | |
| system_prompt_2 + "\n" + last_user_prompt | |
| if last_user_prompt | |
| else system_prompt_2 | |
| ) | |
| assistant_output = await capture_output(prompt, "assistant", session) | |
| if assistant_output: | |
| conversation["conversation"].append(assistant_output) | |
| else: | |
| return None | |
| # User's response | |
| user_output = await capture_output( | |
| system_prompt + assistant_output["api_response"]["choices"][0]["message"]["content"], | |
| "user", | |
| session, | |
| ) | |
| if user_output: | |
| conversation["conversation"].append(user_output) | |
| else: | |
| return None | |
| return conversation | |
| async def run_and_save_conversation(session, conversation_id): | |
| """ | |
| Runs a conversation and saves it to the output file. | |
| Args: | |
| session (aiohttp.ClientSession): The aiohttp session object for making API requests. | |
| conversation_id (int): The ID of the conversation. | |
| """ | |
| logging.info(f"Starting conversation {conversation_id}") | |
| conversation = await run_conversation(session) | |
| if conversation: | |
| logging.info(f"Saving conversation {conversation_id}") | |
| await save_conversation(conversation) | |
| else: | |
| logging.warning(f"Conversation {conversation_id} is None, skipping save") | |
| async def main(): | |
| """ | |
| The main function that runs the script. | |
| """ | |
| async with aiohttp.ClientSession() as session: | |
| conversation_semaphore = asyncio.Semaphore(config["max_connections"]) | |
| tasks = [] | |
| async def run_conversation_with_semaphore(conversation_id): | |
| async with conversation_semaphore: | |
| await run_and_save_conversation(session, conversation_id) | |
| for conversation_id in range(config["num_conversations"]): | |
| task = asyncio.create_task(run_conversation_with_semaphore(conversation_id)) | |
| tasks.append(task) | |
| await asyncio.gather(*tasks) | |
| logging.info(f"Collected {config['num_conversations']} conversations.") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ### Make a new python 3.11 environment via pyenv or conda or whatever. | |
| # Install vLLM with CUDA 11.8. | |
| export CUDA_VER="cu118" | |
| export VLLM_VERSION=0.4.2 | |
| export PYTHON_VERSION=311 | |
| pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+${CUDA_VER}-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x | |
| 86_64.whl --extra-index-url https://download.pytorch.org/whl/${CUDA_VER} | |
| export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE | |
| pip install wheel | |
| pip install --no-cache-dir flash-attn --no-build-isolation | |
| python -m vllm.entrypoints.openai.api_server -pp 1 -tp 1 --dtype float16 --max-model-len 4096 --enable-prefix-caching --device cuda --max-log-len 25 --max-logprobs 10 --enforce-eager --model meta-llama/Meta-Llama-3-8B-Instruct --served-model-name gpt-3.5-turbo --api-key=sk-1234 --disable-custom-all-reduce --disable-log-requests --gpu-memory-utilization 1.0 --uvicorn-log-level critical & | |
| ### Tested on vLLM engine (v0.4.2) | |
| openai -b http://0.0.0.0:8000/v1/ api completions.create -M 2048 -n1 -m gpt-3.5-turbo -t 0.6 -P0.9 --prompt "Hi. My name is..." |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
OpenAI CLI example.