retromarz's picture
Use huggingface_hub for pushing captions.json to Space repo;
b4b2465 verified
raw
history blame
9.3 kB
import gradio as gr
import torch
import logging
import json
import uuid
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import HfApi, Repository
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Define output JSON path
OUTPUT_DIR = "outputs"
OUTPUT_JSON_PATH = os.path.join(OUTPUT_DIR, "captions.json")
# Initialize Hugging Face API and repository
HF_TOKEN = os.environ.get("HF_TOKEN")
REPO_URL = "https://huggingface.co/spaces/retromarz/plavu_microsoft-git-large"
try:
os.makedirs(OUTPUT_DIR, exist_ok=True)
if not os.path.exists(OUTPUT_JSON_PATH):
with open(OUTPUT_JSON_PATH, 'w') as f:
json.dump([], f)
repo = Repository(
local_dir=".", # Use current directory (Space's repo root)
clone_from=REPO_URL,
use_auth_token=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,
git_user="retromarz",
git_email="ma@pnz.de"
)
repo.lfs_track([OUTPUT_JSON_PATH]) # Track large files if needed
logger.info("Hugging Face repository initialized")
except Exception as e:
logger.error(f"Failed to initialize repository: {str(e)}")
raise
# Initialize model and processor
try:
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
quantization_config = BitsAndBytesConfig(load_in_8bit=True) if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", quantization_config=quantization_config).to("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Model and processor loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or processor: {str(e)}")
raise
# Function to save results to JSON and push to Git
def save_to_json(image_name: str, caption: str, caption_type: str, caption_length: str, error: str = None):
try:
# Write to local JSON
with open(OUTPUT_JSON_PATH, 'r+') as f:
data = json.load(f)
data.append({
"image_name": image_name,
"caption": caption,
"caption_type": caption_type,
"caption_length": caption_length,
"error": error,
"timestamp": str(torch.cuda.Event().record()) if torch.cuda.is_available() else None
})
f.seek(0)
json.dump(data, f, indent=4)
logger.info(f"Saved result to {OUTPUT_JSON_PATH}")
# Commit and push to Space repo
repo.git_add(OUTPUT_JSON_PATH)
repo.git_commit(f"Update captions.json with new caption for {image_name}")
repo.git_push()
logger.info("Pushed captions.json to Hugging Face Space repository")
except Exception as e:
logger.error(f"Error saving to JSON or pushing to Git: {str(e)}")
# Define the captioning function
def generate_caption(input_image: Image, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = "") -> str:
logger.info("Starting generate_caption")
if input_image is None:
error_msg = "Please upload an image."
logger.error(error_msg)
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
if not isinstance(input_image, Image.Image):
error_msg = "Invalid image format. Please upload a valid JPG or PNG image."
logger.error(error_msg)
save_to_json("unknown", error_msg, caption_type, caption_length, error=error_msg)
return error_msg
# Generate a unique image name
image_name = f"image_{uuid.uuid4().hex}.jpg"
logger.info(f"Generated image name: {image_name}")
try:
# Resize image
logger.info("Resizing image")
input_image = input_image.resize((256, 256))
# Prepare prompt
if not prompt:
prompt = f"Generate a {caption_type} caption for this image."
logger.info(f"Prompt: {prompt}")
# Prepare inputs
logger.info("Processing image with processor")
inputs = processor(images=input_image, text=prompt, return_tensors="pt").to(model.device)
logger.info(f"Inputs prepared: {inputs.keys()}")
# Generate the caption
logger.info("Generating caption")
with torch.no_grad():
max_length = {"short": 20, "medium": 50, "long": 100}.get(caption_length, 50)
generated_ids = model.generate(
pixel_values=inputs.pixel_values,
input_ids=inputs.input_ids,
max_length=max_length,
num_beams=5
)
# Decode the output
logger.info("Decoding caption")
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# Clean up caption
if caption.lower().startswith(("generate a", "write a")):
caption = caption.split(".", 1)[1].strip() if "." in caption else caption
logger.info(f"Generated caption: {caption}")
# Save to JSON and push to Git
logger.info("Saving to JSON")
save_to_json(image_name, caption, caption_type, caption_length, error=None)
return caption
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
logger.error(error_msg)
save_to_json(image_name, "", caption_type, caption_length, error=error_msg)
return error_msg
# Function to view caption history
def view_caption_history():
try:
with open(OUTPUT_JSON_PATH, 'r') as f:
data = json.load(f)
if not data:
return "No captions generated yet."
return "\n".join([f"Image: {item['image_name']}, Caption: {item['caption']}, Type: {item['caption_type']}, Length: {item['caption_length']}, Error: {item['error']}" for item in data])
except Exception as e:
return f"Error reading caption history: {str(e)}"
# Function for batch processing
def batch_generate_captions(image_list, caption_type: str = "descriptive", caption_length: str = "medium", prompt: str = ""):
results = []
for img in image_list:
logger.info(f"Processing batch image: {img.name}")
img_pil = Image.open(img.name).convert("RGB")
caption = generate_caption(img_pil, caption_type, caption_length, prompt)
results.append(f"Image {os.path.basename(img.name)}: {caption}")
return "\n".join(results)
# Create Gradio Blocks interface
with gr.Blocks(title="Image Captioning with GIT") as demo:
gr.Markdown("# Image Captioning with GIT")
gr.Markdown("Upload an image or multiple images to generate captions using the Microsoft/git-large-coco model. Results are saved to outputs/captions.json and pushed to the Git repository.")
# Tab for single image captioning
with gr.Tab("Single Image Captioning"):
with gr.Row():
with gr.Column():
single_image_input = gr.Image(label="Upload Image", type="pil")
single_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive")
single_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
single_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model")
single_submit = gr.Button("Generate Caption")
single_output = gr.Textbox(label="Generated Caption", placeholder="Caption will appear here")
single_submit.click(
fn=generate_caption,
inputs=[single_image_input, single_caption_type, single_caption_length, single_prompt],
outputs=single_output
)
# Tab for viewing caption history
with gr.Tab("Caption History"):
history_output = gr.Textbox(label="Caption History", placeholder="History will appear here")
history_button = gr.Button("View History")
history_button.click(
fn=view_caption_history,
inputs=None,
outputs=history_output
)
# Tab for batch processing
with gr.Tab("Batch Image Captioning"):
with gr.Row():
with gr.Column():
batch_image_input = gr.Files(label="Upload Multiple Images", file_types=["image"])
batch_caption_type = gr.Dropdown(choices=["descriptive", "casual", "social media"], label="Caption Type", value="descriptive")
batch_caption_length = gr.Dropdown(choices=["short", "medium", "long"], label="Caption Length", value="medium")
batch_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for the model")
batch_submit = gr.Button("Generate Captions")
batch_output = gr.Textbox(label="Batch Caption Results", placeholder="Batch results will appear here")
batch_submit.click(
fn=batch_generate_captions,
inputs=[batch_image_input, batch_caption_type, batch_caption_length, batch_prompt],
outputs=batch_output
)
if __name__ == "__main__":
demo.launch()