Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
-
from langchain_core.messages import
|
| 4 |
from langchain_core.output_parsers import PydanticOutputParser
|
| 5 |
from langchain_core.prompts import PromptTemplate
|
| 6 |
from langchain.chains import LLMChain
|
|
@@ -17,7 +17,6 @@ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
|
| 17 |
from langchain_community.vectorstores import FAISS
|
| 18 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 19 |
|
| 20 |
-
# Environment Variable Checks
|
| 21 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 22 |
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
|
| 23 |
|
|
@@ -29,10 +28,10 @@ if not HUGGINGFACE_ACCESS_TOKEN:
|
|
| 29 |
st.error("β HUGGINGFACE_ACCESS_TOKEN not found.")
|
| 30 |
st.stop()
|
| 31 |
|
| 32 |
-
# Initialize LLM and Embeddings
|
| 33 |
llm = ChatGoogleGenerativeAI(
|
| 34 |
model="gemini-1.5-pro",
|
| 35 |
-
google_api_key=GOOGLE_API_KEY
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
embeddings = HuggingFaceInferenceAPIEmbeddings(
|
|
@@ -40,7 +39,6 @@ embeddings = HuggingFaceInferenceAPIEmbeddings(
|
|
| 40 |
model_name="BAAI/bge-small-en-v1.5"
|
| 41 |
)
|
| 42 |
|
| 43 |
-
# Pydantic Models for Structured Output
|
| 44 |
class KeyPoint(BaseModel):
|
| 45 |
point: str = Field(description="A key point extracted from the document.")
|
| 46 |
|
|
@@ -51,10 +49,8 @@ class DocumentAnalysis(BaseModel):
|
|
| 51 |
key_points: List[KeyPoint] = Field(description="List of key points from the document.")
|
| 52 |
summary: Summary = Field(description="Summary of the document.")
|
| 53 |
|
| 54 |
-
# Output Parser
|
| 55 |
parser = PydanticOutputParser(pydantic_object=DocumentAnalysis)
|
| 56 |
|
| 57 |
-
# Prompt Template
|
| 58 |
prompt_template = """
|
| 59 |
Analyze the following text and extract key points and a summary.
|
| 60 |
{format_instructions}
|
|
@@ -66,19 +62,15 @@ prompt = PromptTemplate(
|
|
| 66 |
partial_variables={"format_instructions": parser.get_format_instructions()}
|
| 67 |
)
|
| 68 |
|
| 69 |
-
# LLM Chain
|
| 70 |
chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)
|
| 71 |
|
| 72 |
-
# Text Analysis Function
|
| 73 |
def analyze_text_structured(text):
|
| 74 |
return chain.run(text=text)
|
| 75 |
|
| 76 |
-
# PDF Text Extraction
|
| 77 |
def extract_text_from_pdf(pdf_file):
|
| 78 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 79 |
return "".join(page.extract_text() for page in pdf_reader.pages)
|
| 80 |
|
| 81 |
-
# JSON to Readable Text
|
| 82 |
def json_to_text(analysis):
|
| 83 |
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
|
| 84 |
text_output += "=== Key Points ===\n"
|
|
@@ -86,7 +78,6 @@ def json_to_text(analysis):
|
|
| 86 |
text_output += f"{i}. {key_point.point}\n"
|
| 87 |
return text_output
|
| 88 |
|
| 89 |
-
# PDF Report Generation (Updated)
|
| 90 |
def create_pdf_report(analysis):
|
| 91 |
pdf = FPDF()
|
| 92 |
pdf.add_page()
|
|
@@ -95,13 +86,11 @@ def create_pdf_report(analysis):
|
|
| 95 |
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
|
| 96 |
pdf.multi_cell(0, 10, txt=json_to_text(analysis))
|
| 97 |
|
| 98 |
-
# Use BytesIO to create a bytes-like object
|
| 99 |
pdf_bytes = io.BytesIO()
|
| 100 |
pdf.output(pdf_bytes, dest='S')
|
| 101 |
pdf_bytes.seek(0)
|
| 102 |
return pdf_bytes.getvalue()
|
| 103 |
|
| 104 |
-
# Word Report Generation (Updated)
|
| 105 |
def create_word_report(analysis):
|
| 106 |
doc = Document()
|
| 107 |
doc.add_heading('PDF Analysis Report', 0)
|
|
@@ -109,16 +98,13 @@ def create_word_report(analysis):
|
|
| 109 |
doc.add_heading('Analysis', level=1)
|
| 110 |
doc.add_paragraph(json_to_text(analysis))
|
| 111 |
|
| 112 |
-
# Use BytesIO to create a bytes-like object
|
| 113 |
docx_bytes = io.BytesIO()
|
| 114 |
doc.save(docx_bytes)
|
| 115 |
docx_bytes.seek(0)
|
| 116 |
return docx_bytes.getvalue()
|
| 117 |
|
| 118 |
-
# Streamlit Page Configuration
|
| 119 |
st.set_page_config(page_title="Chat With PDF", page_icon="π")
|
| 120 |
|
| 121 |
-
# Custom CSS
|
| 122 |
def local_css():
|
| 123 |
st.markdown("""
|
| 124 |
<style>
|
|
@@ -203,7 +189,6 @@ def local_css():
|
|
| 203 |
|
| 204 |
local_css()
|
| 205 |
|
| 206 |
-
# Session State Initialization
|
| 207 |
if "current_file" not in st.session_state:
|
| 208 |
st.session_state.current_file = None
|
| 209 |
if "pdf_summary" not in st.session_state:
|
|
@@ -219,14 +204,12 @@ if "vectorstore" not in st.session_state:
|
|
| 219 |
if "messages" not in st.session_state:
|
| 220 |
st.session_state.messages = []
|
| 221 |
|
| 222 |
-
# Main App Layout
|
| 223 |
st.markdown('<div class="main-header">', unsafe_allow_html=True)
|
| 224 |
st.markdown('<div class="flag-stripe"></div>', unsafe_allow_html=True)
|
| 225 |
st.title("π Chat With PDF")
|
| 226 |
st.caption("Your AI-powered Document Analyzer")
|
| 227 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 228 |
|
| 229 |
-
# PDF Upload and Analysis Section
|
| 230 |
with st.container():
|
| 231 |
st.markdown('<div class="card animate-fadeIn">', unsafe_allow_html=True)
|
| 232 |
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
|
|
@@ -276,7 +259,6 @@ with st.container():
|
|
| 276 |
)
|
| 277 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 278 |
|
| 279 |
-
# Document Chat Section
|
| 280 |
if "vectorstore" in st.session_state and st.session_state.vectorstore is not None:
|
| 281 |
st.subheader("Chat with the Document")
|
| 282 |
|
|
@@ -296,8 +278,7 @@ if "vectorstore" in st.session_state and st.session_state.vectorstore is not Non
|
|
| 296 |
context = "\n".join([doc.page_content for doc in docs])
|
| 297 |
|
| 298 |
messages = [
|
| 299 |
-
|
| 300 |
-
HumanMessage(content=f"Context: {context}\n\nQuestion: {prompt}")
|
| 301 |
]
|
| 302 |
|
| 303 |
response = llm.invoke(messages)
|
|
@@ -305,7 +286,6 @@ if "vectorstore" in st.session_state and st.session_state.vectorstore is not Non
|
|
| 305 |
|
| 306 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
| 307 |
|
| 308 |
-
# Footer
|
| 309 |
st.markdown(
|
| 310 |
f'<div class="footer">Analysis Time: {st.session_state.analysis_time:.1f}s | Powered by Google Generative AI</div>',
|
| 311 |
unsafe_allow_html=True
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
+
from langchain_core.messages import HumanMessage
|
| 4 |
from langchain_core.output_parsers import PydanticOutputParser
|
| 5 |
from langchain_core.prompts import PromptTemplate
|
| 6 |
from langchain.chains import LLMChain
|
|
|
|
| 17 |
from langchain_community.vectorstores import FAISS
|
| 18 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 19 |
|
|
|
|
| 20 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 21 |
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
|
| 22 |
|
|
|
|
| 28 |
st.error("β HUGGINGFACE_ACCESS_TOKEN not found.")
|
| 29 |
st.stop()
|
| 30 |
|
|
|
|
| 31 |
llm = ChatGoogleGenerativeAI(
|
| 32 |
model="gemini-1.5-pro",
|
| 33 |
+
google_api_key=GOOGLE_API_KEY,
|
| 34 |
+
convert_system_message_to_human=True
|
| 35 |
)
|
| 36 |
|
| 37 |
embeddings = HuggingFaceInferenceAPIEmbeddings(
|
|
|
|
| 39 |
model_name="BAAI/bge-small-en-v1.5"
|
| 40 |
)
|
| 41 |
|
|
|
|
| 42 |
class KeyPoint(BaseModel):
|
| 43 |
point: str = Field(description="A key point extracted from the document.")
|
| 44 |
|
|
|
|
| 49 |
key_points: List[KeyPoint] = Field(description="List of key points from the document.")
|
| 50 |
summary: Summary = Field(description="Summary of the document.")
|
| 51 |
|
|
|
|
| 52 |
parser = PydanticOutputParser(pydantic_object=DocumentAnalysis)
|
| 53 |
|
|
|
|
| 54 |
prompt_template = """
|
| 55 |
Analyze the following text and extract key points and a summary.
|
| 56 |
{format_instructions}
|
|
|
|
| 62 |
partial_variables={"format_instructions": parser.get_format_instructions()}
|
| 63 |
)
|
| 64 |
|
|
|
|
| 65 |
chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)
|
| 66 |
|
|
|
|
| 67 |
def analyze_text_structured(text):
|
| 68 |
return chain.run(text=text)
|
| 69 |
|
|
|
|
| 70 |
def extract_text_from_pdf(pdf_file):
|
| 71 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 72 |
return "".join(page.extract_text() for page in pdf_reader.pages)
|
| 73 |
|
|
|
|
| 74 |
def json_to_text(analysis):
|
| 75 |
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
|
| 76 |
text_output += "=== Key Points ===\n"
|
|
|
|
| 78 |
text_output += f"{i}. {key_point.point}\n"
|
| 79 |
return text_output
|
| 80 |
|
|
|
|
| 81 |
def create_pdf_report(analysis):
|
| 82 |
pdf = FPDF()
|
| 83 |
pdf.add_page()
|
|
|
|
| 86 |
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
|
| 87 |
pdf.multi_cell(0, 10, txt=json_to_text(analysis))
|
| 88 |
|
|
|
|
| 89 |
pdf_bytes = io.BytesIO()
|
| 90 |
pdf.output(pdf_bytes, dest='S')
|
| 91 |
pdf_bytes.seek(0)
|
| 92 |
return pdf_bytes.getvalue()
|
| 93 |
|
|
|
|
| 94 |
def create_word_report(analysis):
|
| 95 |
doc = Document()
|
| 96 |
doc.add_heading('PDF Analysis Report', 0)
|
|
|
|
| 98 |
doc.add_heading('Analysis', level=1)
|
| 99 |
doc.add_paragraph(json_to_text(analysis))
|
| 100 |
|
|
|
|
| 101 |
docx_bytes = io.BytesIO()
|
| 102 |
doc.save(docx_bytes)
|
| 103 |
docx_bytes.seek(0)
|
| 104 |
return docx_bytes.getvalue()
|
| 105 |
|
|
|
|
| 106 |
st.set_page_config(page_title="Chat With PDF", page_icon="π")
|
| 107 |
|
|
|
|
| 108 |
def local_css():
|
| 109 |
st.markdown("""
|
| 110 |
<style>
|
|
|
|
| 189 |
|
| 190 |
local_css()
|
| 191 |
|
|
|
|
| 192 |
if "current_file" not in st.session_state:
|
| 193 |
st.session_state.current_file = None
|
| 194 |
if "pdf_summary" not in st.session_state:
|
|
|
|
| 204 |
if "messages" not in st.session_state:
|
| 205 |
st.session_state.messages = []
|
| 206 |
|
|
|
|
| 207 |
st.markdown('<div class="main-header">', unsafe_allow_html=True)
|
| 208 |
st.markdown('<div class="flag-stripe"></div>', unsafe_allow_html=True)
|
| 209 |
st.title("π Chat With PDF")
|
| 210 |
st.caption("Your AI-powered Document Analyzer")
|
| 211 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 212 |
|
|
|
|
| 213 |
with st.container():
|
| 214 |
st.markdown('<div class="card animate-fadeIn">', unsafe_allow_html=True)
|
| 215 |
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
|
|
|
|
| 259 |
)
|
| 260 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 261 |
|
|
|
|
| 262 |
if "vectorstore" in st.session_state and st.session_state.vectorstore is not None:
|
| 263 |
st.subheader("Chat with the Document")
|
| 264 |
|
|
|
|
| 278 |
context = "\n".join([doc.page_content for doc in docs])
|
| 279 |
|
| 280 |
messages = [
|
| 281 |
+
HumanMessage(content=f"You are a helpful assistant. Answer the question based on the provided document context.\n\nContext: {context}\n\nQuestion: {prompt}")
|
|
|
|
| 282 |
]
|
| 283 |
|
| 284 |
response = llm.invoke(messages)
|
|
|
|
| 286 |
|
| 287 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
| 288 |
|
|
|
|
| 289 |
st.markdown(
|
| 290 |
f'<div class="footer">Analysis Time: {st.session_state.analysis_time:.1f}s | Powered by Google Generative AI</div>',
|
| 291 |
unsafe_allow_html=True
|