Spaces:
Running
Running
move settings on the sidebar, allow env variables
Browse files- streamlit_app.py +78 -57
streamlit_app.py
CHANGED
|
@@ -18,11 +18,14 @@ from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_t
|
|
| 18 |
from grobid_client_generic import GrobidClientGeneric
|
| 19 |
|
| 20 |
if 'rqa' not in st.session_state:
|
| 21 |
-
st.session_state['rqa'] =
|
| 22 |
|
| 23 |
if 'api_key' not in st.session_state:
|
| 24 |
st.session_state['api_key'] = False
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
if 'doc_id' not in st.session_state:
|
| 27 |
st.session_state['doc_id'] = None
|
| 28 |
|
|
@@ -42,13 +45,16 @@ if 'git_rev' not in st.session_state:
|
|
| 42 |
if "messages" not in st.session_state:
|
| 43 |
st.session_state.messages = []
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def new_file():
|
| 47 |
st.session_state['loaded_embeddings'] = None
|
| 48 |
st.session_state['doc_id'] = None
|
| 49 |
|
| 50 |
|
| 51 |
-
@st.cache_resource
|
| 52 |
def init_qa(model):
|
| 53 |
if model == 'chatgpt-3.5-turbo':
|
| 54 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
|
@@ -67,6 +73,7 @@ def init_qa(model):
|
|
| 67 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 68 |
else:
|
| 69 |
st.error("The model was not loaded properly. Try reloading. ")
|
|
|
|
| 70 |
|
| 71 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 72 |
|
|
@@ -94,7 +101,6 @@ def init_ner():
|
|
| 94 |
grobid_quantities_client=quantities_client,
|
| 95 |
grobid_superconductors_client=materials_client
|
| 96 |
)
|
| 97 |
-
|
| 98 |
return gqa
|
| 99 |
|
| 100 |
|
|
@@ -125,51 +131,52 @@ def play_old_messages():
|
|
| 125 |
|
| 126 |
is_api_key_provided = st.session_state['api_key']
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
| 141 |
-
api_key = st.
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if api_key:
|
| 144 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
elif model == 'chatgpt-3.5-turbo':
|
| 148 |
-
api_key = st.
|
| 149 |
-
|
|
|
|
|
|
|
| 150 |
if api_key:
|
| 151 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
st.title("📝 Scientific Document Insight Q&A")
|
| 158 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
disabled=not is_api_key_provided,
|
| 164 |
-
help="The full-text is extracted using Grobid. ")
|
| 165 |
-
with radio_col:
|
| 166 |
-
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0,
|
| 167 |
-
help="LLM will respond the question, Embedding will show the "
|
| 168 |
-
"paragraphs relevant to the question in the paper.")
|
| 169 |
-
with context_col:
|
| 170 |
-
context_size = st.slider("Context size", 3, 10, value=4,
|
| 171 |
-
help="Number of paragraphs to consider when answering a question",
|
| 172 |
-
disabled=not uploaded_file)
|
| 173 |
|
| 174 |
question = st.chat_input(
|
| 175 |
"Ask something about the article",
|
|
@@ -178,14 +185,29 @@ question = st.chat_input(
|
|
| 178 |
)
|
| 179 |
|
| 180 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
st.header("Documentation")
|
| 182 |
st.markdown("https://github.com/lfoppiano/document-qa")
|
| 183 |
st.markdown(
|
| 184 |
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
|
| 185 |
|
| 186 |
-
st.markdown(
|
| 187 |
-
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
|
| 188 |
-
unsafe_allow_html=True)
|
| 189 |
if st.session_state['git_rev'] != "unknown":
|
| 190 |
st.markdown("**Revision number**: [" + st.session_state[
|
| 191 |
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
|
|
@@ -203,9 +225,9 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
| 203 |
tmp_file = NamedTemporaryFile()
|
| 204 |
tmp_file.write(bytearray(binary))
|
| 205 |
# hash = get_file_hash(tmp_file.name)[:10]
|
| 206 |
-
st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name,
|
| 207 |
-
|
| 208 |
-
|
| 209 |
st.session_state['loaded_embeddings'] = True
|
| 210 |
st.session_state.messages = []
|
| 211 |
|
|
@@ -226,27 +248,26 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
| 226 |
text_response = None
|
| 227 |
if mode == "Embeddings":
|
| 228 |
with st.spinner("Generating LLM response..."):
|
| 229 |
-
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
|
| 230 |
-
|
| 231 |
elif mode == "LLM":
|
| 232 |
with st.spinner("Generating response..."):
|
| 233 |
-
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
|
| 234 |
-
|
| 235 |
|
| 236 |
if not text_response:
|
| 237 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
| 238 |
|
| 239 |
with st.chat_message("assistant"):
|
| 240 |
if mode == "LLM":
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
text_response = decorated_text
|
| 250 |
else:
|
| 251 |
st.write(text_response)
|
| 252 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
|
|
|
| 18 |
from grobid_client_generic import GrobidClientGeneric
|
| 19 |
|
| 20 |
if 'rqa' not in st.session_state:
|
| 21 |
+
st.session_state['rqa'] = {}
|
| 22 |
|
| 23 |
if 'api_key' not in st.session_state:
|
| 24 |
st.session_state['api_key'] = False
|
| 25 |
|
| 26 |
+
if 'api_keys' not in st.session_state:
|
| 27 |
+
st.session_state['api_keys'] = {}
|
| 28 |
+
|
| 29 |
if 'doc_id' not in st.session_state:
|
| 30 |
st.session_state['doc_id'] = None
|
| 31 |
|
|
|
|
| 45 |
if "messages" not in st.session_state:
|
| 46 |
st.session_state.messages = []
|
| 47 |
|
| 48 |
+
if 'ner_processing' not in st.session_state:
|
| 49 |
+
st.session_state['ner_processing'] = False
|
| 50 |
+
|
| 51 |
|
| 52 |
def new_file():
|
| 53 |
st.session_state['loaded_embeddings'] = None
|
| 54 |
st.session_state['doc_id'] = None
|
| 55 |
|
| 56 |
|
| 57 |
+
# @st.cache_resource
|
| 58 |
def init_qa(model):
|
| 59 |
if model == 'chatgpt-3.5-turbo':
|
| 60 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
|
|
|
| 73 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 74 |
else:
|
| 75 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 76 |
+
st.stop()
|
| 77 |
|
| 78 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 79 |
|
|
|
|
| 101 |
grobid_quantities_client=quantities_client,
|
| 102 |
grobid_superconductors_client=materials_client
|
| 103 |
)
|
|
|
|
| 104 |
return gqa
|
| 105 |
|
| 106 |
|
|
|
|
| 131 |
|
| 132 |
is_api_key_provided = st.session_state['api_key']
|
| 133 |
|
| 134 |
+
with st.sidebar:
|
| 135 |
+
model = st.radio(
|
| 136 |
+
"Model (cannot be changed after selection or upload)",
|
| 137 |
+
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
|
| 138 |
+
index=1,
|
| 139 |
+
captions=[
|
| 140 |
+
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
|
| 141 |
+
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
|
| 142 |
+
# "LLama2-70B-Chat + Sentence BERT (embeddings)",
|
| 143 |
+
],
|
| 144 |
+
help="Select the model you want to use.")
|
| 145 |
+
|
| 146 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
| 147 |
+
api_key = st.text_input('Huggingface API Key',
|
| 148 |
+
type="password") if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ[
|
| 149 |
+
'HUGGINGFACEHUB_API_TOKEN']
|
| 150 |
+
st.markdown(
|
| 151 |
+
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
|
| 152 |
+
|
| 153 |
if api_key:
|
| 154 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 155 |
+
st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
|
| 156 |
+
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
| 157 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
| 158 |
+
st.session_state['rqa'][model] = init_qa(model)
|
| 159 |
+
|
| 160 |
elif model == 'chatgpt-3.5-turbo':
|
| 161 |
+
api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
|
| 162 |
+
os.environ['OPENAI_API_KEY']
|
| 163 |
+
st.markdown(
|
| 164 |
+
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
|
| 165 |
if api_key:
|
| 166 |
st.session_state['api_key'] = is_api_key_provided = True
|
| 167 |
+
st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
|
| 168 |
+
if 'OPENAI_API_KEY' not in os.environ:
|
| 169 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
| 170 |
+
st.session_state['rqa'][model] = init_qa(model)
|
| 171 |
+
# else:
|
| 172 |
+
# is_api_key_provided = st.session_state['api_key']
|
| 173 |
|
| 174 |
st.title("📝 Scientific Document Insight Q&A")
|
| 175 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
| 176 |
|
| 177 |
+
uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
|
| 178 |
+
disabled=not is_api_key_provided,
|
| 179 |
+
help="The full-text is extracted using Grobid. ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
question = st.chat_input(
|
| 182 |
"Ask something about the article",
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
with st.sidebar:
|
| 188 |
+
st.header("Settings")
|
| 189 |
+
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
|
| 190 |
+
help="LLM will respond the question, Embedding will show the "
|
| 191 |
+
"paragraphs relevant to the question in the paper.")
|
| 192 |
+
chunk_size = st.slider("Chunks size", 100, 2000, value=250,
|
| 193 |
+
help="Size of chunks in which the document is partitioned",
|
| 194 |
+
disabled=not uploaded_file)
|
| 195 |
+
context_size = st.slider("Context size", 3, 10, value=4,
|
| 196 |
+
help="Number of chunks to consider when answering a question",
|
| 197 |
+
disabled=not uploaded_file)
|
| 198 |
+
|
| 199 |
+
st.session_state['ner_processing'] = st.checkbox("NER processing on LLM response")
|
| 200 |
+
st.markdown(
|
| 201 |
+
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
|
| 202 |
+
unsafe_allow_html=True)
|
| 203 |
+
|
| 204 |
+
st.divider()
|
| 205 |
+
|
| 206 |
st.header("Documentation")
|
| 207 |
st.markdown("https://github.com/lfoppiano/document-qa")
|
| 208 |
st.markdown(
|
| 209 |
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
|
| 210 |
|
|
|
|
|
|
|
|
|
|
| 211 |
if st.session_state['git_rev'] != "unknown":
|
| 212 |
st.markdown("**Revision number**: [" + st.session_state[
|
| 213 |
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
|
|
|
|
| 225 |
tmp_file = NamedTemporaryFile()
|
| 226 |
tmp_file.write(bytearray(binary))
|
| 227 |
# hash = get_file_hash(tmp_file.name)[:10]
|
| 228 |
+
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
| 229 |
+
chunk_size=chunk_size,
|
| 230 |
+
perc_overlap=0.1)
|
| 231 |
st.session_state['loaded_embeddings'] = True
|
| 232 |
st.session_state.messages = []
|
| 233 |
|
|
|
|
| 248 |
text_response = None
|
| 249 |
if mode == "Embeddings":
|
| 250 |
with st.spinner("Generating LLM response..."):
|
| 251 |
+
text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
|
| 252 |
+
context_size=context_size)
|
| 253 |
elif mode == "LLM":
|
| 254 |
with st.spinner("Generating response..."):
|
| 255 |
+
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 256 |
+
context_size=context_size)
|
| 257 |
|
| 258 |
if not text_response:
|
| 259 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
| 260 |
|
| 261 |
with st.chat_message("assistant"):
|
| 262 |
if mode == "LLM":
|
| 263 |
+
if st.session_state['ner_processing']:
|
| 264 |
+
with st.spinner("Processing NER on LLM response..."):
|
| 265 |
+
entities = gqa.process_single_text(text_response)
|
| 266 |
+
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
|
| 267 |
+
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
|
| 268 |
+
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
|
| 269 |
+
text_response = decorated_text
|
| 270 |
+
st.markdown(text_response, unsafe_allow_html=True)
|
|
|
|
| 271 |
else:
|
| 272 |
st.write(text_response)
|
| 273 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|