retromarz's picture
switched to dataset for json storage
385cf4a verified
import gradio as gr
import torch
import logging
import json
import uuid
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import HfApi
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Define output JSON path
OUTPUT_DIR = "outputs"
OUTPUT_JSON_PATH = os.path.join(OUTPUT_DIR, "captions.json")
REPO_ID = "spaces/retromarz/plavu_microsoft-git-large" # Space repository ID
# Initialize Hugging Face API
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
logger.error("HF_TOKEN environment variable not set. Please configure it in Space secrets.")
raise ValueError("HF_TOKEN is required")
api = HfApi()
# Ensure output directory exists
try:
os.makedirs(OUTPUT_DIR, exist_ok=True)
if not os.path.exists(OUTPUT_JSON_PATH):
with open(OUTPUT_JSON_PATH, 'w') as f:
json.dump([], f)
logger.info("Output directory and captions.json initialized")
except Exception as e:
logger.error(f"Failed to initialize output directory: {str(e)}")
raise
# Initialize model and processor
try:
processor = AutoProcessor.from_pretrained("microsoft/git-base")
quantization_config =None # BitsAndBytesConfig(load_in_8bit=True) if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Model and processor loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or processor: {str(e)}")
raise
# Function to save results to JSON and upload to Space
def save_to_json(image_name: str, caption: str, caption_type: str, caption_length: str, error: str = None):
try:
# Write to local JSON
with open(OUTPUT_JSON_PATH, 'r+') as f:
data = json.load(f)
data.append({
"image_name": image_name,
"caption": caption,
"caption_type": caption_type,
"caption_length": caption_length,
"error": error,
"timestamp": str(torch.cuda.Event().record()) if torch.cuda.is_available() else None
})
f.seek(0)
json.dump(data, f, indent=4)
logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
# Upload to Dataset repository
api.upload_file(
path_or_fileobj=OUTPUT_JSON_PATH,
path_in_repo="outputs/captions.json",
repo_id=retromarz/plavu_dataset,
repo_type="dataset",
token=HF_TOKEN
)
logger.info(f"Uploaded {OUTPUT_JSON_PATH} to Space repository {REPO_ID}")
except Exception as e:
logger.error(f"Error saving to JSON or uploading to Space: {str(e)}")
raise
# Define the captioning function
def generate_caption(input_image: Image, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = "") -> str:
logger.info("Starting generate_caption")
if input_image is None:
error_msg = "Please upload an image."
logger.error(error_msg)
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
if not isinstance(input_image, Image.Image):
error_msg = "Invalid image format. Please upload a valid JPG or PNG image."
logger.error(error_msg)
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
# Generate a unique image name
image_name = f"image_{uuid.uuid4().hex}.jpg"
logger.info(f"Generated image name: {image_name}")
try:
# Resize image
logger.info("Resizing image")
input_image = input_image.resize((256, 256))
# Prepare prompt
if not prompt:
prompt = f"Generate a {caption_type} caption for this image."
logger.info(f"Prompt: {prompt}")
# Prepare inputs
logger.info("Processing image with processor")
inputs = processor(images=input_image, text=prompt, return_tensors="pt").to(model.device)
logger.info(f"Inputs prepared: {inputs.keys()}")
logger.info(f"Pixel values shape: {inputs['pixel_values'].shape}")
logger.info(f"Input IDs shape: {inputs['input_ids'].shape}")
# Generate the caption
logger.info("Generating caption")
with torch.no_grad():
max_length = {"short": 20, "medium": 50, "long": 100}.get(caption_length, 50)
generated_ids = model.generate(
pixel_values=inputs.pixel_values,
input_ids=inputs.input_ids,
max_length=max_length,
num_beams=5,
do_sample=True, # Enable sampling to avoid empty outputs
top_k=50,
temperature=0.7
)
# Debug raw output
logger.info(f"Raw generated IDs: {generated_ids}")
logger.info(f"Generated IDs shape: {generated_ids.shape}")
# Decode the output
logger.info("Decoding caption")
raw_caption = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip()
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
logger.info(f"Raw decoded caption (with special tokens): {raw_caption}")
logger.info(f"Generated caption: {caption}")
# Fallback if caption is empty
if not caption or caption in [".", ""]:
caption = "No meaningful caption generated."
logger.warning("Empty or invalid caption generated, using fallback")
# Clean up caption
if caption.lower().startswith(("generate a", "write a")):
caption = caption.split(".", 1)[1].strip() if "." in caption else caption
# Save to JSON and upload to Space
logger.info("Saving to JSON")
save_to_json(image_name, caption, caption_type, caption_length, error=None)
return caption
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
logger.error(error_msg)
save_to_json(image_name, "", caption_type, caption_length, error=error_msg)
return error_msg
# Function to view caption history
def view_caption_history():
try:
with open(OUTPUT_JSON_PATH, 'r') as f:
data = json.load(f)
if not data:
return "No captions generated yet."
return "\n".join([f"Image: {item['image_name']}, Caption: {item['caption']}, Type: {item['caption_type']}, Length: {item['caption_length']}, Error: {item['error']}" for item in data])
except Exception as e:
return f"Error reading caption history: {str(e)}"
# Function for batch processing
def batch_generate_captions(image_list, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = ""):
results = []
for img in image_list:
logger.info(f"Processing batch image: {img.name}")
img_pil = Image.open(img.name).convert("RGB")
caption = generate_caption(img_pil, caption_type, caption_length, prompt)
results.append(f"Image {os.path.basename(img.name)}: {caption}")
return "\n".join(results)
# Create Gradio Blocks interface
with gr.Blocks(title="Image Captioning with GIT") as demo:
gr.Markdown("# Image Captioning with GIT")
gr.Markdown("Upload an image or multiple images to generate captions using the Microsoft/git-large-coco model. Results are saved to outputs/captions.json and uploaded to the Space repository.")
# Tab for single image captioning
with gr.Tab("Single Image Captioning"):
with gr.Row():
with gr.Column():
single_image_input = gr.Image(label="Upload Image", type="pil")
single_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive")
single_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
single_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model")
single_submit = gr.Button("Generate Caption")
single_output = gr.Textbox(label="Generated Caption", placeholder="Caption will appear here")
single_submit.click(
fn=generate_caption,
inputs=[single_image_input, single_caption_type, single_caption_length, single_prompt],
outputs=single_output
)
# Tab for viewing caption history
with gr.Tab("Caption History"):
history_output = gr.Textbox(label="Caption History", placeholder="History will appear here")
history_button = gr.Button("View History")
history_button.click(
fn=view_caption_history,
inputs=None,
outputs=history_output
)
# Tab for batch processing
with gr.Tab("Batch Image Captioning"):
with gr.Row():
with gr.Column():
batch_image_input = gr.Files(label="Upload Multiple Images", file_types=["image"])
batch_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive")
batch_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
batch_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model")
batch_submit = gr.Button("Generate Captions")
batch_output = gr.Textbox(label="Batch Caption Results", placeholder="Batch results will appear here")
batch_submit.click(
fn=batch_generate_captions,
inputs=[batch_image_input, batch_caption_type, batch_caption_length, batch_prompt],
outputs=batch_output
)
if __name__ == "__main__":
demo.launch()