Srikesh commited on
Commit
00eb76e
Β·
verified Β·
1 Parent(s): e37be53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -199
app.py CHANGED
@@ -1,259 +1,185 @@
1
  import gradio as gr
2
- from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- from langchain_community.vectorstores import FAISS
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from langchain_community.llms import HuggingFacePipeline
6
- from langchain.chains import ConversationalRetrievalChain
7
- from langchain.memory import ConversationBufferMemory
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from pypdf import PdfReader
10
  import torch
 
11
 
12
- # Initialize global variables
13
- vectorstore = None
14
- qa_chain = None
15
- llm_pipeline = None
 
 
16
 
17
- def initialize_llm():
18
- """Initialize the language model (done once at startup)"""
19
- global llm_pipeline
20
 
21
- if llm_pipeline is not None:
22
- return
23
 
24
- print("Loading language model...")
 
25
 
26
- # Use a smaller, efficient model that works without API
27
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
-
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_name,
32
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
- device_map="auto",
34
  low_cpu_mem_usage=True
35
  )
36
 
37
- pipe = pipeline(
38
- "text-generation",
39
- model=model,
40
- tokenizer=tokenizer,
41
- max_new_tokens=512,
42
- temperature=0.7,
43
- top_p=0.95,
44
- repetition_penalty=1.15
45
- )
46
-
47
- llm_pipeline = HuggingFacePipeline(pipeline=pipe)
48
- print("Model loaded successfully!")
49
 
50
  def process_pdf(pdf_file):
51
- """Process uploaded PDF and create vector store"""
52
- global vectorstore, qa_chain
53
 
54
  if pdf_file is None:
55
- return "Please upload a PDF file!", None, None
56
 
57
  try:
58
- # Extract text from PDF
59
  pdf_reader = PdfReader(pdf_file.name)
60
  text = ""
61
  for page in pdf_reader.pages:
62
- text += page.extract_text()
63
 
64
  if not text.strip():
65
- return "Could not extract text from PDF. Please ensure it's a valid PDF with text content.", None, None
66
-
67
- # Split text into chunks
68
- text_splitter = RecursiveCharacterTextSplitter(
69
- chunk_size=1000,
70
- chunk_overlap=200,
71
- length_function=len
72
- )
73
- chunks = text_splitter.split_text(text)
74
-
75
- # Create embeddings (using a lightweight model)
76
- embeddings = HuggingFaceEmbeddings(
77
- model_name="sentence-transformers/all-MiniLM-L6-v2",
78
- model_kwargs={'device': 'cpu'}
79
- )
80
-
81
- # Create vector store
82
- vectorstore = FAISS.from_texts(chunks, embeddings)
83
 
84
- # Initialize LLM if not already done
85
- initialize_llm()
 
 
86
 
87
- # Create memory for conversation
88
- memory = ConversationBufferMemory(
89
- memory_key="chat_history",
90
- return_messages=True,
91
- output_key="answer"
92
- )
93
 
94
- # Create conversational chain
95
- qa_chain = ConversationalRetrievalChain.from_llm(
96
- llm=llm_pipeline,
97
- retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
98
- memory=memory,
99
- return_source_documents=True,
100
- verbose=False
101
- )
102
 
103
- return f"βœ… PDF processed successfully! Extracted {len(chunks)} text chunks. You can now ask questions!", None, None
104
 
105
  except Exception as e:
106
- return f"❌ Error processing PDF: {str(e)}", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def chat(message, history):
109
- """Handle chat interactions"""
110
- global qa_chain
111
 
112
- if qa_chain is None:
113
  return history + [[message, "⚠️ Please upload and process a PDF first!"]]
114
 
115
  if not message.strip():
116
  return history
117
 
118
  try:
119
- # Get response from chain
120
- result = qa_chain({"question": message})
121
- answer = result["answer"]
122
 
123
- # Clean up the answer (remove any system prompts)
124
- if "Answer:" in answer:
125
- answer = answer.split("Answer:")[-1].strip()
126
 
127
- return history + [[message, answer]]
128
 
129
  except Exception as e:
130
  return history + [[message, f"❌ Error: {str(e)}"]]
131
 
132
- def clear_chat():
133
- """Clear chat history and reset chain"""
134
- global qa_chain
135
- if qa_chain is not None and hasattr(qa_chain, 'memory'):
136
- qa_chain.memory.clear()
137
- return None
138
 
139
- # Create Gradio interface
140
- with gr.Blocks(theme=gr.themes.Soft(), title="Chat with PDF") as demo:
141
- gr.Markdown(
142
- """
143
- # πŸ“„ Chat with PDF using AI
144
- Upload a PDF document and ask questions about its content - No API key required!
145
-
146
- **Instructions:**
147
- 1. Upload a PDF file
148
- 2. Click "Process PDF" and wait for confirmation
149
- 3. Start asking questions about your document!
150
- """
151
- )
152
 
153
  with gr.Row():
154
  with gr.Column(scale=1):
155
- pdf_input = gr.File(
156
- label="πŸ“Ž Upload PDF",
157
- file_types=[".pdf"],
158
- type="filepath"
159
- )
160
- process_btn = gr.Button("πŸ”„ Process PDF", variant="primary", size="lg")
161
- status_output = gr.Textbox(
162
- label="πŸ“Š Status",
163
- interactive=False,
164
- lines=3
165
- )
166
 
167
- gr.Markdown(
168
- """
169
- ### πŸ’‘ Tips:
170
- - Processing may take 30-60 seconds
171
- - Ask specific questions about the content
172
- - You can ask follow-up questions
173
- - Best with text-based PDFs (not scanned images)
174
- """
175
- )
176
-
177
  with gr.Column(scale=2):
178
- chatbot = gr.Chatbot(
179
- label="πŸ’¬ Chat History",
180
- height=500,
181
- bubble_full_width=False
182
- )
183
  with gr.Row():
184
- msg = gr.Textbox(
185
- label="Your Question",
186
- placeholder="Ask a question about your PDF...",
187
- lines=2,
188
- scale=4
189
- )
190
- with gr.Row():
191
- submit_btn = gr.Button("πŸ“€ Send", variant="primary", scale=1)
192
- clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
193
-
194
- gr.Markdown(
195
- """
196
- ---
197
- ### πŸ”Œ API Access
198
- Once deployed on Hugging Face Spaces, you can access this via API:
199
- ```python
200
- # Python example
201
- from gradio_client import Client
202
-
203
- client = Client("YOUR_USERNAME/YOUR_SPACE_NAME")
204
-
205
- # Process PDF
206
- result = client.predict("path/to/file.pdf", api_name="/process_pdf")
207
-
208
- # Ask questions
209
- result = client.predict("What is this document about?", [], api_name="/chat")
210
- ```
211
-
212
- ```javascript
213
- // JavaScript example
214
- const response = await fetch("https://YOUR_USERNAME-YOUR_SPACE_NAME.hf.space/api/predict", {
215
- method: "POST",
216
- headers: { "Content-Type": "application/json" },
217
- body: JSON.stringify({
218
- data: ["What is this document about?", []]
219
- })
220
- });
221
- ```
222
- """
223
- )
224
 
225
- # Event handlers
226
- process_btn.click(
227
- fn=process_pdf,
228
- inputs=[pdf_input],
229
- outputs=[status_output, chatbot, msg]
230
- )
231
 
232
- msg.submit(
233
- fn=chat,
234
- inputs=[msg, chatbot],
235
- outputs=[chatbot]
236
- ).then(
237
- fn=lambda: "",
238
- outputs=[msg]
239
- )
240
 
241
- submit_btn.click(
242
- fn=chat,
243
- inputs=[msg, chatbot],
244
- outputs=[chatbot]
245
- ).then(
246
- fn=lambda: "",
247
- outputs=[msg]
248
- )
249
-
250
- clear_btn.click(
251
- fn=clear_chat,
252
- outputs=[chatbot]
253
- )
254
 
255
- # Initialize model on startup
256
- initialize_llm()
257
 
258
  if __name__ == "__main__":
259
- demo.launch(share=False)
 
1
  import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ import numpy as np
 
 
 
 
 
4
  from pypdf import PdfReader
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
 
8
+ # Global variables
9
+ chunks = []
10
+ embeddings = []
11
+ model = None
12
+ tokenizer = None
13
+ embed_model = None
14
 
15
+ def initialize_models():
16
+ """Initialize models on startup"""
17
+ global model, tokenizer, embed_model
18
 
19
+ print("Loading models...")
 
20
 
21
+ # Load embedding model
22
+ embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
23
 
24
+ # Load language model
25
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
+ torch_dtype=torch.float32,
 
30
  low_cpu_mem_usage=True
31
  )
32
 
33
+ print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def process_pdf(pdf_file):
36
+ """Process PDF and create embeddings"""
37
+ global chunks, embeddings, embed_model
38
 
39
  if pdf_file is None:
40
+ return "❌ Please upload a PDF file!", None
41
 
42
  try:
43
+ # Read PDF
44
  pdf_reader = PdfReader(pdf_file.name)
45
  text = ""
46
  for page in pdf_reader.pages:
47
+ text += page.extract_text() + "\n"
48
 
49
  if not text.strip():
50
+ return "❌ Could not extract text from PDF!", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Split into chunks
53
+ chunk_size = 1000
54
+ overlap = 200
55
+ chunks = []
56
 
57
+ for i in range(0, len(text), chunk_size - overlap):
58
+ chunk = text[i:i + chunk_size]
59
+ if chunk.strip():
60
+ chunks.append(chunk)
 
 
61
 
62
+ # Create embeddings
63
+ embeddings = embed_model.encode(chunks, show_progress_bar=False)
 
 
 
 
 
 
64
 
65
+ return f"βœ… PDF processed! Created {len(chunks)} chunks. You can now ask questions!", None
66
 
67
  except Exception as e:
68
+ return f"❌ Error: {str(e)}", None
69
+
70
+ def find_relevant_chunks(query, top_k=3):
71
+ """Find most relevant chunks using cosine similarity"""
72
+ global chunks, embeddings, embed_model
73
+
74
+ if not chunks:
75
+ return []
76
+
77
+ query_embedding = embed_model.encode([query])[0]
78
+
79
+ # Calculate cosine similarity
80
+ similarities = np.dot(embeddings, query_embedding) / (
81
+ np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
82
+ )
83
+
84
+ # Get top k indices
85
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
86
+
87
+ return [chunks[i] for i in top_indices]
88
+
89
+ def generate_response(question, context):
90
+ """Generate response using the language model"""
91
+ global model, tokenizer
92
+
93
+ prompt = f"""<|system|>
94
+ You are a helpful assistant. Answer the question based on the provided context. Be concise and accurate.
95
+ </s>
96
+ <|user|>
97
+ Context: {context}
98
+
99
+ Question: {question}
100
+ </s>
101
+ <|assistant|>
102
+ """
103
+
104
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
105
+
106
+ with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=300,
110
+ temperature=0.7,
111
+ top_p=0.9,
112
+ do_sample=True,
113
+ pad_token_id=tokenizer.eos_token_id
114
+ )
115
+
116
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+
118
+ # Extract only the assistant's response
119
+ if "<|assistant|>" in response:
120
+ response = response.split("<|assistant|>")[-1].strip()
121
+
122
+ return response
123
 
124
  def chat(message, history):
125
+ """Handle chat"""
126
+ global chunks
127
 
128
+ if not chunks:
129
  return history + [[message, "⚠️ Please upload and process a PDF first!"]]
130
 
131
  if not message.strip():
132
  return history
133
 
134
  try:
135
+ # Find relevant context
136
+ relevant_chunks = find_relevant_chunks(message)
137
+ context = "\n\n".join(relevant_chunks)
138
 
139
+ # Generate response
140
+ response = generate_response(message, context)
 
141
 
142
+ return history + [[message, response]]
143
 
144
  except Exception as e:
145
  return history + [[message, f"❌ Error: {str(e)}"]]
146
 
147
+ def clear_all():
148
+ """Clear everything"""
149
+ global chunks, embeddings
150
+ chunks = []
151
+ embeddings = []
152
+ return None, "Ready to process a new PDF"
153
 
154
+ # Create UI
155
+ with gr.Blocks(title="Chat with PDF") as demo:
156
+ gr.Markdown("# πŸ“„ Chat with PDF - Simple Version")
 
 
 
 
 
 
 
 
 
 
157
 
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
+ pdf_input = gr.File(label="πŸ“Ž Upload PDF", file_types=[".pdf"])
161
+ process_btn = gr.Button("πŸ”„ Process PDF", variant="primary")
162
+ status = gr.Textbox(label="Status", lines=3)
163
+ clear_all_btn = gr.Button("πŸ—‘οΈ Clear All")
 
 
 
 
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
165
  with gr.Column(scale=2):
166
+ chatbot = gr.Chatbot(label="πŸ’¬ Chat", height=400)
167
+ msg = gr.Textbox(label="Question", placeholder="Ask about the PDF...")
 
 
 
168
  with gr.Row():
169
+ send_btn = gr.Button("Send", variant="primary")
170
+ clear_btn = gr.Button("Clear Chat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # Events
173
+ process_btn.click(process_pdf, [pdf_input], [status, chatbot])
 
 
 
 
174
 
175
+ msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
176
+ send_btn.click(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
 
 
 
 
 
 
177
 
178
+ clear_btn.click(lambda: None, None, [chatbot])
179
+ clear_all_btn.click(clear_all, None, [chatbot, status])
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Initialize on startup
182
+ initialize_models()
183
 
184
  if __name__ == "__main__":
185
+ demo.launch()