import gradio as gr import torch import logging import json import uuid import os from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig from huggingface_hub import HfApi # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Define output JSON path OUTPUT_DIR = "outputs" OUTPUT_JSON_PATH = os.path.join(OUTPUT_DIR, "captions.json") REPO_ID = "spaces/retromarz/plavu_microsoft-git-large" # Space repository ID # Initialize Hugging Face API HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: logger.error("HF_TOKEN environment variable not set. Please configure it in Space secrets.") raise ValueError("HF_TOKEN is required") api = HfApi() # Ensure output directory exists try: os.makedirs(OUTPUT_DIR, exist_ok=True) if not os.path.exists(OUTPUT_JSON_PATH): with open(OUTPUT_JSON_PATH, 'w') as f: json.dump([], f) logger.info("Output directory and captions.json initialized") except Exception as e: logger.error(f"Failed to initialize output directory: {str(e)}") raise # Initialize model and processor try: processor = AutoProcessor.from_pretrained("microsoft/git-base") quantization_config =None # BitsAndBytesConfig(load_in_8bit=True) if torch.cuda.is_available() else None model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to("cuda" if torch.cuda.is_available() else "cpu") logger.info("Model and processor loaded successfully") except Exception as e: logger.error(f"Failed to load model or processor: {str(e)}") raise # Function to save results to JSON and upload to Space def save_to_json(image_name: str, caption: str, caption_type: str, caption_length: str, error: str = None): try: # Write to local JSON with open(OUTPUT_JSON_PATH, 'r+') as f: data = json.load(f) data.append({ "image_name": image_name, "caption": caption, "caption_type": caption_type, "caption_length": caption_length, "error": error, "timestamp": str(torch.cuda.Event().record()) if torch.cuda.is_available() else None }) f.seek(0) json.dump(data, f, indent=4) logger.info(f"Saved result to {OUTPUT_JSON_PATH}") # Upload to Dataset repository api.upload_file( path_or_fileobj=OUTPUT_JSON_PATH, path_in_repo="outputs/captions.json", repo_id=retromarz/plavu_dataset, repo_type="dataset", token=HF_TOKEN ) logger.info(f"Uploaded {OUTPUT_JSON_PATH} to Space repository {REPO_ID}") except Exception as e: logger.error(f"Error saving to JSON or uploading to Space: {str(e)}") raise # Define the captioning function def generate_caption(input_image: Image, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = "") -> str: logger.info("Starting generate_caption") if input_image is None: error_msg = "Please upload an image." logger.error(error_msg) save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg) return error_msg if not isinstance(input_image, Image.Image): error_msg = "Invalid image format. Please upload a valid JPG or PNG image." logger.error(error_msg) save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg) return error_msg # Generate a unique image name image_name = f"image_{uuid.uuid4().hex}.jpg" logger.info(f"Generated image name: {image_name}") try: # Resize image logger.info("Resizing image") input_image = input_image.resize((256, 256)) # Prepare prompt if not prompt: prompt = f"Generate a {caption_type} caption for this image." logger.info(f"Prompt: {prompt}") # Prepare inputs logger.info("Processing image with processor") inputs = processor(images=input_image, text=prompt, return_tensors="pt").to(model.device) logger.info(f"Inputs prepared: {inputs.keys()}") logger.info(f"Pixel values shape: {inputs['pixel_values'].shape}") logger.info(f"Input IDs shape: {inputs['input_ids'].shape}") # Generate the caption logger.info("Generating caption") with torch.no_grad(): max_length = {"short": 20, "medium": 50, "long": 100}.get(caption_length, 50) generated_ids = model.generate( pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, max_length=max_length, num_beams=5, do_sample=True, # Enable sampling to avoid empty outputs top_k=50, temperature=0.7 ) # Debug raw output logger.info(f"Raw generated IDs: {generated_ids}") logger.info(f"Generated IDs shape: {generated_ids.shape}") # Decode the output logger.info("Decoding caption") raw_caption = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip() caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() logger.info(f"Raw decoded caption (with special tokens): {raw_caption}") logger.info(f"Generated caption: {caption}") # Fallback if caption is empty if not caption or caption in [".", ""]: caption = "No meaningful caption generated." logger.warning("Empty or invalid caption generated, using fallback") # Clean up caption if caption.lower().startswith(("generate a", "write a")): caption = caption.split(".", 1)[1].strip() if "." in caption else caption # Save to JSON and upload to Space logger.info("Saving to JSON") save_to_json(image_name, caption, caption_type, caption_length, error=None) return caption except Exception as e: error_msg = f"Error generating caption: {str(e)}" logger.error(error_msg) save_to_json(image_name, "", caption_type, caption_length, error=error_msg) return error_msg # Function to view caption history def view_caption_history(): try: with open(OUTPUT_JSON_PATH, 'r') as f: data = json.load(f) if not data: return "No captions generated yet." return "\n".join([f"Image: {item['image_name']}, Caption: {item['caption']}, Type: {item['caption_type']}, Length: {item['caption_length']}, Error: {item['error']}" for item in data]) except Exception as e: return f"Error reading caption history: {str(e)}" # Function for batch processing def batch_generate_captions(image_list, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = ""): results = [] for img in image_list: logger.info(f"Processing batch image: {img.name}") img_pil = Image.open(img.name).convert("RGB") caption = generate_caption(img_pil, caption_type, caption_length, prompt) results.append(f"Image {os.path.basename(img.name)}: {caption}") return "\n".join(results) # Create Gradio Blocks interface with gr.Blocks(title="Image Captioning with GIT") as demo: gr.Markdown("# Image Captioning with GIT") gr.Markdown("Upload an image or multiple images to generate captions using the Microsoft/git-large-coco model. Results are saved to outputs/captions.json and uploaded to the Space repository.") # Tab for single image captioning with gr.Tab("Single Image Captioning"): with gr.Row(): with gr.Column(): single_image_input = gr.Image(label="Upload Image", type="pil") single_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive") single_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium") single_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model") single_submit = gr.Button("Generate Caption") single_output = gr.Textbox(label="Generated Caption", placeholder="Caption will appear here") single_submit.click( fn=generate_caption, inputs=[single_image_input, single_caption_type, single_caption_length, single_prompt], outputs=single_output ) # Tab for viewing caption history with gr.Tab("Caption History"): history_output = gr.Textbox(label="Caption History", placeholder="History will appear here") history_button = gr.Button("View History") history_button.click( fn=view_caption_history, inputs=None, outputs=history_output ) # Tab for batch processing with gr.Tab("Batch Image Captioning"): with gr.Row(): with gr.Column(): batch_image_input = gr.Files(label="Upload Multiple Images", file_types=["image"]) batch_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive") batch_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium") batch_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model") batch_submit = gr.Button("Generate Captions") batch_output = gr.Textbox(label="Batch Caption Results", placeholder="Batch results will appear here") batch_submit.click( fn=batch_generate_captions, inputs=[batch_image_input, batch_caption_type, batch_caption_length, batch_prompt], outputs=batch_output ) if __name__ == "__main__": demo.launch()