Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import tempfile | |
| import gradio as gr | |
| import requests | |
| import uvicorn | |
| import torch | |
| import base64 | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import RedirectResponse, StreamingResponse | |
| from typing import List | |
| from pdf2image import convert_from_bytes | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoProcessor | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.pagesizes import letter | |
| from io import BytesIO | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main'))) | |
| from colpali_engine.models.paligemma_colbert_architecture import ColPali | |
| from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator | |
| from colpali_engine.utils.colpali_processing_utils import ( | |
| process_images, | |
| process_queries, | |
| ) | |
| app = FastAPI() | |
| # Load model | |
| model_name = "vidore/colpali" | |
| token = os.environ.get("HF_TOKEN") | |
| model = ColPali.from_pretrained( | |
| "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval() | |
| model.load_adapter(model_name) | |
| processor = AutoProcessor.from_pretrained(model_name, token = token) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
| # In-memory storage | |
| ds = [] | |
| images = [] | |
| def read_root(): | |
| return RedirectResponse(url="/docs") | |
| async def index(files: List[UploadFile] = File(...)): | |
| global ds, images | |
| images = [] | |
| ds = [] | |
| for file in files: | |
| content = await file.read() | |
| pdf_image_list = convert_from_bytes(content) | |
| images.extend(pdf_image_list) | |
| # Create embeddings for each file and load in memory storage | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: process_images(processor, x), | |
| ) | |
| for batch_doc in dataloader: | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return {"message": f"Uploaded and converted {len(images)} pages"} | |
| def generate_pdf(results): | |
| pdf_buffer = BytesIO() | |
| c = canvas.Canvas(pdf_buffer, pagesize=letter) | |
| width, height = letter | |
| for result in results: | |
| img_base64 = result["image"] | |
| img_data = base64.b64decode(img_base64) | |
| # Create a temporary file to hold the image | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
| temp_file.write(img_data) | |
| temp_file.flush() | |
| # Draw the image from the temporary file | |
| c.drawImage(temp_file.name, 0, 0, width, height) | |
| c.showPage() | |
| # Clean up the temporary file | |
| os.remove(temp_file.name) | |
| c.save() | |
| pdf_buffer.seek(0) | |
| return pdf_buffer | |
| async def search(query: str, k: int = 1): | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = process_queries(processor, [query], mock_image) | |
| batch_query = {k: v.to(device) for k, v in batch_query.items()} | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| retriever_evaluator = CustomEvaluator(is_multi_vector=True) | |
| scores = retriever_evaluator.evaluate(qs, ds) | |
| top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] | |
| results = [] | |
| for idx in top_k_indices: | |
| img_byte_arr = BytesIO() | |
| images[idx].save(img_byte_arr, format='PNG') | |
| img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| results.append({"image": img_base64, "page": f"Page {idx}"}) | |
| pdf_buffer = generate_pdf(results) | |
| # Use StreamingResponse to handle in-memory file | |
| response = StreamingResponse(pdf_buffer, media_type='application/pdf') | |
| response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"' | |
| return response | |
| async def recommendation(file: UploadFile = File(...), k: int = 10): | |
| content = await file.read() | |
| pdf_image_list = convert_from_bytes(content) | |
| qs = [] | |
| dataloader = DataLoader( | |
| pdf_image_list, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: process_images(processor, x), | |
| ) | |
| for batch_query in dataloader: | |
| with torch.no_grad(): | |
| batch_query = {k: v.to(device) for k, v in batch_query.items()} | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| retriever_evaluator = CustomEvaluator(is_multi_vector=True) | |
| scores = retriever_evaluator.evaluate(qs, ds) | |
| top_k_indices = scores.argsort(axis=1)[0][-k-1:-1][::-1] | |
| results = [] | |
| for idx in top_k_indices: | |
| img_byte_arr = BytesIO() | |
| images[idx].save(img_byte_arr, format='PNG') | |
| img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| results.append({"image": img_base64, "page": f"Page {idx}"}) | |
| pdf_buffer = generate_pdf(results) | |
| response = StreamingResponse(pdf_buffer, media_type='application/pdf') | |
| response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"' | |
| return response | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |