Skip to content

Instantly share code, notes, and snippets.

@atzamis
Forked from daveebbelaar/llm_factory.py
Created August 17, 2024 14:09
Show Gist options
  • Save atzamis/6bb2266527327b291476ee1519591dff to your computer and use it in GitHub Desktop.
Save atzamis/6bb2266527327b291476ee1519591dff to your computer and use it in GitHub Desktop.
LLM Factory with Instructor
from typing import Any, Dict, List, Type
import instructor
from anthropic import Anthropic
from config.settings import get_settings
from openai import OpenAI
from pydantic import BaseModel, Field
class LLMFactory:
def __init__(self, provider: str):
self.provider = provider
self.settings = getattr(get_settings(), provider)
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
client_initializers = {
"openai": lambda s: instructor.from_openai(OpenAI(api_key=s.api_key)),
"anthropic": lambda s: instructor.from_anthropic(
Anthropic(api_key=s.api_key)
),
"llama": lambda s: instructor.from_openai(
OpenAI(base_url=s.base_url, api_key=s.api_key),
mode=instructor.Mode.JSON,
),
}
initializer = client_initializers.get(self.provider)
if initializer:
return initializer(self.settings)
raise ValueError(f"Unsupported LLM provider: {self.provider}")
def create_completion(
self, response_model: Type[BaseModel], messages: List[Dict[str, str]], **kwargs
) -> Any:
completion_params = {
"model": kwargs.get("model", self.settings.default_model),
"temperature": kwargs.get("temperature", self.settings.temperature),
"max_retries": kwargs.get("max_retries", self.settings.max_retries),
"max_tokens": kwargs.get("max_tokens", self.settings.max_tokens),
"response_model": response_model,
"messages": messages,
}
return self.client.chat.completions.create(**completion_params)
if __name__ == "__main__":
class CompletionModel(BaseModel):
response: str = Field(description="Your response to the user.")
reasoning: str = Field(description="Explain your reasoning for the response.")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "If it takes 2 hours to dry 1 shirt out in the sun, how long will it take to dry 5 shirts?",
},
]
llm = LLMFactory("openai")
completion = llm.create_completion(
response_model=CompletionModel,
messages=messages,
)
assert isinstance(completion, CompletionModel)
print(f"Response: {completion.response}\n")
print(f"Reasoning: {completion.reasoning}")
from typing import Optional
from pydantic_settings import BaseSettings
from functools import lru_cache
from dotenv import load_dotenv
import os
load_dotenv()
class LLMProviderSettings(BaseSettings):
temperature: float = 0.0
max_tokens: Optional[int] = None
max_retries: int = 3
class OpenAISettings(LLMProviderSettings):
api_key: str = os.getenv("OPENAI_API_KEY")
default_model: str = "gpt-4o"
class AnthropicSettings(LLMProviderSettings):
api_key: str = os.getenv("ANTHROPIC_API_KEY")
default_model: str = "claude-3-5-sonnet-20240620"
max_tokens: int = 1024
class LlamaSettings(LLMProviderSettings):
api_key: str = "key" # required, but not used
default_model: str = "llama3"
base_url: str = "http://localhost:11434/v1"
class Settings(BaseSettings):
app_name: str = "GenAI Project Template"
openai: OpenAISettings = OpenAISettings()
anthropic: AnthropicSettings = AnthropicSettings()
llama: LlamaSettings = LlamaSettings()
@lru_cache
def get_settings():
return Settings()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment