Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| import os | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from colpali_engine.models.paligemma_colbert_architecture import ColPali | |
| from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator | |
| from colpali_engine.utils.colpali_processing_utils import ( | |
| process_images, | |
| process_queries, | |
| ) | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import AutoProcessor, Idefics3ForConditionalGeneration | |
| import re | |
| import time | |
| from PIL import Image | |
| import torch | |
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| import time | |
| def model_inference( | |
| images, text, assistant_prefix= "Réfléchis step by step. Répond uniquement avec les informations du document fourni.", decoding_strategy = "Greedy", temperature= 0.4, max_new_tokens=512, | |
| repetition_penalty=1.2, top_p=0.8 | |
| ): | |
| ## Load idefics | |
| id_processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3") | |
| id_model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2" | |
| ).to("cuda") | |
| BAD_WORDS_IDS = id_processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | |
| EOS_WORDS_IDS = [id_processor.tokenizer.eos_token_id] | |
| images = [Image.open(image[0]) for image in images] | |
| if text == "" and not images: | |
| gr.Error("Please input a query and optionally image(s).") | |
| if text == "" and images: | |
| gr.Error("Please input a text query along the image(s).") | |
| if isinstance(images, Image.Image): | |
| images = [images] | |
| resulting_messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "image"}] + [ | |
| {"type": "text", "text": text} | |
| ] | |
| } | |
| ] | |
| if assistant_prefix: | |
| text = f"{assistant_prefix} {text}" | |
| prompt = id_processor.apply_chat_template(resulting_messages, add_generation_prompt=True) | |
| inputs = id_processor(text=prompt, images=[images], return_tensors="pt") | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| generation_args = { | |
| "max_new_tokens": max_new_tokens, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| assert decoding_strategy in [ | |
| "Greedy", | |
| "Top P Sampling", | |
| ] | |
| if decoding_strategy == "Greedy": | |
| generation_args["do_sample"] = False | |
| elif decoding_strategy == "Top P Sampling": | |
| generation_args["temperature"] = temperature | |
| generation_args["do_sample"] = True | |
| generation_args["top_p"] = top_p | |
| generation_args.update(inputs) | |
| # Generate | |
| generated_ids = id_model.generate(**generation_args) | |
| generated_texts = id_processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True) | |
| return generated_texts[0] | |
| def search(query: str, ds, images, k): | |
| # Load colpali model | |
| model_name = "vidore/colpali-v1.2" | |
| token = os.environ.get("HF_TOKEN") | |
| model = ColPali.from_pretrained( | |
| "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval() | |
| model.load_adapter(model_name) | |
| model = model.eval() | |
| processor = AutoProcessor.from_pretrained(model_name, token = token) | |
| mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| start = time.time() | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = process_queries(processor, [query], mock_image) | |
| batch_query = {k: v.to(device) for k, v in batch_query.items()} | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| retriever_evaluator = CustomEvaluator(is_multi_vector=True) | |
| scores = retriever_evaluator.evaluate(qs, ds) | |
| top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] | |
| results = [] | |
| print(images) | |
| print(top_k_indices) | |
| for idx in top_k_indices: | |
| print(idx) | |
| results.append((images[idx])) #, f"Page {idx}" | |
| print(f"Temps: {time.time()- start} s") | |
| del model | |
| del processor | |
| print("done") | |
| return results | |
| def index(files, ds): | |
| print("Converting files") | |
| images = convert_files(files) | |
| print(f"Files converted with {len(images)} images.") | |
| return index_gpu(images, ds) | |
| def convert_files(files): | |
| images = [] | |
| for f in files: | |
| images.extend(convert_from_path(f, thread_count=4)) | |
| if len(images) >= 250: | |
| raise gr.Error("The number of images in the dataset should be less than 150.") | |
| return images | |
| def index_gpu(images, ds): | |
| """Example script to run inference with ColPali""" | |
| # Load colpali model | |
| model_name = "vidore/colpali-v1.2" | |
| token = os.environ.get("HF_TOKEN") | |
| model = ColPali.from_pretrained( | |
| "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval() | |
| model.load_adapter(model_name) | |
| model = model.eval() | |
| processor = AutoProcessor.from_pretrained(model_name, token = token) | |
| mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
| # run inference - docs | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: process_images(processor, x), | |
| ) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| for batch_doc in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| del model | |
| del processor | |
| print("done") | |
| return f"Uploaded and converted {len(images)} pages", ds, images | |
| def get_example(): | |
| return [ | |
| [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"], | |
| [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"], | |
| [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau de la répartition homme femme"], | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ColPali + Idefics3: Efficient Document Retrieval with Vision Language Models 📚") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 1️⃣ Upload PDFs") | |
| file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs") | |
| message = gr.Textbox("Files not yet uploaded", label="Status") | |
| convert_button = gr.Button("🔄 Index documents") | |
| embeds = gr.State(value=[]) | |
| imgs = gr.State(value=[]) | |
| img_chunk = gr.State(value=[]) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 2️⃣ Search with ColPali") | |
| query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
| k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) | |
| search_button = gr.Button("🔍 Search", variant="primary") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=get_example(), | |
| inputs=[file, query], | |
| ) | |
| # Define the actions | |
| output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True) | |
| convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) | |
| search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]) | |
| gr.Markdown("## 3️⃣ Get your answer with Idefics") | |
| answer_button = gr.Button("Answer", variant="primary") | |
| output = gr.Markdown(label="Output") | |
| answer_button.click(model_inference, inputs=[output_gallery, query], outputs=output) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(debug=True) |