Streamlit、LangChain(OpenAI API)で、ChatGPTのようなストリーミング応答を実装するためのコードをメモしておきます。
LangChain 0.0.260で動作確認しました。
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.base import BaseCallbackHandler
import streamlit as st
model_name = 'gpt-4'
class StreamHandler(BaseCallbackHandler):
"""
新しいトークンをテキストに追加し、コンテナ内に更新されたテキストを表示するためのコールバックハンドラ。
"""
def __init__(self, container, init_text=""):
self.container = container
self.text = init_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
def main():
query = st.text_input("メッセージを入力してください。")
send_button = st.button("送信")
if send_button and query:
send_button = False
container = st.empty()
stream_handler = StreamHandler(container)
llm = ChatOpenAI(model_name=model_name, streaming=True, callbacks=[stream_handler], temperature=0)
prompt = PromptTemplate(
input_variables=["query"],
template="""
以下の質問に答えてください。
質問:
{query}
""",
)
chain = LLMChain(llm=llm, prompt=prompt)
chain.run(query)
if __name__ == '__main__':
main()
StreamHandler
のコンストラクタで、2つの引数 (container
と init_text
) を受け取ります。init_text
はデフォルトで空の文字列です。このコンストラクタは、渡された container
と init_text
をインスタンス変数に割り当てるだけのシンプルなものです。
on_llm_new_token
というメソッドは、新しい token
(文字列として型指定)を受け取り、現在のテキスト (self.text
) にそのトークンを追加し、その後、self.container.markdown(self.text)
で更新されたテキストをマークダウンとして container
に表示します。**kwargs
は、追加のキーワード引数を受け取るためのものですが、このメソッドの中では特に使用していません。
動作は以下のような感じです。
英語のプロンプトのコードをgithubにあげてあります。
GitHub - rysk310/gpt-streaming
Contribute to rysk310/gpt-streaming development by creating an account on GitHub.