import os import asyncio import aiohttp import json import logging from threading import Lock # Logging setup (for better debugging) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # Configuration config = { "system_prompt_file": "system-prompt-3.txt", "system_prompt_file_2": "system-prompt-2.txt", "base_openai_url": "http://0.0.0.0:8000/v1/chat/completions", "max_tokens": 2048, "num_responses": 1, "model": "gpt-3.5-turbo", "temperature": 0.6, "presence_penalty": 0.9, "num_turns": 3, "num_conversations": 2, "max_connections": 1, "max_retries": 3, "logprob": True, "output_file": "conversation_history.jsonl", } # Error mappings error_mappings = { "invalid_request_error": { "code": 400, "description": "Your request was malformed or missing some required parameters.", "retry": False, }, "rate_limit_error": { "code": 429, "description": "You have hit your assigned rate limit.", "retry": True, }, "tokens_exceeded_error": { "code": 403, "description": "You have exceeded the allowed number of tokens in your request.", "retry": False, }, "authentication_error": { "code": 401, "description": "Your API key or token was invalid, expired, or revoked.", "retry": False, }, "not_found_error": { "code": 404, "description": "The requested resource was not found.", "retry": False, }, "server_error": { "code": 500, "description": "An issue occurred on the OpenAI server side.", "retry": True, }, "permission_error": { "code": 403, "description": "Your API key or token lacks the required permissions for the requested action.", "retry": False, }, } # Get OpenAI API key from environment variable OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: raise ValueError("OPENAI_API_KEY environment variable not set") # Create semaphores and locks connection_semaphore = asyncio.Semaphore(config["max_connections"]) file_lock = Lock() async def capture_output(prompt, role, session, min_length=3): """ Captures the output from the OpenAI API for a given prompt and role. Args: prompt (str): The input prompt for the API. role (str): The role of the message sender (e.g., "system", "user", "assistant"). session (aiohttp.ClientSession): The aiohttp session object for making API requests. min_length (int): The minimum length of the output to be considered valid. Returns: dict: The captured output as a dictionary containing the complete API response, or None if an error occurs. """ while True: async with connection_semaphore: data = { "model": config["model"], "messages": [{"role": role, "content": prompt}], "max_tokens": config["max_tokens"], "n": config["num_responses"], "temperature": config["temperature"], "presence_penalty": config["presence_penalty"], "logprobs": True, "top_logprobs": 1, "include_stop_str_in_output": True, "best_of": 10, } headers = {"Authorization": f"Bearer {OPENAI_API_KEY}"} try: async with session.post( config["base_openai_url"], json=data, headers=headers ) as response: if response.status in [ mapping["code"] for mapping in error_mappings.values() ]: error_type = next( key for key, value in error_mappings.items() if value["code"] == response.status ) error_description = error_mappings[error_type]["description"] if error_mappings[error_type]["retry"]: logging.warning( f"{error_type} ({response.status}): {error_description}. Retrying in 5 seconds..." ) await asyncio.sleep(5) continue else: logging.error( f"{error_type} ({response.status}): {error_description}. Skipping conversation." ) return None else: response.raise_for_status() # Check for other API request errors api_response = await response.json() output = api_response["choices"][0]["message"]["content"] if len(output) >= min_length: return {"role": role, "api_response": api_response} else: logging.warning( "Output is too short. Skipping conversation update." ) return None except aiohttp.ClientConnectorError as e: logging.warning(f"Connection error: {e}. Retrying in 5 seconds...") await asyncio.sleep(5) except aiohttp.ClientError as e: logging.warning(f"API request error: {e}. Retrying in 5 seconds...") await asyncio.sleep(5) async def save_conversation(conversation): """ Saves the conversation to a JSON lines file. Args: conversation (dict): The conversation data to be saved. """ try: with file_lock: with open(config["output_file"], "a") as output_file: json.dump(conversation, output_file) output_file.write("\n") except Exception as e: logging.error(f"Error saving conversation: {e}") async def run_conversation(session): """ Runs a single conversation with the specified number of turns. Args: session (aiohttp.ClientSession): The aiohttp session object for making API requests. Returns: dict: The conversation data, or None if an error occurs. """ conversation = {"conversation": []} # This prompt writes a new system prompt for the conversation that follows. with open(config["system_prompt_file"], "r") as file: system_prompt = file.read() system_output = await capture_output(system_prompt, "system", session) if system_output: conversation["conversation"].append(system_output) else: return None # Conversation loop for _ in range(config["num_turns"]): # Assistant's response with open(config["system_prompt_file_2"], "r") as file: system_prompt_2 = file.read() last_user_prompt = next( ( item["api_response"]["choices"][0]["message"]["content"] for item in reversed(conversation["conversation"]) if item["role"] == "user" ), "", ) prompt = ( system_prompt_2 + "\n" + last_user_prompt if last_user_prompt else system_prompt_2 ) assistant_output = await capture_output(prompt, "assistant", session) if assistant_output: conversation["conversation"].append(assistant_output) else: return None # User's response user_output = await capture_output( system_prompt + assistant_output["api_response"]["choices"][0]["message"]["content"], "user", session, ) if user_output: conversation["conversation"].append(user_output) else: return None return conversation async def run_and_save_conversation(session, conversation_id): """ Runs a conversation and saves it to the output file. Args: session (aiohttp.ClientSession): The aiohttp session object for making API requests. conversation_id (int): The ID of the conversation. """ logging.info(f"Starting conversation {conversation_id}") conversation = await run_conversation(session) if conversation: logging.info(f"Saving conversation {conversation_id}") await save_conversation(conversation) else: logging.warning(f"Conversation {conversation_id} is None, skipping save") async def main(): """ The main function that runs the script. """ async with aiohttp.ClientSession() as session: conversation_semaphore = asyncio.Semaphore(config["max_connections"]) tasks = [] async def run_conversation_with_semaphore(conversation_id): async with conversation_semaphore: await run_and_save_conversation(session, conversation_id) for conversation_id in range(config["num_conversations"]): task = asyncio.create_task(run_conversation_with_semaphore(conversation_id)) tasks.append(task) await asyncio.gather(*tasks) logging.info(f"Collected {config['num_conversations']} conversations.") if __name__ == "__main__": asyncio.run(main())