pdf_chat / app.py
Srikesh's picture
Update app.py
00eb76e verified
raw
history blame
5.47 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import numpy as np
from pypdf import PdfReader
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Global variables
chunks = []
embeddings = []
model = None
tokenizer = None
embed_model = None
def initialize_models():
"""Initialize models on startup"""
global model, tokenizer, embed_model
print("Loading models...")
# Load embedding model
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Load language model
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
print("Models loaded successfully!")
def process_pdf(pdf_file):
"""Process PDF and create embeddings"""
global chunks, embeddings, embed_model
if pdf_file is None:
return "❌ Please upload a PDF file!", None
try:
# Read PDF
pdf_reader = PdfReader(pdf_file.name)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
if not text.strip():
return "❌ Could not extract text from PDF!", None
# Split into chunks
chunk_size = 1000
overlap = 200
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk.strip():
chunks.append(chunk)
# Create embeddings
embeddings = embed_model.encode(chunks, show_progress_bar=False)
return f"βœ… PDF processed! Created {len(chunks)} chunks. You can now ask questions!", None
except Exception as e:
return f"❌ Error: {str(e)}", None
def find_relevant_chunks(query, top_k=3):
"""Find most relevant chunks using cosine similarity"""
global chunks, embeddings, embed_model
if not chunks:
return []
query_embedding = embed_model.encode([query])[0]
# Calculate cosine similarity
similarities = np.dot(embeddings, query_embedding) / (
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
)
# Get top k indices
top_indices = np.argsort(similarities)[-top_k:][::-1]
return [chunks[i] for i in top_indices]
def generate_response(question, context):
"""Generate response using the language model"""
global model, tokenizer
prompt = f"""<|system|>
You are a helpful assistant. Answer the question based on the provided context. Be concise and accurate.
</s>
<|user|>
Context: {context}
Question: {question}
</s>
<|assistant|>
"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|assistant|>" in response:
response = response.split("<|assistant|>")[-1].strip()
return response
def chat(message, history):
"""Handle chat"""
global chunks
if not chunks:
return history + [[message, "⚠️ Please upload and process a PDF first!"]]
if not message.strip():
return history
try:
# Find relevant context
relevant_chunks = find_relevant_chunks(message)
context = "\n\n".join(relevant_chunks)
# Generate response
response = generate_response(message, context)
return history + [[message, response]]
except Exception as e:
return history + [[message, f"❌ Error: {str(e)}"]]
def clear_all():
"""Clear everything"""
global chunks, embeddings
chunks = []
embeddings = []
return None, "Ready to process a new PDF"
# Create UI
with gr.Blocks(title="Chat with PDF") as demo:
gr.Markdown("# πŸ“„ Chat with PDF - Simple Version")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="πŸ“Ž Upload PDF", file_types=[".pdf"])
process_btn = gr.Button("πŸ”„ Process PDF", variant="primary")
status = gr.Textbox(label="Status", lines=3)
clear_all_btn = gr.Button("πŸ—‘οΈ Clear All")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="πŸ’¬ Chat", height=400)
msg = gr.Textbox(label="Question", placeholder="Ask about the PDF...")
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear Chat")
# Events
process_btn.click(process_pdf, [pdf_input], [status, chatbot])
msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
send_btn.click(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
clear_btn.click(lambda: None, None, [chatbot])
clear_all_btn.click(clear_all, None, [chatbot, status])
# Initialize on startup
initialize_models()
if __name__ == "__main__":
demo.launch()