Spaces:
Running
Running
Merge branch 'main' into pdf-render
Browse files- README.md +7 -4
- document_qa/document_qa_engine.py +23 -9
- pyproject.toml +1 -1
- streamlit_app.py +25 -2
README.md
CHANGED
|
@@ -16,11 +16,14 @@ license: apache-2.0
|
|
| 16 |
|
| 17 |
## Introduction
|
| 18 |
|
| 19 |
-
Question/Answering on scientific documents using LLMs
|
| 20 |
-
|
| 21 |
-
Differently to most of the
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
**Demos**:
|
| 26 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
|
|
|
| 16 |
|
| 17 |
## Introduction
|
| 18 |
|
| 19 |
+
Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
|
| 20 |
+
The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
|
| 21 |
+
Differently to most of the projects, we focus on scientific articles.
|
| 22 |
+
We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
|
| 23 |
|
| 24 |
+
Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
|
| 25 |
+
|
| 26 |
+
The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.
|
| 27 |
|
| 28 |
**Demos**:
|
| 29 |
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
|
document_qa/document_qa_engine.py
CHANGED
|
@@ -23,7 +23,13 @@ class DocumentQAEngine:
|
|
| 23 |
embeddings_map_from_md5 = {}
|
| 24 |
embeddings_map_to_md5 = {}
|
| 25 |
|
| 26 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
self.embedding_function = embedding_function
|
| 28 |
self.llm = llm
|
| 29 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
@@ -81,14 +87,14 @@ class DocumentQAEngine:
|
|
| 81 |
return self.embeddings_map_from_md5[md5]
|
| 82 |
|
| 83 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
| 84 |
-
verbose=False) -> (
|
| 85 |
Any, str):
|
| 86 |
# self.load_embeddings(self.embeddings_root_path)
|
| 87 |
|
| 88 |
if verbose:
|
| 89 |
print(query)
|
| 90 |
|
| 91 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
| 92 |
response = response['output_text'] if 'output_text' in response else response
|
| 93 |
|
| 94 |
if verbose:
|
|
@@ -138,9 +144,15 @@ class DocumentQAEngine:
|
|
| 138 |
|
| 139 |
return parsed_output
|
| 140 |
|
| 141 |
-
def _run_query(self, doc_id, query, context_size=4):
|
| 142 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
| 145 |
|
| 146 |
def _get_context(self, doc_id, query, context_size=4):
|
|
@@ -150,6 +162,7 @@ class DocumentQAEngine:
|
|
| 150 |
return relevant_documents
|
| 151 |
|
| 152 |
def get_all_context_by_document(self, doc_id):
|
|
|
|
| 153 |
db = self.embeddings_dict[doc_id]
|
| 154 |
docs = db.get()
|
| 155 |
return docs['documents']
|
|
@@ -161,6 +174,7 @@ class DocumentQAEngine:
|
|
| 161 |
return relevant_documents
|
| 162 |
|
| 163 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
|
|
|
| 164 |
if verbose:
|
| 165 |
print("File", pdf_file_path)
|
| 166 |
filename = Path(pdf_file_path).stem
|
|
@@ -209,18 +223,17 @@ class DocumentQAEngine:
|
|
| 209 |
|
| 210 |
if hash not in self.embeddings_dict.keys():
|
| 211 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 212 |
-
|
| 213 |
else:
|
| 214 |
self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
|
| 215 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 216 |
collection_name=hash)
|
| 217 |
|
| 218 |
-
|
| 219 |
self.embeddings_root_path = None
|
| 220 |
|
| 221 |
return hash
|
| 222 |
|
| 223 |
-
def create_embeddings(self, pdfs_dir_path: Path):
|
| 224 |
input_files = []
|
| 225 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
| 226 |
for file_ in files:
|
|
@@ -238,7 +251,8 @@ class DocumentQAEngine:
|
|
| 238 |
print(data_path, "exists. Skipping it ")
|
| 239 |
continue
|
| 240 |
|
| 241 |
-
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=
|
|
|
|
| 242 |
filename = metadata[0]['filename']
|
| 243 |
|
| 244 |
vector_db_document = Chroma.from_texts(texts,
|
|
|
|
| 23 |
embeddings_map_from_md5 = {}
|
| 24 |
embeddings_map_to_md5 = {}
|
| 25 |
|
| 26 |
+
def __init__(self,
|
| 27 |
+
llm,
|
| 28 |
+
embedding_function,
|
| 29 |
+
qa_chain_type="stuff",
|
| 30 |
+
embeddings_root_path=None,
|
| 31 |
+
grobid_url=None,
|
| 32 |
+
):
|
| 33 |
self.embedding_function = embedding_function
|
| 34 |
self.llm = llm
|
| 35 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
|
| 87 |
return self.embeddings_map_from_md5[md5]
|
| 88 |
|
| 89 |
def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
|
| 90 |
+
verbose=False, memory=None) -> (
|
| 91 |
Any, str):
|
| 92 |
# self.load_embeddings(self.embeddings_root_path)
|
| 93 |
|
| 94 |
if verbose:
|
| 95 |
print(query)
|
| 96 |
|
| 97 |
+
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
|
| 98 |
response = response['output_text'] if 'output_text' in response else response
|
| 99 |
|
| 100 |
if verbose:
|
|
|
|
| 144 |
|
| 145 |
return parsed_output
|
| 146 |
|
| 147 |
+
def _run_query(self, doc_id, query, memory=None, context_size=4):
|
| 148 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
| 149 |
+
if memory:
|
| 150 |
+
return self.chain.run(input_documents=relevant_documents,
|
| 151 |
+
question=query)
|
| 152 |
+
else:
|
| 153 |
+
return self.chain.run(input_documents=relevant_documents,
|
| 154 |
+
question=query,
|
| 155 |
+
memory=memory)
|
| 156 |
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
|
| 157 |
|
| 158 |
def _get_context(self, doc_id, query, context_size=4):
|
|
|
|
| 162 |
return relevant_documents
|
| 163 |
|
| 164 |
def get_all_context_by_document(self, doc_id):
|
| 165 |
+
"""Return the full context from the document"""
|
| 166 |
db = self.embeddings_dict[doc_id]
|
| 167 |
docs = db.get()
|
| 168 |
return docs['documents']
|
|
|
|
| 174 |
return relevant_documents
|
| 175 |
|
| 176 |
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
| 177 |
+
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
|
| 178 |
if verbose:
|
| 179 |
print("File", pdf_file_path)
|
| 180 |
filename = Path(pdf_file_path).stem
|
|
|
|
| 223 |
|
| 224 |
if hash not in self.embeddings_dict.keys():
|
| 225 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 226 |
+
collection_name=hash)
|
| 227 |
else:
|
| 228 |
self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
|
| 229 |
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
|
| 230 |
collection_name=hash)
|
| 231 |
|
|
|
|
| 232 |
self.embeddings_root_path = None
|
| 233 |
|
| 234 |
return hash
|
| 235 |
|
| 236 |
+
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
|
| 237 |
input_files = []
|
| 238 |
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
|
| 239 |
for file_ in files:
|
|
|
|
| 251 |
print(data_path, "exists. Skipping it ")
|
| 252 |
continue
|
| 253 |
|
| 254 |
+
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
|
| 255 |
+
perc_overlap=perc_overlap)
|
| 256 |
filename = metadata[0]['filename']
|
| 257 |
|
| 258 |
vector_db_document = Chroma.from_texts(texts,
|
pyproject.toml
CHANGED
|
@@ -3,7 +3,7 @@ requires = ["setuptools", "setuptools-scm"]
|
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
| 5 |
[tool.bumpversion]
|
| 6 |
-
current_version = "0.
|
| 7 |
commit = "true"
|
| 8 |
tag = "true"
|
| 9 |
tag_name = "v{new_version}"
|
|
|
|
| 3 |
build-backend = "setuptools.build_meta"
|
| 4 |
|
| 5 |
[tool.bumpversion]
|
| 6 |
+
current_version = "0.3.0"
|
| 7 |
commit = "true"
|
| 8 |
tag = "true"
|
| 9 |
tag_name = "v{new_version}"
|
streamlit_app.py
CHANGED
|
@@ -7,6 +7,7 @@ from tempfile import NamedTemporaryFile
|
|
| 7 |
import dotenv
|
| 8 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 9 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
|
|
|
| 10 |
|
| 11 |
dotenv.load_dotenv(override=True)
|
| 12 |
|
|
@@ -52,6 +53,9 @@ if 'ner_processing' not in st.session_state:
|
|
| 52 |
if 'uploaded' not in st.session_state:
|
| 53 |
st.session_state['uploaded'] = False
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
if 'binary' not in st.session_state:
|
| 56 |
st.session_state['binary'] = None
|
| 57 |
|
|
@@ -82,6 +86,11 @@ def new_file():
|
|
| 82 |
st.session_state['loaded_embeddings'] = None
|
| 83 |
st.session_state['doc_id'] = None
|
| 84 |
st.session_state['uploaded'] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
# @st.cache_resource
|
|
@@ -112,6 +121,7 @@ def init_qa(model, api_key=None):
|
|
| 112 |
else:
|
| 113 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 114 |
st.stop()
|
|
|
|
| 115 |
|
| 116 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 117 |
|
|
@@ -183,7 +193,7 @@ with st.sidebar:
|
|
| 183 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
| 184 |
|
| 185 |
st.markdown(
|
| 186 |
-
":warning: Mistral and Zephyr are
|
| 187 |
|
| 188 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
| 189 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
@@ -219,6 +229,12 @@ with st.sidebar:
|
|
| 219 |
st.session_state['rqa'][model] = init_qa(model)
|
| 220 |
# else:
|
| 221 |
# is_api_key_provided = st.session_state['api_key']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
left_column, right_column = st.columns([1, 1])
|
| 223 |
|
| 224 |
with right_column:
|
|
@@ -349,7 +365,8 @@ with right_column:
|
|
| 349 |
elif mode == "LLM":
|
| 350 |
with st.spinner("Generating response..."):
|
| 351 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 352 |
-
|
|
|
|
| 353 |
|
| 354 |
if not text_response:
|
| 355 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
|
@@ -368,5 +385,11 @@ with right_column:
|
|
| 368 |
st.write(text_response)
|
| 369 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
| 372 |
play_old_messages()
|
|
|
|
| 7 |
import dotenv
|
| 8 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 9 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
| 10 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 11 |
|
| 12 |
dotenv.load_dotenv(override=True)
|
| 13 |
|
|
|
|
| 53 |
if 'uploaded' not in st.session_state:
|
| 54 |
st.session_state['uploaded'] = False
|
| 55 |
|
| 56 |
+
if 'memory' not in st.session_state:
|
| 57 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 58 |
+
|
| 59 |
if 'binary' not in st.session_state:
|
| 60 |
st.session_state['binary'] = None
|
| 61 |
|
|
|
|
| 86 |
st.session_state['loaded_embeddings'] = None
|
| 87 |
st.session_state['doc_id'] = None
|
| 88 |
st.session_state['uploaded'] = True
|
| 89 |
+
st.session_state['memory'].clear()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def clear_memory():
|
| 93 |
+
st.session_state['memory'].clear()
|
| 94 |
|
| 95 |
|
| 96 |
# @st.cache_resource
|
|
|
|
| 121 |
else:
|
| 122 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 123 |
st.stop()
|
| 124 |
+
return
|
| 125 |
|
| 126 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
| 127 |
|
|
|
|
| 193 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
| 194 |
|
| 195 |
st.markdown(
|
| 196 |
+
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
|
| 197 |
|
| 198 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
| 199 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
|
|
| 229 |
st.session_state['rqa'][model] = init_qa(model)
|
| 230 |
# else:
|
| 231 |
# is_api_key_provided = st.session_state['api_key']
|
| 232 |
+
|
| 233 |
+
st.button(
|
| 234 |
+
'Reset chat memory.',
|
| 235 |
+
on_click=clear_memory(),
|
| 236 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
|
| 237 |
+
|
| 238 |
left_column, right_column = st.columns([1, 1])
|
| 239 |
|
| 240 |
with right_column:
|
|
|
|
| 365 |
elif mode == "LLM":
|
| 366 |
with st.spinner("Generating response..."):
|
| 367 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 368 |
+
context_size=context_size,
|
| 369 |
+
memory=st.session_state.memory)
|
| 370 |
|
| 371 |
if not text_response:
|
| 372 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
|
|
|
| 385 |
st.write(text_response)
|
| 386 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
| 387 |
|
| 388 |
+
for id in range(0, len(st.session_state.messages), 2):
|
| 389 |
+
question = st.session_state.messages[id]['content']
|
| 390 |
+
if len(st.session_state.messages) > id + 1:
|
| 391 |
+
answer = st.session_state.messages[id + 1]['content']
|
| 392 |
+
st.session_state.memory.save_context({"input": question}, {"output": answer})
|
| 393 |
+
|
| 394 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
| 395 |
play_old_messages()
|