Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ from langchain_core.prompts import PromptTemplate
|
|
| 6 |
from langchain.chains import LLMChain
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
from typing import List
|
| 9 |
-
from dotenv import load_dotenv
|
| 10 |
import os
|
| 11 |
import time
|
| 12 |
from datetime import datetime
|
|
@@ -18,12 +17,36 @@ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
|
| 18 |
from langchain_community.vectorstores import FAISS
|
| 19 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
class KeyPoint(BaseModel):
|
| 28 |
point: str = Field(description="A key point extracted from the document.")
|
| 29 |
|
|
@@ -33,7 +56,10 @@ class Summary(BaseModel):
|
|
| 33 |
class DocumentAnalysis(BaseModel):
|
| 34 |
key_points: List[KeyPoint] = Field(description="List of key points from the document.")
|
| 35 |
summary: Summary = Field(description="Summary of the document.")
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
parser = PydanticOutputParser(pydantic_object=DocumentAnalysis)
|
| 38 |
|
| 39 |
prompt_template = """
|
|
@@ -46,54 +72,51 @@ prompt = PromptTemplate(
|
|
| 46 |
input_variables=["text"],
|
| 47 |
partial_variables={"format_instructions": parser.get_format_instructions()}
|
| 48 |
)
|
| 49 |
-
|
| 50 |
chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
def analyze_text_structured(text):
|
| 53 |
-
|
| 54 |
-
return output
|
| 55 |
|
| 56 |
-
|
| 57 |
def extract_text_from_pdf(pdf_file):
|
| 58 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 59 |
-
|
| 60 |
-
for page in pdf_reader.pages:
|
| 61 |
-
text += page.extract_text()
|
| 62 |
-
return text
|
| 63 |
|
| 64 |
-
|
| 65 |
def json_to_text(analysis):
|
| 66 |
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
|
| 67 |
text_output += "=== Key Points ===\n"
|
| 68 |
for i, key_point in enumerate(analysis.key_points, start=1):
|
| 69 |
text_output += f"{i}. {key_point.point}\n"
|
| 70 |
return text_output
|
| 71 |
-
|
| 72 |
def create_pdf_report(analysis):
|
| 73 |
pdf = FPDF()
|
| 74 |
pdf.add_page()
|
| 75 |
pdf.set_font('Helvetica', '', 12)
|
| 76 |
pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
|
| 77 |
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
|
| 78 |
-
|
| 79 |
-
pdf.multi_cell(0, 10, txt=clean_text)
|
| 80 |
return pdf.output(dest='S')
|
| 81 |
-
|
| 82 |
def create_word_report(analysis):
|
| 83 |
doc = Document()
|
| 84 |
doc.add_heading('PDF Analysis Report', 0)
|
| 85 |
doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
|
| 86 |
-
clean_text = json_to_text(analysis)
|
| 87 |
doc.add_heading('Analysis', level=1)
|
| 88 |
-
doc.add_paragraph(
|
| 89 |
docx_bytes = io.BytesIO()
|
| 90 |
doc.save(docx_bytes)
|
| 91 |
docx_bytes.seek(0)
|
| 92 |
return docx_bytes.getvalue()
|
| 93 |
-
|
| 94 |
-
st.set_page_config(page_title="Chat With PDF", page_icon="😒")
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def local_css():
|
| 98 |
st.markdown("""
|
| 99 |
<style>
|
|
@@ -177,7 +200,8 @@ def local_css():
|
|
| 177 |
""", unsafe_allow_html=True)
|
| 178 |
|
| 179 |
local_css()
|
| 180 |
-
|
|
|
|
| 181 |
if "current_file" not in st.session_state:
|
| 182 |
st.session_state.current_file = None
|
| 183 |
if "pdf_summary" not in st.session_state:
|
|
@@ -193,85 +217,94 @@ if "vectorstore" not in st.session_state:
|
|
| 193 |
if "messages" not in st.session_state:
|
| 194 |
st.session_state.messages = []
|
| 195 |
|
| 196 |
-
|
| 197 |
st.markdown('<div class="main-header">', unsafe_allow_html=True)
|
| 198 |
st.markdown('<div class="flag-stripe"></div>', unsafe_allow_html=True)
|
| 199 |
-
st.title("
|
| 200 |
-
st.caption("Your AI-powered
|
| 201 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
st.session_state.
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
st.session_state.messages = []
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
st.
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
st.session_state.
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
| 250 |
if "vectorstore" in st.session_state:
|
| 251 |
st.subheader("Chat with the Document")
|
|
|
|
| 252 |
for message in st.session_state.messages:
|
| 253 |
with st.chat_message(message["role"]):
|
| 254 |
st.markdown(message["content"])
|
| 255 |
|
| 256 |
if prompt := st.chat_input("Ask a question about the document"):
|
| 257 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
| 258 |
with st.chat_message("user"):
|
| 259 |
st.markdown(prompt)
|
| 260 |
|
| 261 |
with st.chat_message("assistant"):
|
| 262 |
with st.spinner("Thinking..."):
|
| 263 |
-
|
| 264 |
docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
|
| 265 |
context = "\n".join([doc.page_content for doc in docs])
|
| 266 |
-
|
| 267 |
messages = [
|
| 268 |
-
SystemMessage(content="You are a
|
| 269 |
HumanMessage(content=f"Context: {context}\n\nQuestion: {prompt}")
|
| 270 |
]
|
| 271 |
|
| 272 |
response = llm.invoke(messages)
|
| 273 |
st.markdown(response.content)
|
|
|
|
| 274 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
| 275 |
|
| 276 |
-
|
| 277 |
-
st.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from langchain.chains import LLMChain
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
from typing import List
|
|
|
|
| 9 |
import os
|
| 10 |
import time
|
| 11 |
from datetime import datetime
|
|
|
|
| 17 |
from langchain_community.vectorstores import FAISS
|
| 18 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 19 |
|
| 20 |
+
# ======================
|
| 21 |
+
# SECRETS CONFIGURATION
|
| 22 |
+
# ======================
|
| 23 |
+
# Get API keys from Hugging Face Secrets
|
| 24 |
+
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 25 |
+
HUGGINGFACE_ACCESS_TOKEN = os.environ.get("HUGGINGFACE_ACCESS_TOKEN")
|
| 26 |
+
|
| 27 |
+
# Validate required secrets
|
| 28 |
+
if not GOOGLE_API_KEY:
|
| 29 |
+
st.error("❌ GOOGLE_API_KEY not found. Please set it in Space Settings > Secrets.")
|
| 30 |
+
st.stop()
|
| 31 |
+
|
| 32 |
+
if not HUGGINGFACE_ACCESS_TOKEN:
|
| 33 |
+
st.error("❌ HUGGINGFACE_ACCESS_TOKEN not found. Please set it in Space Settings > Secrets.")
|
| 34 |
+
st.stop()
|
| 35 |
+
|
| 36 |
+
# Initialize LLM and embeddings with secrets
|
| 37 |
+
llm = ChatGoogleGenerativeAI(
|
| 38 |
+
model="gemini-1.5-pro",
|
| 39 |
+
google_api_key=GOOGLE_API_KEY
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
embeddings = HuggingFaceInferenceAPIEmbeddings(
|
| 43 |
+
api_key=HUGGINGFACE_ACCESS_TOKEN,
|
| 44 |
+
model_name="BAAI/bge-small-en-v1.5"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# ======================
|
| 48 |
+
# DOCUMENT ANALYSIS CLASSES
|
| 49 |
+
# ======================
|
| 50 |
class KeyPoint(BaseModel):
|
| 51 |
point: str = Field(description="A key point extracted from the document.")
|
| 52 |
|
|
|
|
| 56 |
class DocumentAnalysis(BaseModel):
|
| 57 |
key_points: List[KeyPoint] = Field(description="List of key points from the document.")
|
| 58 |
summary: Summary = Field(description="Summary of the document.")
|
| 59 |
+
|
| 60 |
+
# ======================
|
| 61 |
+
# CHAIN SETUP
|
| 62 |
+
# ======================
|
| 63 |
parser = PydanticOutputParser(pydantic_object=DocumentAnalysis)
|
| 64 |
|
| 65 |
prompt_template = """
|
|
|
|
| 72 |
input_variables=["text"],
|
| 73 |
partial_variables={"format_instructions": parser.get_format_instructions()}
|
| 74 |
)
|
| 75 |
+
|
| 76 |
chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)
|
| 77 |
+
|
| 78 |
+
# ======================
|
| 79 |
+
# UTILITY FUNCTIONS
|
| 80 |
+
# ======================
|
| 81 |
def analyze_text_structured(text):
|
| 82 |
+
return chain.run(text=text)
|
|
|
|
| 83 |
|
|
|
|
| 84 |
def extract_text_from_pdf(pdf_file):
|
| 85 |
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
| 86 |
+
return "".join(page.extract_text() for page in pdf_reader.pages)
|
|
|
|
|
|
|
|
|
|
| 87 |
|
|
|
|
| 88 |
def json_to_text(analysis):
|
| 89 |
text_output = "=== Summary ===\n" + f"{analysis.summary.summary}\n\n"
|
| 90 |
text_output += "=== Key Points ===\n"
|
| 91 |
for i, key_point in enumerate(analysis.key_points, start=1):
|
| 92 |
text_output += f"{i}. {key_point.point}\n"
|
| 93 |
return text_output
|
| 94 |
+
|
| 95 |
def create_pdf_report(analysis):
|
| 96 |
pdf = FPDF()
|
| 97 |
pdf.add_page()
|
| 98 |
pdf.set_font('Helvetica', '', 12)
|
| 99 |
pdf.cell(200, 10, txt="PDF Analysis Report", ln=True, align='C')
|
| 100 |
pdf.cell(200, 10, txt=f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=True, align='C')
|
| 101 |
+
pdf.multi_cell(0, 10, txt=json_to_text(analysis))
|
|
|
|
| 102 |
return pdf.output(dest='S')
|
| 103 |
+
|
| 104 |
def create_word_report(analysis):
|
| 105 |
doc = Document()
|
| 106 |
doc.add_heading('PDF Analysis Report', 0)
|
| 107 |
doc.add_paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
|
|
|
|
| 108 |
doc.add_heading('Analysis', level=1)
|
| 109 |
+
doc.add_paragraph(json_to_text(analysis))
|
| 110 |
docx_bytes = io.BytesIO()
|
| 111 |
doc.save(docx_bytes)
|
| 112 |
docx_bytes.seek(0)
|
| 113 |
return docx_bytes.getvalue()
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
# ======================
|
| 116 |
+
# STREAMLIT UI
|
| 117 |
+
# ======================
|
| 118 |
+
st.set_page_config(page_title="Chat With PDF", page_icon="📄")
|
| 119 |
+
|
| 120 |
def local_css():
|
| 121 |
st.markdown("""
|
| 122 |
<style>
|
|
|
|
| 200 |
""", unsafe_allow_html=True)
|
| 201 |
|
| 202 |
local_css()
|
| 203 |
+
|
| 204 |
+
# Initialize session state
|
| 205 |
if "current_file" not in st.session_state:
|
| 206 |
st.session_state.current_file = None
|
| 207 |
if "pdf_summary" not in st.session_state:
|
|
|
|
| 217 |
if "messages" not in st.session_state:
|
| 218 |
st.session_state.messages = []
|
| 219 |
|
| 220 |
+
# UI Components
|
| 221 |
st.markdown('<div class="main-header">', unsafe_allow_html=True)
|
| 222 |
st.markdown('<div class="flag-stripe"></div>', unsafe_allow_html=True)
|
| 223 |
+
st.title("📄 Chat With PDF")
|
| 224 |
+
st.caption("Your AI-powered Document Analyzer")
|
| 225 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 226 |
+
|
| 227 |
+
# File Uploader
|
| 228 |
+
with st.container():
|
| 229 |
+
st.markdown('<div class="card animate-fadeIn">', unsafe_allow_html=True)
|
| 230 |
+
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
|
| 231 |
+
|
| 232 |
+
if uploaded_file is not None:
|
| 233 |
+
if st.session_state.current_file != uploaded_file.name:
|
| 234 |
+
st.session_state.current_file = uploaded_file.name
|
| 235 |
+
st.session_state.pdf_summary = None
|
| 236 |
+
st.session_state.pdf_report = None
|
| 237 |
+
st.session_state.word_report = None
|
| 238 |
+
st.session_state.vectorstore = None
|
| 239 |
st.session_state.messages = []
|
| 240 |
+
|
| 241 |
+
text = extract_text_from_pdf(uploaded_file)
|
| 242 |
+
|
| 243 |
+
if st.button("Analyze Text"):
|
| 244 |
+
start_time = time.time()
|
| 245 |
+
with st.spinner("Analyzing..."):
|
| 246 |
+
analysis = analyze_text_structured(text)
|
| 247 |
+
st.session_state.pdf_summary = analysis
|
| 248 |
+
|
| 249 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 250 |
+
chunks = text_splitter.split_text(text)
|
| 251 |
+
st.session_state.vectorstore = FAISS.from_texts(chunks, embeddings)
|
| 252 |
+
|
| 253 |
+
st.session_state.pdf_report = create_pdf_report(analysis)
|
| 254 |
+
st.session_state.word_report = create_word_report(analysis)
|
| 255 |
+
|
| 256 |
+
st.session_state.analysis_time = time.time() - start_time
|
| 257 |
+
st.subheader("Analysis Results")
|
| 258 |
+
st.text(json_to_text(analysis))
|
| 259 |
+
|
| 260 |
+
col1, col2 = st.columns(2)
|
| 261 |
+
with col1:
|
| 262 |
+
st.download_button(
|
| 263 |
+
label="Download PDF Report",
|
| 264 |
+
data=st.session_state.pdf_report,
|
| 265 |
+
file_name="analysis_report.pdf",
|
| 266 |
+
mime="application/pdf"
|
| 267 |
+
)
|
| 268 |
+
with col2:
|
| 269 |
+
st.download_button(
|
| 270 |
+
label="Download Word Report",
|
| 271 |
+
data=st.session_state.word_report,
|
| 272 |
+
file_name="analysis_report.docx",
|
| 273 |
+
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
| 274 |
+
)
|
| 275 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 276 |
+
|
| 277 |
+
# Chat Interface
|
| 278 |
if "vectorstore" in st.session_state:
|
| 279 |
st.subheader("Chat with the Document")
|
| 280 |
+
|
| 281 |
for message in st.session_state.messages:
|
| 282 |
with st.chat_message(message["role"]):
|
| 283 |
st.markdown(message["content"])
|
| 284 |
|
| 285 |
if prompt := st.chat_input("Ask a question about the document"):
|
| 286 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 287 |
+
|
| 288 |
with st.chat_message("user"):
|
| 289 |
st.markdown(prompt)
|
| 290 |
|
| 291 |
with st.chat_message("assistant"):
|
| 292 |
with st.spinner("Thinking..."):
|
|
|
|
| 293 |
docs = st.session_state.vectorstore.similarity_search(prompt, k=3)
|
| 294 |
context = "\n".join([doc.page_content for doc in docs])
|
| 295 |
+
|
| 296 |
messages = [
|
| 297 |
+
SystemMessage(content="You are a helpful assistant. Answer the question based on the provided document context."),
|
| 298 |
HumanMessage(content=f"Context: {context}\n\nQuestion: {prompt}")
|
| 299 |
]
|
| 300 |
|
| 301 |
response = llm.invoke(messages)
|
| 302 |
st.markdown(response.content)
|
| 303 |
+
|
| 304 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
| 305 |
|
| 306 |
+
# Footer
|
| 307 |
+
st.markdown(
|
| 308 |
+
f'<div class="footer">Analysis Time: {st.session_state.analysis_time:.1f}s | Powered by Google Generative AI</div>',
|
| 309 |
+
unsafe_allow_html=True
|
| 310 |
+
)
|