Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import logging | |
| import json | |
| import uuid | |
| import os | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig | |
| from huggingface_hub import HfApi | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Define output JSON path | |
| OUTPUT_DIR = "outputs" | |
| OUTPUT_JSON_PATH = os.path.join(OUTPUT_DIR, "captions.json") | |
| REPO_ID = "spaces/retromarz/plavu_microsoft-git-large" # Space repository ID | |
| # Initialize Hugging Face API | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| logger.error("HF_TOKEN environment variable not set. Please configure it in Space secrets.") | |
| raise ValueError("HF_TOKEN is required") | |
| api = HfApi() | |
| # Ensure output directory exists | |
| try: | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| if not os.path.exists(OUTPUT_JSON_PATH): | |
| with open(OUTPUT_JSON_PATH, 'w') as f: | |
| json.dump([], f) | |
| logger.info("Output directory and captions.json initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize output directory: {str(e)}") | |
| raise | |
| # Initialize model and processor | |
| try: | |
| processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
| quantization_config =None # BitsAndBytesConfig(load_in_8bit=True) if torch.cuda.is_available() else None | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info("Model and processor loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model or processor: {str(e)}") | |
| raise | |
| # Function to save results to JSON and upload to Space | |
| def save_to_json(image_name: str, caption: str, caption_type: str, caption_length: str, error: str = None): | |
| try: | |
| # Write to local JSON | |
| with open(OUTPUT_JSON_PATH, 'r+') as f: | |
| data = json.load(f) | |
| data.append({ | |
| "image_name": image_name, | |
| "caption": caption, | |
| "caption_type": caption_type, | |
| "caption_length": caption_length, | |
| "error": error, | |
| "timestamp": str(torch.cuda.Event().record()) if torch.cuda.is_available() else None | |
| }) | |
| f.seek(0) | |
| json.dump(data, f, indent=4) | |
| logger.info(f"Saved result to {OUTPUT_JSON_PATH}") | |
| # Upload to Dataset repository | |
| api.upload_file( | |
| path_or_fileobj=OUTPUT_JSON_PATH, | |
| path_in_repo="outputs/captions.json", | |
| repo_id=retromarz/plavu_dataset, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| logger.info(f"Uploaded {OUTPUT_JSON_PATH} to Space repository {REPO_ID}") | |
| except Exception as e: | |
| logger.error(f"Error saving to JSON or uploading to Space: {str(e)}") | |
| raise | |
| # Define the captioning function | |
| def generate_caption(input_image: Image, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = "") -> str: | |
| logger.info("Starting generate_caption") | |
| if input_image is None: | |
| error_msg = "Please upload an image." | |
| logger.error(error_msg) | |
| save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg) | |
| return error_msg | |
| if not isinstance(input_image, Image.Image): | |
| error_msg = "Invalid image format. Please upload a valid JPG or PNG image." | |
| logger.error(error_msg) | |
| save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg) | |
| return error_msg | |
| # Generate a unique image name | |
| image_name = f"image_{uuid.uuid4().hex}.jpg" | |
| logger.info(f"Generated image name: {image_name}") | |
| try: | |
| # Resize image | |
| logger.info("Resizing image") | |
| input_image = input_image.resize((256, 256)) | |
| # Prepare prompt | |
| if not prompt: | |
| prompt = f"Generate a {caption_type} caption for this image." | |
| logger.info(f"Prompt: {prompt}") | |
| # Prepare inputs | |
| logger.info("Processing image with processor") | |
| inputs = processor(images=input_image, text=prompt, return_tensors="pt").to(model.device) | |
| logger.info(f"Inputs prepared: {inputs.keys()}") | |
| logger.info(f"Pixel values shape: {inputs['pixel_values'].shape}") | |
| logger.info(f"Input IDs shape: {inputs['input_ids'].shape}") | |
| # Generate the caption | |
| logger.info("Generating caption") | |
| with torch.no_grad(): | |
| max_length = {"short": 20, "medium": 50, "long": 100}.get(caption_length, 50) | |
| generated_ids = model.generate( | |
| pixel_values=inputs.pixel_values, | |
| input_ids=inputs.input_ids, | |
| max_length=max_length, | |
| num_beams=5, | |
| do_sample=True, # Enable sampling to avoid empty outputs | |
| top_k=50, | |
| temperature=0.7 | |
| ) | |
| # Debug raw output | |
| logger.info(f"Raw generated IDs: {generated_ids}") | |
| logger.info(f"Generated IDs shape: {generated_ids.shape}") | |
| # Decode the output | |
| logger.info("Decoding caption") | |
| raw_caption = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip() | |
| caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| logger.info(f"Raw decoded caption (with special tokens): {raw_caption}") | |
| logger.info(f"Generated caption: {caption}") | |
| # Fallback if caption is empty | |
| if not caption or caption in [".", ""]: | |
| caption = "No meaningful caption generated." | |
| logger.warning("Empty or invalid caption generated, using fallback") | |
| # Clean up caption | |
| if caption.lower().startswith(("generate a", "write a")): | |
| caption = caption.split(".", 1)[1].strip() if "." in caption else caption | |
| # Save to JSON and upload to Space | |
| logger.info("Saving to JSON") | |
| save_to_json(image_name, caption, caption_type, caption_length, error=None) | |
| return caption | |
| except Exception as e: | |
| error_msg = f"Error generating caption: {str(e)}" | |
| logger.error(error_msg) | |
| save_to_json(image_name, "", caption_type, caption_length, error=error_msg) | |
| return error_msg | |
| # Function to view caption history | |
| def view_caption_history(): | |
| try: | |
| with open(OUTPUT_JSON_PATH, 'r') as f: | |
| data = json.load(f) | |
| if not data: | |
| return "No captions generated yet." | |
| return "\n".join([f"Image: {item['image_name']}, Caption: {item['caption']}, Type: {item['caption_type']}, Length: {item['caption_length']}, Error: {item['error']}" for item in data]) | |
| except Exception as e: | |
| return f"Error reading caption history: {str(e)}" | |
| # Function for batch processing | |
| def batch_generate_captions(image_list, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = ""): | |
| results = [] | |
| for img in image_list: | |
| logger.info(f"Processing batch image: {img.name}") | |
| img_pil = Image.open(img.name).convert("RGB") | |
| caption = generate_caption(img_pil, caption_type, caption_length, prompt) | |
| results.append(f"Image {os.path.basename(img.name)}: {caption}") | |
| return "\n".join(results) | |
| # Create Gradio Blocks interface | |
| with gr.Blocks(title="Image Captioning with GIT") as demo: | |
| gr.Markdown("# Image Captioning with GIT") | |
| gr.Markdown("Upload an image or multiple images to generate captions using the Microsoft/git-large-coco model. Results are saved to outputs/captions.json and uploaded to the Space repository.") | |
| # Tab for single image captioning | |
| with gr.Tab("Single Image Captioning"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_image_input = gr.Image(label="Upload Image", type="pil") | |
| single_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive") | |
| single_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium") | |
| single_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model") | |
| single_submit = gr.Button("Generate Caption") | |
| single_output = gr.Textbox(label="Generated Caption", placeholder="Caption will appear here") | |
| single_submit.click( | |
| fn=generate_caption, | |
| inputs=[single_image_input, single_caption_type, single_caption_length, single_prompt], | |
| outputs=single_output | |
| ) | |
| # Tab for viewing caption history | |
| with gr.Tab("Caption History"): | |
| history_output = gr.Textbox(label="Caption History", placeholder="History will appear here") | |
| history_button = gr.Button("View History") | |
| history_button.click( | |
| fn=view_caption_history, | |
| inputs=None, | |
| outputs=history_output | |
| ) | |
| # Tab for batch processing | |
| with gr.Tab("Batch Image Captioning"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| batch_image_input = gr.Files(label="Upload Multiple Images", file_types=["image"]) | |
| batch_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive") | |
| batch_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium") | |
| batch_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model") | |
| batch_submit = gr.Button("Generate Captions") | |
| batch_output = gr.Textbox(label="Batch Caption Results", placeholder="Batch results will appear here") | |
| batch_submit.click( | |
| fn=batch_generate_captions, | |
| inputs=[batch_image_input, batch_caption_type, batch_caption_length, batch_prompt], | |
| outputs=batch_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |