Spaces:
Running
Running
use models
Browse files- streamlit_app.py +8 -4
streamlit_app.py
CHANGED
|
@@ -19,6 +19,10 @@ from document_qa.document_qa_engine import DocumentQAEngine
|
|
| 19 |
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
|
| 20 |
from grobid_client_generic import GrobidClientGeneric
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if 'rqa' not in st.session_state:
|
| 23 |
st.session_state['rqa'] = {}
|
| 24 |
|
|
@@ -117,17 +121,17 @@ def clear_memory():
|
|
| 117 |
# @st.cache_resource
|
| 118 |
def init_qa(model, api_key=None):
|
| 119 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
| 120 |
-
if model
|
| 121 |
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 122 |
if api_key:
|
| 123 |
-
chat = ChatOpenAI(model_name=
|
| 124 |
temperature=0,
|
| 125 |
openai_api_key=api_key,
|
| 126 |
frequency_penalty=0.1)
|
| 127 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
| 128 |
|
| 129 |
else:
|
| 130 |
-
chat = ChatOpenAI(model_name=
|
| 131 |
temperature=0,
|
| 132 |
frequency_penalty=0.1)
|
| 133 |
embeddings = OpenAIEmbeddings()
|
|
@@ -241,7 +245,7 @@ with st.sidebar:
|
|
| 241 |
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 242 |
st.session_state['rqa'][model] = init_qa(model)
|
| 243 |
|
| 244 |
-
elif model
|
| 245 |
if 'OPENAI_API_KEY' not in os.environ:
|
| 246 |
api_key = st.text_input('OpenAI API Key', type="password")
|
| 247 |
st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
|
|
|
|
| 19 |
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
|
| 20 |
from grobid_client_generic import GrobidClientGeneric
|
| 21 |
|
| 22 |
+
OPENAI_MODELS = ['chatgpt-3.5-turbo',
|
| 23 |
+
"gpt-4",
|
| 24 |
+
"gpt-4-1106-preview"]
|
| 25 |
+
|
| 26 |
if 'rqa' not in st.session_state:
|
| 27 |
st.session_state['rqa'] = {}
|
| 28 |
|
|
|
|
| 121 |
# @st.cache_resource
|
| 122 |
def init_qa(model, api_key=None):
|
| 123 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
| 124 |
+
if model in OPENAI_MODELS:
|
| 125 |
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 126 |
if api_key:
|
| 127 |
+
chat = ChatOpenAI(model_name=model,
|
| 128 |
temperature=0,
|
| 129 |
openai_api_key=api_key,
|
| 130 |
frequency_penalty=0.1)
|
| 131 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
| 132 |
|
| 133 |
else:
|
| 134 |
+
chat = ChatOpenAI(model_name=model,
|
| 135 |
temperature=0,
|
| 136 |
frequency_penalty=0.1)
|
| 137 |
embeddings = OpenAIEmbeddings()
|
|
|
|
| 245 |
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 246 |
st.session_state['rqa'][model] = init_qa(model)
|
| 247 |
|
| 248 |
+
elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
|
| 249 |
if 'OPENAI_API_KEY' not in os.environ:
|
| 250 |
api_key = st.text_input('OpenAI API Key', type="password")
|
| 251 |
st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
|