Skip to content

Instantly share code, notes, and snippets.

@virattt
Created November 19, 2024 23:51
Show Gist options
  • Save virattt/9b4b792329f6a1dfd37f1758c979c908 to your computer and use it in GitHub Desktop.
Save virattt/9b4b792329f6a1dfd37f1758c979c908 to your computer and use it in GitHub Desktop.

Revisions

  1. virattt revised this gist Nov 19, 2024. 1 changed file with 1 addition and 10 deletions.
    11 changes: 1 addition & 10 deletions hedge-fund-agent-team-v1-4.ipynb
    Original file line number Diff line number Diff line change
    @@ -4,7 +4,7 @@
    "metadata": {
    "colab": {
    "provenance": [],
    "authorship_tag": "ABX9TyOyH1IFXdKyrc5YNQsKV4D/",
    "authorship_tag": "ABX9TyMxDOZeeIz9VQ497kQOFQVG",
    "include_colab_link": true
    },
    "kernelspec": {
    @@ -918,15 +918,6 @@
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [],
    "metadata": {
    "id": "6BFJLGik_oRu"
    },
    "execution_count": null,
    "outputs": []
    }
    ]
    }
  2. virattt revised this gist Nov 19, 2024. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion hedge-fund-agent-team-v1-4.ipynb
    Original file line number Diff line number Diff line change
    @@ -4,7 +4,8 @@
    "metadata": {
    "colab": {
    "provenance": [],
    "authorship_tag": "ABX9TyOyH1IFXdKyrc5YNQsKV4D/"
    "authorship_tag": "ABX9TyOyH1IFXdKyrc5YNQsKV4D/",
    "include_colab_link": true
    },
    "kernelspec": {
    "name": "python3",
    @@ -15,6 +16,16 @@
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/virattt/9b4b792329f6a1dfd37f1758c979c908/hedge-fund-agent-team-v1-4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "markdown",
    "source": [
  3. virattt created this gist Nov 19, 2024.
    921 changes: 921 additions & 0 deletions hedge-fund-agent-team-v1-4.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,921 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "colab": {
    "provenance": [],
    "authorship_tag": "ABX9TyOyH1IFXdKyrc5YNQsKV4D/"
    },
    "kernelspec": {
    "name": "python3",
    "display_name": "Python 3"
    },
    "language_info": {
    "name": "python"
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "source": [
    "This notebook provides a tutorial on how to use multi-agents with LangGraph.\n",
    "\n",
    "Specifically, we use the **supervisor** pattern, where we have 1 supervisor agent and 3 analyst agents:\n",
    "1. fundamental analyst\n",
    "2. technical analyst\n",
    "3. sentiment analyst\n",
    "\n",
    "This code will be a part of an evolving series.\n",
    "\n",
    "If you have any questions, please message me on X at [virattt](https://twitter.com/virattt)."
    ],
    "metadata": {
    "id": "Xp0Uq2g0uLxb"
    }
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Setup"
    ],
    "metadata": {
    "id": "xYivxWv2b6SW"
    }
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "unPP2JGqble1"
    },
    "outputs": [],
    "source": [
    "%%capture --no-stderr\n",
    "%pip install -U langgraph langchain langchain_openai langchain_experimental langsmith pandas ta"
    ]
    },
    {
    "cell_type": "code",
    "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "\n",
    "def _set_if_undefined(var: str):\n",
    " if not os.environ.get(var):\n",
    " os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n",
    "\n",
    "\n",
    "_set_if_undefined(\"OPENAI_API_KEY\") # For the agent. Get from https://platform.openai.com\n",
    "_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n",
    "_set_if_undefined(\"TAVILY_API_KEY\") # For surfing the web. Get from https://tavily.com"
    ],
    "metadata": {
    "id": "5zJ1jU9-b9WS"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Define agent tools"
    ],
    "metadata": {
    "id": "iADG-Tp3b--h"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from langchain_core.tools import tool\n",
    "from typing import List, Dict, Optional, Union\n",
    "import requests\n",
    "import os\n",
    "from typing import Dict, Union\n",
    "from pydantic import BaseModel, Field\n",
    "import requests\n",
    "from langchain_core.tools import tool\n",
    "\n",
    "import pandas as pd\n",
    "import ta\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "class GetIncomeStatementsInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " period: str = Field(default=\"ttm\", description=\"The period of the income statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
    " limit: int = Field(default=10, description=\"The maximum number of income statements to return. Default is 10.\")\n",
    "\n",
    "@tool(\"get_income_statements\", args_schema=GetIncomeStatementsInput, return_direct=True)\n",
    "def get_income_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get income statements for a ticker with specified period and limit.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = (\n",
    " f'https://api.financialdatasets.ai/financials/income-statements'\n",
    " f'?ticker={ticker}'\n",
    " f'&period={period}'\n",
    " f'&limit={limit}'\n",
    " )\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"income_statements\": [], \"error\": str(e)}\n",
    "\n",
    "class GetBalanceSheetsInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " period: str = Field(default=\"ttm\", description=\"The period of the balance sheets. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
    " limit: int = Field(default=10, description=\"The maximum number of balance sheets to return. Default is 10.\")\n",
    "\n",
    "@tool(\"get_balance_sheets\", args_schema=GetBalanceSheetsInput, return_direct=True)\n",
    "def get_balance_sheets(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get balance sheets for a ticker with specified period and limit.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = (\n",
    " f'https://api.financialdatasets.ai/financials/balance-sheets'\n",
    " f'?ticker={ticker}'\n",
    " f'&period={period}'\n",
    " f'&limit={limit}'\n",
    " )\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"balance_sheets\": [], \"error\": str(e)}\n",
    "\n",
    "class GetCashFlowStatementsInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " period: str = Field(default=\"ttm\", description=\"The period of the cash flow statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
    " limit: int = Field(default=10, description=\"The maximum number of cash flow statements to return. Default is 10.\")\n",
    "\n",
    "@tool(\"get_cash_flow_statements\", args_schema=GetCashFlowStatementsInput, return_direct=True)\n",
    "def get_cash_flow_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get cash flow statements for a ticker with specified period and limit.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = (\n",
    " f'https://api.financialdatasets.ai/financials/cash-flow-statements'\n",
    " f'?ticker={ticker}'\n",
    " f'&period={period}'\n",
    " f'&limit={limit}'\n",
    " )\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"cash_flow_statements\": [], \"error\": str(e)}\n",
    "\n",
    "class GetPricesInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " start_date: str = Field(..., description=\"The start of the price time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
    " end_date: str = Field(..., description=\"The end of the aggregate time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
    " interval: str = Field(default=\"day\", description=\"The time interval of the prices. Valid values are second', 'minute', 'day', 'week', 'month', 'quarter', 'year'.\")\n",
    " interval_multiplier: int = Field(default=1, description=\"The multiplier for the interval. For example, if interval is 'day' and interval_multiplier is 1, the prices will be daily. If interval is 'minute' and interval_multiplier is 5, the prices will be every 5 minutes.\")\n",
    " limit: int = Field(default=5000, description=\"The maximum number of prices to return. The default is 5000 and the maximum is 50000.\")\n",
    "\n",
    "@tool(\"get_stock_prices\", args_schema=GetPricesInput, return_direct=True)\n",
    "def get_stock_prices(ticker: str, start_date: str, end_date: str, interval: str = 'day', interval_multiplier: int = 1, limit: int = 5000) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get prices for a ticker over a given date range and interval.\n",
    " \"\"\"\n",
    "\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    " url = (\n",
    " f\"https://api.financialdatasets.ai/prices\"\n",
    " f\"?ticker={ticker}\"\n",
    " f\"&start_date={start_date}\"\n",
    " f\"&end_date={end_date}\"\n",
    " f\"&interval={interval}\"\n",
    " f\"&interval_multiplier={interval_multiplier}\"\n",
    " f\"&limit={limit}\"\n",
    " )\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " data = response.json()\n",
    " return data\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"prices\": [], \"error\": str(e)}\n",
    "\n",
    "class GetCurrentPriceInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    "\n",
    "@tool(\"get_current_stock_price\", args_schema=GetCurrentPriceInput, return_direct=True)\n",
    "def get_current_stock_price(ticker: str) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get the current (latest) stock price for a ticker.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = f\"https://api.financialdatasets.ai/prices/snapshot?ticker={ticker}\"\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"price\": None, \"error\": str(e)}\n",
    "\n",
    "class GetOptionsChainInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " limit: int = Field(default=10, description=\"The maximum number of options to return. Default is 10.\")\n",
    " strike_price: Optional[float] = Field(default=None, description=\"Optional filter for specific strike price.\")\n",
    " option_type: Optional[str] = Field(default=None, description=\"Optional filter for option type. Valid values are 'call' or 'put'.\")\n",
    "\n",
    "@tool(\"get_options_chain\", args_schema=GetOptionsChainInput, return_direct=True)\n",
    "def get_options_chain(\n",
    " ticker: str,\n",
    " limit: int = 10,\n",
    " strike_price: Optional[float] = None,\n",
    " option_type: Optional[str] = None\n",
    ") -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get options chain data for a ticker with optional filters for strike price and option type.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " params = {\n",
    " 'ticker': ticker,\n",
    " 'limit': limit\n",
    " }\n",
    "\n",
    " if strike_price is not None:\n",
    " params['strike_price'] = strike_price\n",
    " if option_type is not None:\n",
    " params['option_type'] = option_type\n",
    "\n",
    " url = 'https://api.financialdatasets.ai/options/chain'\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key}, params=params)\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"options_chain\": [], \"error\": str(e)}\n",
    "\n",
    "class GetInsiderTradesInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " limit: int = Field(default=10, description=\"The maximum number of insider transactions to return. Default is 10.\")\n",
    "\n",
    "@tool(\"get_insider_trades\", args_schema=GetInsiderTradesInput, return_direct=True)\n",
    "def get_insider_trades(ticker: str, limit: int = 10) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get insider trading transactions for a ticker.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = (\n",
    " f'https://api.financialdatasets.ai/insider-transactions'\n",
    " f'?ticker={ticker}'\n",
    " f'&limit={limit}'\n",
    " )\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={'X-API-Key': api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"insider_transactions\": [], \"error\": str(e)}\n",
    "\n",
    "class GetTechnicalIndicatorsInput(BaseModel):\n",
    " \"\"\"Input schema for technical indicators calculation.\"\"\"\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    " indicator: str = Field(..., description=\"The technical indicator to calculate. Valid values are 'rsi', 'macd', 'sma', 'ema', 'bbands'.\")\n",
    " period: Optional[int] = Field(default=14, description=\"The period for the indicator calculation. Default is 14.\")\n",
    " start_date: Optional[str] = Field(default=None, description=\"Start date in YYYY-MM-DD format.\")\n",
    " end_date: Optional[str] = Field(default=None, description=\"End date in YYYY-MM-DD format.\")\n",
    " interval: Optional[str] = Field(default=\"day\", description=\"The time interval for price data.\")\n",
    " interval_multiplier: Optional[int] = Field(default=1, description=\"Multiplier for the time interval.\")\n",
    "\n",
    "@tool(\"get_technical_indicators\", args_schema=GetTechnicalIndicatorsInput)\n",
    "def get_technical_indicators(\n",
    " ticker: str,\n",
    " indicator: str,\n",
    " period: int = 14,\n",
    " interval: str = \"day\",\n",
    " interval_multiplier: int = 1,\n",
    " start_date: Optional[str] = None,\n",
    " end_date: Optional[str] = None,\n",
    ") -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Calculate technical indicators for a given ticker and time period.\n",
    " Supports RSI, MACD, SMA, EMA, and Bollinger Bands calculations.\n",
    " \"\"\"\n",
    " try:\n",
    " # Fetch historical price data with padding for calculations\n",
    " adjusted_start = (datetime.strptime(start_date, \"%Y-%m-%d\") - timedelta(days=period * 2)).strftime(\"%Y-%m-%d\")\n",
    "\n",
    " price_data = get_stock_prices.invoke({\n",
    " \"ticker\": ticker,\n",
    " \"start_date\": adjusted_start,\n",
    " \"end_date\": end_date,\n",
    " \"interval\": interval,\n",
    " \"interval_multiplier\": interval_multiplier\n",
    " })\n",
    "\n",
    " if \"error\" in price_data:\n",
    " return price_data\n",
    "\n",
    " # Convert to pandas DataFrame with proper datetime handling\n",
    " df = pd.DataFrame(price_data[\"prices\"])\n",
    "\n",
    " # Clean datetime strings by removing timezone\n",
    " df['time'] = df['time'].apply(lambda x: x.split(' EDT')[0].split(' EST')[0])\n",
    " # Convert to datetime after cleaning\n",
    " df['time'] = pd.to_datetime(df['time'])\n",
    " df.set_index('time', inplace=True)\n",
    "\n",
    " result = {\n",
    " \"ticker\": ticker,\n",
    " \"indicator\": indicator,\n",
    " \"period\": period,\n",
    " \"data\": []\n",
    " }\n",
    "\n",
    " # Calculate indicators (no changes here)\n",
    " if indicator.lower() == \"rsi\":\n",
    " rsi = ta.momentum.RSIIndicator(df['close'], window=period)\n",
    " df['indicator_value'] = rsi.rsi()\n",
    " elif indicator.lower() == \"macd\":\n",
    " macd = ta.trend.MACD(\n",
    " df['close'],\n",
    " window_slow=26,\n",
    " window_fast=12,\n",
    " window_sign=9\n",
    " )\n",
    " df['macd_line'] = macd.macd()\n",
    " df['signal_line'] = macd.macd_signal()\n",
    " df['histogram'] = macd.macd_diff()\n",
    " df['indicator_value'] = df['macd_line']\n",
    " elif indicator.lower() == \"sma\":\n",
    " df['indicator_value'] = ta.trend.SMAIndicator(\n",
    " df['close'],\n",
    " window=period\n",
    " ).sma_indicator()\n",
    " elif indicator.lower() == \"ema\":\n",
    " df['indicator_value'] = ta.trend.EMAIndicator(\n",
    " df['close'],\n",
    " window=period\n",
    " ).ema_indicator()\n",
    " elif indicator.lower() == \"bbands\":\n",
    " bb = ta.volatility.BollingerBands(\n",
    " df['close'],\n",
    " window=period,\n",
    " window_dev=2\n",
    " )\n",
    " df['middle_band'] = bb.bollinger_mavg()\n",
    " df['upper_band'] = bb.bollinger_hband()\n",
    " df['lower_band'] = bb.bollinger_lband()\n",
    " df['indicator_value'] = df['middle_band']\n",
    "\n",
    " # Filter to requested date range\n",
    " df = df[start_date:end_date]\n",
    "\n",
    " # Handle NaN values using newer pandas methods\n",
    " df = df.ffill().bfill() # Using newer methods instead of fillna(method=...)\n",
    "\n",
    " for idx, row in df.iterrows():\n",
    " data_point = {\n",
    " \"time\": idx.strftime(\"%Y-%m-%d %H:%M:%S\"), # Clean datetime format\n",
    " \"time_milliseconds\": int(idx.timestamp() * 1000),\n",
    " \"value\": float(row['indicator_value'])\n",
    " }\n",
    "\n",
    " if indicator.lower() == \"macd\":\n",
    " data_point.update({\n",
    " \"signal_line\": float(row['signal_line']),\n",
    " \"histogram\": float(row['histogram'])\n",
    " })\n",
    " elif indicator.lower() == \"bbands\":\n",
    " data_point.update({\n",
    " \"upper_band\": float(row['upper_band']),\n",
    " \"lower_band\": float(row['lower_band'])\n",
    " })\n",
    "\n",
    " result[\"data\"].append(data_point)\n",
    "\n",
    " return result\n",
    "\n",
    " except Exception as e:\n",
    " return {\n",
    " \"ticker\": ticker,\n",
    " \"indicator\": indicator,\n",
    " \"error\": f\"Error calculating {indicator}: {str(e)}\"\n",
    " }\n",
    "\n",
    "class GetFinancialMetricsInput(BaseModel):\n",
    " ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
    "\n",
    "@tool(\"get_financial_metrics\", args_schema=GetFinancialMetricsInput, return_direct=True)\n",
    "def get_financial_metrics(ticker: str) -> Union[Dict, str]:\n",
    " \"\"\"\n",
    " Get key financial metrics snapshot for a ticker, including valuation ratios,\n",
    " profitability margins, returns, and growth metrics.\n",
    " \"\"\"\n",
    " api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
    " if not api_key:\n",
    " raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
    "\n",
    " url = f\"https://api.financialdatasets.ai/financial-metrics/snapshot?ticker={ticker}\"\n",
    "\n",
    " try:\n",
    " response = requests.get(url, headers={\"X-API-Key\": api_key})\n",
    " return response.json()\n",
    " except Exception as e:\n",
    " return {\"ticker\": ticker, \"snapshot\": None, \"error\": str(e)}\n"
    ],
    "metadata": {
    "id": "twLVNqHMb_w9"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "\n",
    "# News tool\n",
    "get_news_tool = TavilySearchResults(max_results=5)"
    ],
    "metadata": {
    "id": "OQ650p7nM6ad"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "# Team 1: Traditional Analysis Track\n",
    "fundamental_analyst_tools = [\n",
    " get_income_statements,\n",
    " get_balance_sheets,\n",
    " get_cash_flow_statements,\n",
    " get_financial_metrics\n",
    "]\n",
    "\n",
    "technical_analyst_tools = [\n",
    " get_stock_prices,\n",
    " get_current_stock_price,\n",
    " get_technical_indicators\n",
    "]\n",
    "\n",
    "sentiment_analyst_tools = [\n",
    " get_options_chain,\n",
    " get_insider_trades,\n",
    " get_news_tool\n",
    "]\n",
    "\n",
    "# Team 2: Specialized Analysis Track\n",
    "quant_strategist_tools = [\n",
    " get_stock_prices,\n",
    " get_technical_indicators,\n",
    " get_financial_metrics\n",
    "]\n",
    "\n",
    "macro_analyst_tools = [\n",
    " get_financial_metrics,\n",
    " get_news_tool,\n",
    " get_technical_indicators\n",
    "]\n",
    "\n",
    "event_driven_analyst_tools = [\n",
    " get_news_tool,\n",
    " get_insider_trades,\n",
    " get_financial_metrics,\n",
    " get_current_stock_price\n",
    "]\n",
    "\n",
    "derivative_analyst_tools = [\n",
    " get_options_chain,\n",
    " get_technical_indicators,\n",
    " get_current_stock_price\n",
    "]"
    ],
    "metadata": {
    "id": "ntBR8eulNR72"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Helper functions"
    ],
    "metadata": {
    "id": "F0N8OtbbcG0L"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "# from langchain_core.messages import HumanMessage\n",
    "\n",
    "# def agent_node(state, agent, name):\n",
    "# result = agent.invoke(state)\n",
    "# return {\n",
    "# \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)]\n",
    "# }"
    ],
    "metadata": {
    "id": "I7LhGDuEcCpC"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Create LangGraph"
    ],
    "metadata": {
    "id": "v5LtjrI-cJ1P"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_openai import ChatOpenAI\n",
    "from pydantic import BaseModel\n",
    "from typing import Literal, Sequence, List, Annotated\n",
    "from typing_extensions import TypedDict\n",
    "import functools\n",
    "import operator\n",
    "from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage\n",
    "from langgraph.graph import END, StateGraph, START\n",
    "from langgraph.prebuilt import create_react_agent\n",
    "\n",
    "# Define team members\n",
    "members = [\"quant_strategist\", \"macro_analyst\", \"event_driven_analyst\", \"derivative_analyst\"]\n",
    "\n",
    "# Create routing prompt template\n",
    "routing_prompt = ChatPromptTemplate.from_messages([\n",
    " (\n",
    " \"system\",\n",
    " \"You are a portfolio manager supervising a hedge fund team with the following analysts:\"\n",
    " \"\\n- quant_strategist: Analyzes price patterns, technical indicators, and financial metrics using quantitative methods\"\n",
    " \"\\n- macro_analyst: Analyzes broad market trends, economic factors, and their impact on financial metrics\"\n",
    " \"\\n- event_driven_analyst: Analyzes special situations, corporate events, insider activity, and real-time price movements\"\n",
    " \"\\n- derivative_analyst: Analyzes options flow, volatility patterns, and derivative pricing\"\n",
    " \"\\nDetermine which analyst(s) should analyze the request. Respond with ONLY the analyst names\"\n",
    " \" separated by commas (e.g., 'quant_strategist,macro_analyst'). Choose analysts based on:\"\n",
    " \"\\n- Use quant_strategist for questions about statistical analysis, factor modeling, or technical patterns\"\n",
    " \"\\n- Use macro_analyst for questions about economic trends, market-wide impacts, or sector analysis\"\n",
    " \"\\n- Use event_driven_analyst for questions about corporate events, news impact, or insider activity\"\n",
    " \"\\n- Use derivative_analyst for questions about options activity, volatility, or derivative strategies\"\n",
    " ),\n",
    " MessagesPlaceholder(variable_name=\"messages\"),\n",
    "])\n",
    "\n",
    "# Create the summary prompt template\n",
    "summary_prompt = ChatPromptTemplate.from_messages(\n",
    " [\n",
    " (\n",
    " \"system\",\n",
    " \"You are a portfolio manager responsible for synthesizing analysis from your team of analysts. \"\n",
    " \"Review all the analysts' reports and provide a comprehensive summary including:\\n\"\n",
    " \"1. Quantitative and statistical insights (when available)\\n\"\n",
    " \"2. Macro and market trend analysis (when available)\\n\"\n",
    " \"3. Event-driven factors and news impact (when available)\\n\"\n",
    " \"4. Derivatives and volatility analysis (when available)\\n\"\n",
    " \"5. Overall investment recommendation\\n\"\n",
    " \"Make sure to highlight any discrepancies or conflicting signals between different analyses.\"\n",
    " ),\n",
    " MessagesPlaceholder(variable_name=\"messages\"),\n",
    " (\n",
    " \"human\",\n",
    " \"Based on all the analyst reports above, provide a comprehensive summary and investment recommendation.\"\n",
    " ),\n",
    " ]\n",
    ")\n",
    "\n",
    "# Initialize LLM\n",
    "llm = ChatOpenAI(model=\"gpt-4\")\n",
    "\n",
    "class AgentState(TypedDict):\n",
    " messages: Annotated[Sequence[BaseMessage], operator.add]\n",
    " selected_analysts: List[str]\n",
    " current_analyst_idx: int\n",
    "\n",
    "def supervisor_router(state):\n",
    " \"\"\"Route to appropriate analyst(s) based on the query\"\"\"\n",
    " # Create the routing chain\n",
    " routing_chain = routing_prompt | llm\n",
    "\n",
    " # Get the routing decision\n",
    " result = routing_chain.invoke(state)\n",
    " selected_analysts = [a.strip() for a in result.content.strip().split(',')]\n",
    "\n",
    " # Add routing message to state\n",
    " message = SystemMessage(\n",
    " content=f\"Routing query to: {', '.join(selected_analysts)}\",\n",
    " name=\"supervisor\"\n",
    " )\n",
    "\n",
    " return {\n",
    " \"messages\": state[\"messages\"] + [message],\n",
    " \"selected_analysts\": selected_analysts,\n",
    " \"current_analyst_idx\": 0\n",
    " }\n",
    "\n",
    "def get_next_step(state):\n",
    " \"\"\"Determine the next step in the workflow\"\"\"\n",
    " if not state[\"selected_analysts\"]:\n",
    " return \"final_summary\"\n",
    "\n",
    " current_idx = state[\"current_analyst_idx\"]\n",
    " if current_idx >= len(state[\"selected_analysts\"]):\n",
    " return \"final_summary\"\n",
    "\n",
    " return state[\"selected_analysts\"][current_idx]\n",
    "\n",
    "def agent_node(state, agent, name):\n",
    " \"\"\"Generic analyst node that updates the current_analyst_idx after completion\"\"\"\n",
    " result = agent.invoke(state)\n",
    "\n",
    " return {\n",
    " \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)],\n",
    " \"selected_analysts\": state[\"selected_analysts\"],\n",
    " \"current_analyst_idx\": state[\"current_analyst_idx\"] + 1\n",
    " }\n",
    "\n",
    "def final_summary_agent(state):\n",
    " \"\"\"Create final summary of all analyst reports\"\"\"\n",
    " summary_chain = summary_prompt | llm\n",
    " result = summary_chain.invoke(state)\n",
    " return {\n",
    " \"messages\": [HumanMessage(content=result.content, name=\"portfolio_manager\")],\n",
    " \"selected_analysts\": state[\"selected_analysts\"],\n",
    " \"current_analyst_idx\": state[\"current_analyst_idx\"]\n",
    " }\n",
    "\n",
    "# Initialize workflow\n",
    "workflow = StateGraph(AgentState)\n",
    "\n",
    "# Create the analysts with their specific tools\n",
    "quant_strategist = create_react_agent(llm, tools=quant_strategist_tools)\n",
    "quant_strategist_node = functools.partial(agent_node, agent=quant_strategist, name=\"quant_strategist\")\n",
    "\n",
    "macro_analyst = create_react_agent(llm, tools=macro_analyst_tools)\n",
    "macro_analyst_node = functools.partial(agent_node, agent=macro_analyst, name=\"macro_analyst\")\n",
    "\n",
    "event_driven_analyst = create_react_agent(llm, tools=event_driven_analyst_tools)\n",
    "event_driven_analyst_node = functools.partial(agent_node, agent=event_driven_analyst, name=\"event_driven_analyst\")\n",
    "\n",
    "derivative_analyst = create_react_agent(llm, tools=derivative_analyst_tools)\n",
    "derivative_analyst_node = functools.partial(agent_node, agent=derivative_analyst, name=\"derivative_analyst\")\n",
    "\n",
    "# Add nodes\n",
    "workflow.add_node(\"supervisor\", supervisor_router)\n",
    "workflow.add_node(\"quant_strategist\", quant_strategist_node)\n",
    "workflow.add_node(\"macro_analyst\", macro_analyst_node)\n",
    "workflow.add_node(\"event_driven_analyst\", event_driven_analyst_node)\n",
    "workflow.add_node(\"derivative_analyst\", derivative_analyst_node)\n",
    "workflow.add_node(\"final_summary\", final_summary_agent)\n",
    "\n",
    "# Add conditional edges\n",
    "workflow.add_conditional_edges(\n",
    " \"supervisor\",\n",
    " get_next_step,\n",
    " {\n",
    " \"quant_strategist\": \"quant_strategist\",\n",
    " \"macro_analyst\": \"macro_analyst\",\n",
    " \"event_driven_analyst\": \"event_driven_analyst\",\n",
    " \"derivative_analyst\": \"derivative_analyst\",\n",
    " \"final_summary\": \"final_summary\"\n",
    " }\n",
    ")\n",
    "\n",
    "# Add conditional edges from each analyst back to the router function\n",
    "for analyst in members:\n",
    " workflow.add_conditional_edges(\n",
    " analyst,\n",
    " get_next_step,\n",
    " {\n",
    " \"quant_strategist\": \"quant_strategist\",\n",
    " \"macro_analyst\": \"macro_analyst\",\n",
    " \"event_driven_analyst\": \"event_driven_analyst\",\n",
    " \"derivative_analyst\": \"derivative_analyst\",\n",
    " \"final_summary\": \"final_summary\"\n",
    " }\n",
    " )\n",
    "\n",
    "# Add entry point and final edges\n",
    "workflow.add_edge(START, \"supervisor\")\n",
    "workflow.add_edge(\"final_summary\", END)\n",
    "\n",
    "# Compile the graph\n",
    "graph = workflow.compile()"
    ],
    "metadata": {
    "id": "TT2AggDicQt6"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Code to pretty print Agent output"
    ],
    "metadata": {
    "id": "V7_AEtWz56-n"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from typing import Dict, Any\n",
    "import json\n",
    "import re\n",
    "from langchain_core.messages import HumanMessage\n",
    "from rich.console import Console\n",
    "from rich.panel import Panel\n",
    "from rich.text import Text\n",
    "from rich.rule import Rule\n",
    "\n",
    "console = Console()\n",
    "\n",
    "def format_bold_text(content: str) -> Text:\n",
    " \"\"\"Convert **text** to rich Text with bold formatting.\"\"\"\n",
    " text = Text()\n",
    " pattern = r'\\*\\*(.*?)\\*\\*'\n",
    "\n",
    " # Split the text by the bold markers\n",
    " parts = re.split(pattern, content)\n",
    "\n",
    " # Alternate between regular and bold text\n",
    " for i, part in enumerate(parts):\n",
    " if i % 2 == 0:\n",
    " text.append(part)\n",
    " else:\n",
    " text.append(part, style=\"bold\")\n",
    "\n",
    " return text\n",
    "\n",
    "def format_message_content(content: str) -> Union[str, Text]:\n",
    " \"\"\"Format the message content, handling JSON and text with bold markers.\"\"\"\n",
    " try:\n",
    " # Try to parse as JSON for prettier formatting\n",
    " data = json.loads(content)\n",
    " return json.dumps(data, indent=2)\n",
    " except:\n",
    " # If not JSON, check for bold markers\n",
    " if '**' in content:\n",
    " return format_bold_text(content)\n",
    " return content\n",
    "\n",
    "def format_agent_message(message: HumanMessage) -> Union[str, Text]:\n",
    " \"\"\"Format a single agent message.\"\"\"\n",
    " return format_message_content(message.content)\n",
    "\n",
    "def get_agent_title(agent: str, message: HumanMessage) -> str:\n",
    " \"\"\"Get the title for the agent panel, with fallback handling.\"\"\"\n",
    " base_title = agent.replace('_', ' ').title()\n",
    "\n",
    " if hasattr(message, 'name') and message.name is not None:\n",
    " try:\n",
    " return message.name.replace('_', ' ').title()\n",
    " except:\n",
    " return base_title\n",
    " return base_title\n",
    "\n",
    "def print_step(step: Dict[str, Any]) -> None:\n",
    " \"\"\"Pretty print a single step of the agent execution.\"\"\"\n",
    " for agent, data in step.items():\n",
    " # Handle supervisor steps\n",
    " if 'next' in data:\n",
    " next_agent = data['next']\n",
    " text = Text()\n",
    " text.append(\"Portfolio Manager \", style=\"bold magenta\")\n",
    " text.append(\"assigns next task to \", style=\"white\")\n",
    "\n",
    " if next_agent == \"final_summary\":\n",
    " text.append(\"FINAL SUMMARY\", style=\"bold yellow\")\n",
    " elif next_agent == \"END\":\n",
    " text.append(\"END\", style=\"bold red\")\n",
    " else:\n",
    " text.append(f\"{next_agent}\", style=\"bold green\")\n",
    "\n",
    " console.print(Panel(\n",
    " text,\n",
    " title=\"[bold blue]Supervision Step\",\n",
    " border_style=\"blue\"\n",
    " ))\n",
    "\n",
    " # Handle agent responses and final summary\n",
    " if 'messages' in data:\n",
    " message = data['messages'][0]\n",
    " formatted_content = format_agent_message(message)\n",
    "\n",
    " if agent == \"final_summary\":\n",
    " # Final summary formatting\n",
    " console.print(Rule(style=\"yellow\", title=\"Portfolio Analysis\"))\n",
    " console.print(Panel(\n",
    " formatted_content,\n",
    " title=\"[bold yellow]Investment Summary and Recommendation\",\n",
    " border_style=\"yellow\",\n",
    " padding=(1, 2)\n",
    " ))\n",
    " console.print(Rule(style=\"yellow\"))\n",
    " else:\n",
    " # Regular analyst reports\n",
    " title = get_agent_title(agent, message)\n",
    " console.print(Panel(\n",
    " formatted_content,\n",
    " title=f\"[bold blue]{title} Report\",\n",
    " border_style=\"green\"\n",
    " ))\n",
    "\n",
    "def stream_agent_execution(graph, input_data: Dict, config: Dict) -> None:\n",
    " \"\"\"Stream and pretty print the agent execution.\"\"\"\n",
    " console.print(\"\\n[bold blue]Starting Agent Execution...[/bold blue]\\n\")\n",
    "\n",
    " for step in graph.stream(input_data, config):\n",
    " if \"__end__\" not in step:\n",
    " print_step(step)\n",
    " console.print(\"\\n\")\n",
    "\n",
    " console.print(\"[bold blue]Analysis Complete[/bold blue]\\n\")"
    ],
    "metadata": {
    "id": "t2E2mnnJ5LaN"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Run the Hedge Fund team"
    ],
    "metadata": {
    "id": "Y1_IZnAUTAHw"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "input_data = {\n",
    " \"messages\": [HumanMessage(content=\"What is AAPL's current price and latest revenue?\")]\n",
    "}\n",
    "config = {\"recursion_limit\": 10}\n",
    "stream_agent_execution(graph, input_data, config)"
    ],
    "metadata": {
    "id": "gLUCOhL85Lip"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "input_data = {\n",
    " \"messages\": [HumanMessage(content=\"What is AAPL's latest news?\")]\n",
    "}\n",
    "config = {\"recursion_limit\": 10}\n",
    "stream_agent_execution(graph, input_data, config)"
    ],
    "metadata": {
    "id": "pYyfbkCLNmGD"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [],
    "metadata": {
    "id": "6BFJLGik_oRu"
    },
    "execution_count": null,
    "outputs": []
    }
    ]
    }