Skip to content

Instantly share code, notes, and snippets.

@data2json
Last active May 12, 2024 20:44
Show Gist options
  • Select an option

  • Save data2json/db67528117f1507449848e7ef2f4bb5d to your computer and use it in GitHub Desktop.

Select an option

Save data2json/db67528117f1507449848e7ef2f4bb5d to your computer and use it in GitHub Desktop.
Install VLLM.
import os
import asyncio
import aiohttp
import json
import logging
from threading import Lock
# Logging setup (for better debugging)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Configuration
config = {
"system_prompt_file": "system-prompt-3.txt",
"system_prompt_file_2": "system-prompt-2.txt",
"base_openai_url": "http://0.0.0.0:8000/v1/chat/completions",
"max_tokens": 2048,
"num_responses": 1,
"model": "gpt-3.5-turbo",
"temperature": 0.6,
"presence_penalty": 0.9,
"num_turns": 3,
"num_conversations": 2,
"max_connections": 1,
"max_retries": 3,
"logprob": True,
"output_file": "conversation_history.jsonl",
}
# Error mappings
error_mappings = {
"invalid_request_error": {
"code": 400,
"description": "Your request was malformed or missing some required parameters.",
"retry": False,
},
"rate_limit_error": {
"code": 429,
"description": "You have hit your assigned rate limit.",
"retry": True,
},
"tokens_exceeded_error": {
"code": 403,
"description": "You have exceeded the allowed number of tokens in your request.",
"retry": False,
},
"authentication_error": {
"code": 401,
"description": "Your API key or token was invalid, expired, or revoked.",
"retry": False,
},
"not_found_error": {
"code": 404,
"description": "The requested resource was not found.",
"retry": False,
},
"server_error": {
"code": 500,
"description": "An issue occurred on the OpenAI server side.",
"retry": True,
},
"permission_error": {
"code": 403,
"description": "Your API key or token lacks the required permissions for the requested action.",
"retry": False,
},
}
# Get OpenAI API key from environment variable
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable not set")
# Create semaphores and locks
connection_semaphore = asyncio.Semaphore(config["max_connections"])
file_lock = Lock()
async def capture_output(prompt, role, session, min_length=3):
"""
Captures the output from the OpenAI API for a given prompt and role.
Args:
prompt (str): The input prompt for the API.
role (str): The role of the message sender (e.g., "system", "user", "assistant").
session (aiohttp.ClientSession): The aiohttp session object for making API requests.
min_length (int): The minimum length of the output to be considered valid.
Returns:
dict: The captured output as a dictionary containing the complete API response, or None if an error occurs.
"""
while True:
async with connection_semaphore:
data = {
"model": config["model"],
"messages": [{"role": role, "content": prompt}],
"max_tokens": config["max_tokens"],
"n": config["num_responses"],
"temperature": config["temperature"],
"presence_penalty": config["presence_penalty"],
"logprobs": True, "top_logprobs": 1, "include_stop_str_in_output": True, "best_of": 10,
}
headers = {"Authorization": f"Bearer {OPENAI_API_KEY}"}
try:
async with session.post(
config["base_openai_url"], json=data, headers=headers
) as response:
if response.status in [
mapping["code"] for mapping in error_mappings.values()
]:
error_type = next(
key
for key, value in error_mappings.items()
if value["code"] == response.status
)
error_description = error_mappings[error_type]["description"]
if error_mappings[error_type]["retry"]:
logging.warning(
f"{error_type} ({response.status}): {error_description}. Retrying in 5 seconds..."
)
await asyncio.sleep(5)
continue
else:
logging.error(
f"{error_type} ({response.status}): {error_description}. Skipping conversation."
)
return None
else:
response.raise_for_status() # Check for other API request errors
api_response = await response.json()
output = api_response["choices"][0]["message"]["content"]
if len(output) >= min_length:
return {"role": role, "api_response": api_response}
else:
logging.warning(
"Output is too short. Skipping conversation update."
)
return None
except aiohttp.ClientConnectorError as e:
logging.warning(f"Connection error: {e}. Retrying in 5 seconds...")
await asyncio.sleep(5)
except aiohttp.ClientError as e:
logging.warning(f"API request error: {e}. Retrying in 5 seconds...")
await asyncio.sleep(5)
async def save_conversation(conversation):
"""
Saves the conversation to a JSON lines file.
Args:
conversation (dict): The conversation data to be saved.
"""
try:
with file_lock:
with open(config["output_file"], "a") as output_file:
json.dump(conversation, output_file)
output_file.write("\n")
except Exception as e:
logging.error(f"Error saving conversation: {e}")
async def run_conversation(session):
"""
Runs a single conversation with the specified number of turns.
Args:
session (aiohttp.ClientSession): The aiohttp session object for making API requests.
Returns:
dict: The conversation data, or None if an error occurs.
"""
conversation = {"conversation": []}
# This prompt writes a new system prompt for the conversation that follows.
with open(config["system_prompt_file"], "r") as file:
system_prompt = file.read()
system_output = await capture_output(system_prompt, "system", session)
if system_output:
conversation["conversation"].append(system_output)
else:
return None
# Conversation loop
for _ in range(config["num_turns"]):
# Assistant's response
with open(config["system_prompt_file_2"], "r") as file:
system_prompt_2 = file.read()
last_user_prompt = next(
(
item["api_response"]["choices"][0]["message"]["content"]
for item in reversed(conversation["conversation"])
if item["role"] == "user"
),
"",
)
prompt = (
system_prompt_2 + "\n" + last_user_prompt
if last_user_prompt
else system_prompt_2
)
assistant_output = await capture_output(prompt, "assistant", session)
if assistant_output:
conversation["conversation"].append(assistant_output)
else:
return None
# User's response
user_output = await capture_output(
system_prompt + assistant_output["api_response"]["choices"][0]["message"]["content"],
"user",
session,
)
if user_output:
conversation["conversation"].append(user_output)
else:
return None
return conversation
async def run_and_save_conversation(session, conversation_id):
"""
Runs a conversation and saves it to the output file.
Args:
session (aiohttp.ClientSession): The aiohttp session object for making API requests.
conversation_id (int): The ID of the conversation.
"""
logging.info(f"Starting conversation {conversation_id}")
conversation = await run_conversation(session)
if conversation:
logging.info(f"Saving conversation {conversation_id}")
await save_conversation(conversation)
else:
logging.warning(f"Conversation {conversation_id} is None, skipping save")
async def main():
"""
The main function that runs the script.
"""
async with aiohttp.ClientSession() as session:
conversation_semaphore = asyncio.Semaphore(config["max_connections"])
tasks = []
async def run_conversation_with_semaphore(conversation_id):
async with conversation_semaphore:
await run_and_save_conversation(session, conversation_id)
for conversation_id in range(config["num_conversations"]):
task = asyncio.create_task(run_conversation_with_semaphore(conversation_id))
tasks.append(task)
await asyncio.gather(*tasks)
logging.info(f"Collected {config['num_conversations']} conversations.")
if __name__ == "__main__":
asyncio.run(main())
### Make a new python 3.11 environment via pyenv or conda or whatever.
# Install vLLM with CUDA 11.8.
export CUDA_VER="cu118"
export VLLM_VERSION=0.4.2
export PYTHON_VERSION=311
pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+${CUDA_VER}-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x
86_64.whl --extra-index-url https://download.pytorch.org/whl/${CUDA_VER}
export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE
pip install wheel
pip install --no-cache-dir flash-attn --no-build-isolation
python -m vllm.entrypoints.openai.api_server -pp 1 -tp 1 --dtype float16 --max-model-len 4096 --enable-prefix-caching --device cuda --max-log-len 25 --max-logprobs 10 --enforce-eager --model meta-llama/Meta-Llama-3-8B-Instruct --served-model-name gpt-3.5-turbo --api-key=sk-1234 --disable-custom-all-reduce --disable-log-requests --gpu-memory-utilization 1.0 --uvicorn-log-level critical &
### Tested on vLLM engine (v0.4.2)
openai -b http://0.0.0.0:8000/v1/ api completions.create -M 2048 -n1 -m gpt-3.5-turbo -t 0.6 -P0.9 --prompt "Hi. My name is..."
@data2json
Copy link
Author

OpenAI CLI example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment