|
|
@@ -0,0 +1,292 @@ |
|
|
"""QA Chatbot streaming using FastAPI, LangChain Expression Language , OpenAI, and Chroma. |
|
|
|
|
|
Features |
|
|
-------- |
|
|
- Persistent Chat Memory: |
|
|
Stores chat history in a local file. |
|
|
- Persistent Vector Store: |
|
|
Stores document embeddings in a local vector store. |
|
|
- Standalone Question Generation: |
|
|
Rephrases follow-up questions to standalone questions in their original language. |
|
|
- Document Retrieval: |
|
|
Searches and retrieves relevant documents based on user queries. |
|
|
- Context-Aware Responses: |
|
|
Generates responses based on a combined context from relevant documents. |
|
|
- Streaming Responses: |
|
|
Streams responses in real time either as plain text or as Server-Sent Events (SSE). |
|
|
SSE also sends the relevant documents as context. |
|
|
|
|
|
Next Steps |
|
|
---------- |
|
|
- Add a proper exception handling mechanism during the streaming process. |
|
|
- Add pruning to the conversation buffer memory to prevent it from growing too large. |
|
|
- Combine documents using a more sophisticated method than simply concatenating them. |
|
|
|
|
|
Usage |
|
|
----- |
|
|
1. Install dependencies: |
|
|
```bash |
|
|
pip install fastapi==0.99.1 uvicorn==0.23.2 python-dotenv==1.0.0 chromadb==0.4.5 tiktoken==0.4.0 langchain==0.0.257 openai==0.27.8 |
|
|
``` |
|
|
|
|
|
or |
|
|
|
|
|
```bash |
|
|
poetry install |
|
|
``` |
|
|
|
|
|
2. Run the server: |
|
|
```bash |
|
|
uvicorn main:app --reload |
|
|
``` |
|
|
3. curl the server: |
|
|
|
|
|
With plain text: |
|
|
|
|
|
```bash |
|
|
curl --no-buffer -X 'POST' \ |
|
|
'http://localhost:8000/chat' \ |
|
|
-H 'accept: text/plain' \ |
|
|
-H 'Content-Type: application/json' \ |
|
|
-d '{ |
|
|
"session_id": "session_1", |
|
|
"message": "who'\''s playing in the river?" |
|
|
}' |
|
|
``` |
|
|
|
|
|
With SSE: |
|
|
|
|
|
```bash |
|
|
curl --no-buffer -X 'POST' \ |
|
|
'http://localhost:8000/chat/sse/' \ |
|
|
-H 'accept: text/event-stream' \ |
|
|
-H 'Content-Type: application/json' \ |
|
|
-d '{ |
|
|
"session_id": "session_2", |
|
|
"message": "who'\''s playing in the garden?" |
|
|
}' |
|
|
|
|
|
Cheers! |
|
|
@jvelezmagic""" |
|
|
|
|
|
import os |
|
|
from functools import lru_cache |
|
|
from typing import AsyncGenerator, Literal |
|
|
|
|
|
from fastapi import Depends, FastAPI |
|
|
from fastapi.responses import StreamingResponse |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.embeddings import OpenAIEmbeddings |
|
|
from langchain.memory import ConversationBufferMemory, FileChatMessageHistory |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.schema import BaseChatMessageHistory, Document, format_document |
|
|
from langchain.schema.output_parser import StrOutputParser |
|
|
from langchain.vectorstores import Chroma |
|
|
from pydantic import BaseModel, BaseSettings |
|
|
|
|
|
|
|
|
class Settings(BaseSettings): |
|
|
openai_api_key: str |
|
|
|
|
|
class Config: # type: ignore |
|
|
env_file = ".env" |
|
|
env_file_encoding = "utf-8" |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
session_id: str |
|
|
message: str |
|
|
|
|
|
|
|
|
class ChatSSEResponse(BaseModel): |
|
|
type: Literal["context", "start", "streaming", "end", "error"] |
|
|
value: str | list[Document] |
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_settings() -> Settings: |
|
|
return Settings() # type: ignore |
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_vectorstore() -> Chroma: |
|
|
settings = get_settings() |
|
|
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key) # type: ignore |
|
|
|
|
|
vectorstore = Chroma( |
|
|
collection_name="chroma", |
|
|
embedding_function=embeddings, |
|
|
persist_directory="chroma", |
|
|
) |
|
|
|
|
|
return vectorstore |
|
|
|
|
|
|
|
|
def combine_documents( |
|
|
docs: list[Document], |
|
|
document_prompt: PromptTemplate = PromptTemplate.from_template("{page_content}"), |
|
|
document_separator: str = "\n\n", |
|
|
) -> str: |
|
|
doc_strings = [format_document(doc, document_prompt) for doc in docs] |
|
|
return document_separator.join(doc_strings) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="QA Chatbot Streaming using FastAPI, LangChain Expression Language , OpenAI, and Chroma", |
|
|
version="0.1.0", |
|
|
) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event() -> None: |
|
|
vectorstore = get_vectorstore() |
|
|
is_collection_empty: bool = vectorstore._collection.count() == 0 # type: ignore |
|
|
|
|
|
if is_collection_empty: |
|
|
vectorstore.add_texts( # type: ignore |
|
|
texts=[ |
|
|
"Cats are playing in the garden.", |
|
|
"Dogs are playing in the river.", |
|
|
"Dogs and cats are mortal enemies, but they often play together.", |
|
|
] |
|
|
) |
|
|
|
|
|
if not os.path.exists("message_store"): |
|
|
os.mkdir("message_store") |
|
|
|
|
|
|
|
|
async def generate_standalone_question( |
|
|
chat_history: str, question: str, settings: Settings |
|
|
) -> str: |
|
|
prompt = PromptTemplate.from_template( |
|
|
template="""Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. |
|
|
Chat History: |
|
|
{chat_history} |
|
|
Follow Up Input: {question} |
|
|
Standalone question:""" |
|
|
) |
|
|
llm = ChatOpenAI(temperature=0, openai_api_key=settings.openai_api_key) |
|
|
|
|
|
chain = prompt | llm | StrOutputParser() # type: ignore |
|
|
|
|
|
return await chain.ainvoke( # type: ignore |
|
|
{ |
|
|
"chat_history": chat_history, |
|
|
"question": question, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
async def search_relevant_documents(query: str, k: int = 5) -> list[Document]: |
|
|
vectorstore = get_vectorstore() |
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
return await retriever.aget_relevant_documents(query=query, k=k) |
|
|
|
|
|
|
|
|
async def generate_response( |
|
|
context: str, chat_memory: BaseChatMessageHistory, message: str, settings: Settings |
|
|
) -> AsyncGenerator[str, None]: |
|
|
prompt = PromptTemplate.from_template( |
|
|
"""Answer the question based only on the following context: |
|
|
{context} |
|
|
Question: {question}""" |
|
|
) |
|
|
|
|
|
llm = ChatOpenAI(temperature=0, openai_api_key=settings.openai_api_key) |
|
|
|
|
|
chain = prompt | llm # type: ignore |
|
|
|
|
|
reponse = "" |
|
|
async for token in chain.astream({"context": context, "question": message}): # type: ignore |
|
|
yield token.content |
|
|
reponse += token.content |
|
|
|
|
|
chat_memory.add_user_message(message=message) |
|
|
chat_memory.add_ai_message(message=reponse) |
|
|
|
|
|
|
|
|
async def generate_sse_response( |
|
|
context: list[Document], |
|
|
chat_memory: BaseChatMessageHistory, |
|
|
message: str, |
|
|
settings: Settings, |
|
|
) -> AsyncGenerator[str, ChatSSEResponse]: |
|
|
prompt = PromptTemplate.from_template( |
|
|
"""Answer the question based only on the following context: |
|
|
{context} |
|
|
Question: {question}""" |
|
|
) |
|
|
|
|
|
llm = ChatOpenAI(temperature=0, openai_api_key=settings.openai_api_key) |
|
|
|
|
|
chain = prompt | llm # type: ignore |
|
|
|
|
|
reponse = "" |
|
|
yield ChatSSEResponse(type="context", value=context).json() |
|
|
try: |
|
|
yield ChatSSEResponse(type="start", value="").json() |
|
|
async for token in chain.astream({"context": context, "question": message}): # type: ignore |
|
|
yield ChatSSEResponse(type="streaming", value=token.content).json() |
|
|
reponse += token.content |
|
|
|
|
|
yield ChatSSEResponse(type="end", value="").json() |
|
|
chat_memory.add_user_message(message=message) |
|
|
chat_memory.add_ai_message(message=reponse) |
|
|
except Exception as e: # TODO: Add proper exception handling |
|
|
yield ChatSSEResponse(type="error", value=str(e)).json() |
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat( |
|
|
request: ChatRequest, settings: Settings = Depends(get_settings) |
|
|
) -> StreamingResponse: |
|
|
memory_key = f"./message_store/{request.session_id}.json" |
|
|
|
|
|
chat_memory = FileChatMessageHistory(file_path=memory_key) |
|
|
memory = ConversationBufferMemory(chat_memory=chat_memory, return_messages=False) |
|
|
|
|
|
standalone_question = await generate_standalone_question( |
|
|
chat_history=memory.buffer, question=request.message, settings=settings |
|
|
) |
|
|
|
|
|
relevant_documents = await search_relevant_documents(query=standalone_question) |
|
|
|
|
|
combined_documents = combine_documents(relevant_documents) |
|
|
|
|
|
return StreamingResponse( |
|
|
generate_response( |
|
|
context=combined_documents, |
|
|
chat_memory=chat_memory, |
|
|
message=request.message, |
|
|
settings=settings, |
|
|
), |
|
|
media_type="text/plain", |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/chat/sse/") |
|
|
async def chat_sse( |
|
|
request: ChatRequest, settings: Settings = Depends(get_settings) |
|
|
) -> StreamingResponse: |
|
|
memory_key = f"./message_store/{request.session_id}.json" |
|
|
|
|
|
chat_memory = FileChatMessageHistory(file_path=memory_key) |
|
|
memory = ConversationBufferMemory(chat_memory=chat_memory, return_messages=False) |
|
|
|
|
|
standalone_question = await generate_standalone_question( |
|
|
chat_history=memory.buffer, question=request.message, settings=settings |
|
|
) |
|
|
|
|
|
relevant_documents = await search_relevant_documents(query=standalone_question, k=2) |
|
|
|
|
|
return StreamingResponse( |
|
|
generate_sse_response( |
|
|
context=relevant_documents, |
|
|
chat_memory=chat_memory, |
|
|
message=request.message, |
|
|
settings=settings, |
|
|
), |
|
|
media_type="text/event-stream", |
|
|
) |