import os
from typing import List, Dict, Union
from dotenv import load_dotenv
from chain import MinimalChainable, FusionChain
import llm
import json
def build_models():
load_dotenv()
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
sonnet_3_5_model: llm.Model = llm.get_model("claude-3.5-sonnet")
sonnet_3_5_model.key = ANTHROPIC_API_KEY
# Add more models here for FusionChain
sonnet_3_model: llm.Model = llm.get_model("claude-3-sonnet")
sonnet_3_model.key = ANTHROPIC_API_KEY
haiku_3_model: llm.Model = llm.get_model("claude-3-haiku")
haiku_3_model.key = ANTHROPIC_API_KEY
return [sonnet_3_5_model, sonnet_3_model, haiku_3_model]
def prompt(model: llm.Model, prompt: str):
res = model.prompt(
prompt,
temperature=0.5,
)
return res.text()
def prompt_chainable_poc():
sonnet_3_5_model, _, _ = build_models()
result, context_filled_prompts = MinimalChainable.run(
context={"topic": "AI Agents"},
model=sonnet_3_5_model,
callable=prompt,
prompts=[
# prompt #1
"Generate one blog post title about: {{topic}}. Respond in strictly in JSON in this format: {'title': '
'}",
# prompt #2
"Generate one hook for the blog post title: {{output[-1].title}}",
# prompt #3
"""Based on the BLOG_TITLE and BLOG_HOOK, generate the first paragraph of the blog post.
BLOG_TITLE:
{{output[-2].title}}
BLOG_HOOK:
{{output[-1]}}""",
],
)
chained_prompts = MinimalChainable.to_delim_text_file(
"poc_context_filled_prompts", context_filled_prompts
)
chainable_result = MinimalChainable.to_delim_text_file("poc_prompt_results", result)
print(f"\n\nš Prompts~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n\n{chained_prompts}")
print(f"\n\nš Results~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n\n{chainable_result}")
pass
def fusion_chain_poc():
sonnet_3_5_model, sonnet_3_model, haiku_3_model = build_models()
def evaluator(outputs: List[str]) -> tuple[str, List[float]]:
# Simple evaluator that chooses the longest output as the top response
scores = [len(output) for output in outputs]
max_score = max(scores)
normalized_scores = [score / max_score for score in scores]
top_response = outputs[scores.index(max_score)]
return top_response, normalized_scores
result = FusionChain.run(
context={"topic": "AI Agents"},
models=[sonnet_3_5_model, sonnet_3_model, haiku_3_model],
callable=prompt,
prompts=[
# prompt #1
"Generate one blog post title about: {{topic}}. Respond in strictly in JSON in this format: {'title': ''}",
# prompt #2
"Generate one hook for the blog post title: {{output[-1].title}}",
# prompt #3
"""Based on the BLOG_TITLE and BLOG_HOOK, generate the first paragraph of the blog post.
BLOG_TITLE:
{{output[-2].title}}
BLOG_HOOK:
{{output[-1]}}""",
],
evaluator=evaluator,
get_model_name=lambda model: model.model_id,
)
result_dump = result.dict()
print("\n\nš FusionChain Results~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(json.dumps(result_dump, indent=4))
# Write the result to a JSON file
with open("poc_fusion_chain_result.json", "w") as json_file:
json.dump(result_dump, json_file, indent=4)
def main():
prompt_chainable_poc()
# fusion_chain_poc()
if __name__ == "__main__":
main()