|
|
import gradio as gr |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
from pypdf import PdfReader |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
embeddings = [] |
|
|
model = None |
|
|
tokenizer = None |
|
|
embed_model = None |
|
|
|
|
|
def initialize_models(): |
|
|
"""Initialize models on startup""" |
|
|
global model, tokenizer, embed_model |
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
|
|
|
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
print("Models loaded successfully!") |
|
|
|
|
|
def process_pdf(pdf_file): |
|
|
"""Process PDF and create embeddings""" |
|
|
global chunks, embeddings, embed_model |
|
|
|
|
|
if pdf_file is None: |
|
|
return "β Please upload a PDF file!", None |
|
|
|
|
|
try: |
|
|
|
|
|
pdf_reader = PdfReader(pdf_file.name) |
|
|
text = "" |
|
|
for page in pdf_reader.pages: |
|
|
text += page.extract_text() + "\n" |
|
|
|
|
|
if not text.strip(): |
|
|
return "β Could not extract text from PDF!", None |
|
|
|
|
|
|
|
|
chunk_size = 1000 |
|
|
overlap = 200 |
|
|
chunks = [] |
|
|
|
|
|
for i in range(0, len(text), chunk_size - overlap): |
|
|
chunk = text[i:i + chunk_size] |
|
|
if chunk.strip(): |
|
|
chunks.append(chunk) |
|
|
|
|
|
|
|
|
embeddings = embed_model.encode(chunks, show_progress_bar=False) |
|
|
|
|
|
return f"β
PDF processed! Created {len(chunks)} chunks. You can now ask questions!", None |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}", None |
|
|
|
|
|
def find_relevant_chunks(query, top_k=3): |
|
|
"""Find most relevant chunks using cosine similarity""" |
|
|
global chunks, embeddings, embed_model |
|
|
|
|
|
if not chunks: |
|
|
return [] |
|
|
|
|
|
query_embedding = embed_model.encode([query])[0] |
|
|
|
|
|
|
|
|
similarities = np.dot(embeddings, query_embedding) / ( |
|
|
np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding) |
|
|
) |
|
|
|
|
|
|
|
|
top_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
|
|
|
return [chunks[i] for i in top_indices] |
|
|
|
|
|
def generate_response(question, context): |
|
|
"""Generate response using the language model""" |
|
|
global model, tokenizer |
|
|
|
|
|
prompt = f"""<|system|> |
|
|
You are a helpful assistant. Answer the question based on the provided context. Be concise and accurate. |
|
|
</s> |
|
|
<|user|> |
|
|
Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
</s> |
|
|
<|assistant|> |
|
|
""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=300, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "<|assistant|>" in response: |
|
|
response = response.split("<|assistant|>")[-1].strip() |
|
|
|
|
|
return response |
|
|
|
|
|
def chat(message, history): |
|
|
"""Handle chat""" |
|
|
global chunks |
|
|
|
|
|
if not chunks: |
|
|
return history + [[message, "β οΈ Please upload and process a PDF first!"]] |
|
|
|
|
|
if not message.strip(): |
|
|
return history |
|
|
|
|
|
try: |
|
|
|
|
|
relevant_chunks = find_relevant_chunks(message) |
|
|
context = "\n\n".join(relevant_chunks) |
|
|
|
|
|
|
|
|
response = generate_response(message, context) |
|
|
|
|
|
return history + [[message, response]] |
|
|
|
|
|
except Exception as e: |
|
|
return history + [[message, f"β Error: {str(e)}"]] |
|
|
|
|
|
def clear_all(): |
|
|
"""Clear everything""" |
|
|
global chunks, embeddings |
|
|
chunks = [] |
|
|
embeddings = [] |
|
|
return None, "Ready to process a new PDF" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Chat with PDF") as demo: |
|
|
gr.Markdown("# π Chat with PDF - Simple Version") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
pdf_input = gr.File(label="π Upload PDF", file_types=[".pdf"]) |
|
|
process_btn = gr.Button("π Process PDF", variant="primary") |
|
|
status = gr.Textbox(label="Status", lines=3) |
|
|
clear_all_btn = gr.Button("ποΈ Clear All") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
chatbot = gr.Chatbot(label="π¬ Chat", height=400) |
|
|
msg = gr.Textbox(label="Question", placeholder="Ask about the PDF...") |
|
|
with gr.Row(): |
|
|
send_btn = gr.Button("Send", variant="primary") |
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
|
|
|
process_btn.click(process_pdf, [pdf_input], [status, chatbot]) |
|
|
|
|
|
msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg]) |
|
|
send_btn.click(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg]) |
|
|
|
|
|
clear_btn.click(lambda: None, None, [chatbot]) |
|
|
clear_all_btn.click(clear_all, None, [chatbot, status]) |
|
|
|
|
|
|
|
|
initialize_models() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |