import os
os.environ["OPENAI_API_KEY"] = ""
from flask import Flask, Response, request
import threading
import queue
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import AIMessage, HumanMessage, SystemMessage
app = Flask(__name__)
@app.route('/')
def index():
    # just for the example, html is included directly, move to .html file
    return Response('''
Flask Streaming Langchain Example
    
    
    
''', mimetype='text/html')
class ThreadedGenerator:
    def __init__(self):
        self.queue = queue.Queue()
    def __iter__(self):
        return self
    def __next__(self):
        item = self.queue.get()
        if item is StopIteration: raise item
        return item
    def send(self, data):
        self.queue.put(data)
    def close(self):
        self.queue.put(StopIteration)
class ChainStreamHandler(StreamingStdOutCallbackHandler):
    def __init__(self, gen):
        super().__init__()
        self.gen = gen
    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)
def llm_thread(g, prompt):
    try:
        chat = ChatOpenAI(
            verbose=True,
            streaming=True,
            callbacks=[ChainStreamHandler(g)],
            temperature=0.7,
        )
        chat([HumanMessage(content=prompt)])
    finally:
        g.close()
def chain(prompt):
    g = ThreadedGenerator()
    threading.Thread(target=llm_thread, args=(g, prompt)).start()
    return g
@app.route('/chain', methods=['POST'])
def _chain():
    return Response(chain(request.json['prompt']), mimetype='text/plain')
if __name__ == '__main__':
    app.run(threaded=True, debug=True)