import base64 import functools import operator from typing import * from typing import Annotated, Any, Dict, List, Sequence from dto.chat import * from dto.graph import * from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.tools import BaseTool from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import END, START, StateGraph from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from models import get_text_model from tools._index import get_tools from typing_extensions import TypedDict from utils.common import * class GraphState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] sender: str class Agent: def __init__(self, name: str, llm, prompt: str, tools: List[BaseTool]): self.name = name self.prompt = prompt prompt_template = ChatPromptTemplate.from_messages([ ("system", prompt), MessagesPlaceholder(variable_name="messages"), ]) self.agent = prompt_template | llm.bind_tools(tools) self.tools = tools class GraphBuilder: def __init__(self, graph_config: GraphConfig): self.graph_config = graph_config self.agents: List[Agent] = [] self.graph: CompiledStateGraph | None = None @staticmethod def _router(state) -> Literal["call_tool", "__end__", "continue"]: messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: return "call_tool" if "FINISHED" in last_message.content: return "__end__" return "continue" @staticmethod def _agent_node(state, agent, name): result = agent.invoke(state) result = AIMessage(**result.dict(exclude={"type", "name"}), name=name) return { "messages": [result], "sender": name, } def _create_agent_node(self, agent: Agent): return functools.partial(self._agent_node, agent=agent.agent, name=agent.name) def _add_nodes(self, state_graph): for agent in self.agents: state_graph.add_node(agent.name, self._create_agent_node(agent)) all_tools = [tool for agent in self.agents for tool in agent.tools] state_graph.add_node("call_tool", ToolNode(all_tools)) def _add_edges(self, state_graph): graph_structure = {edge.source: edge.target for edge in self.graph_config.edges} # Add start edge start_node = graph_structure['start'] state_graph.add_edge(START, start_node) for agent in self.agents: next_node = graph_structure.get(agent.name, 'end') next_node = END if next_node.lower() == 'end' else next_node state_graph.add_conditional_edges( agent.name, self._router, { "call_tool": "call_tool", "continue": next_node, "__end__": END } ) # Add edges from call_tool back to agents state_graph.add_conditional_edges( "call_tool", lambda x: x["sender"], {agent.name: agent.name for agent in self.agents} ) def get_stream(self, initial_state: Dict[str, Any], user : User): if not self.graph: raise ValueError("Graph not created. Call build() first.") return self.graph.stream(initial_state,config={"recursion_limit": 150, **get_llm_config(user)}) def get_image(self, xray=True, mermaid=True): if not self.graph: raise ValueError("Graph not created. Call build() first.") img_data = self.graph.get_graph(xray=xray) img_data = img_data.draw_mermaid_png() if mermaid else img_data.draw_png() img_base64 = base64.b64encode(img_data).decode('utf-8') return img_base64 def get_compiled_graph(self): if not self.graph: raise ValueError("Graph not created. Call build() first.") return self.graph.get_graph() def build(self): # Create agents for agent_params in self.graph_config.agents: if(agent_params.name not in ['start','end']): tools : List[BaseTool] = get_tools(tool_names=agent_params.tool_names) tools.extend(agent_params.tools) self.agents.append(Agent( name=agent_params.name, llm=get_text_model(agent_params.llm), prompt=agent_params.prompt, tools=tools, )) #Build & Compile the Graph state_graph = StateGraph(GraphState) self._add_nodes(state_graph) self._add_edges(state_graph) self.graph = state_graph.compile() return self.graph def run(self, user: User, messages : List[Chat]): if not self.graph: raise ValueError("Graph not created. Call build() first.") messages_dict = to_dict_list(messages) events = self.get_stream({"messages" : messages_dict}, user) for event in events: for key, value in event.items(): if 'messages' in value: for message in value['messages']: try: chat_item = Chat( role=map_to_chat_role(message.type), content=message.content, name=getattr(message, 'name', None), tool_calls=[ToolCall(**tool_call) for tool_call in getattr(message, 'tool_calls', [])] if getattr(message, 'tool_calls', None) else None, tool_call_id=getattr(message, 'tool_call_id', None) ) # Convert the Chat object to a dictionary chat_dict = chat_item.model_dump(exclude_unset=True) # Convert any Enum values to strings chat_dict = {k: v.value if isinstance(v, Enum) else v for k, v in chat_dict.items()} # Convert the dictionary to a JSON string json_str = json.dumps(chat_dict) yield json_str except Exception as e: error_chat = Chat( role=ChatRole.assistant, content=f"Error processing message: {str(e)}", name="Error" ) error_dict = error_chat.model_dump(exclude_unset=True) error_dict = {k: v.value if isinstance(v, Enum) else v for k, v in error_dict.items()} error_json = json.dumps(error_dict) yield error_json #Usage example body = GraphConfig( graph_id=3, name="Graph", description="Graph", session_id=None, agents=[ AgentConfig( name="Researcher", prompt="You are a research assistant", llm="claude-sonnet-3.5", tools=[], color="#FF871F" ), AgentConfig( name="Writer", prompt="You are a great writer", llm="claude-sonnet-3.5", tools=[], color="#FF871F" ), ], edges=[ GraphEdge(source="start", target="Researcher"), GraphEdge(source="Researcher", target="Writer"), GraphEdge(source="Writer", target="end") ] ) graph = GraphBuilder(graph_config=body) graph.build() response = GraphResponse( config = graph.get_compiled_config(), base64 = graph.get_image(xray=True, mermaid=True) ) response