Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| import torch | |
| from PIL import Image | |
| import io | |
| # Load model and processor (using CPU) | |
| folder_path = "diffusers/shot-categorizer-v0" | |
| model = AutoModelForCausalLM.from_pretrained(folder_path, trust_remote_code=True).eval() | |
| processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True) | |
| # Define analysis function | |
| def analyze_image(image): | |
| # Convert Gradio image input to PIL Image | |
| if isinstance(image, Image.Image): | |
| img = image.convert("RGB") | |
| else: | |
| img = Image.open(io.BytesIO(image)).convert("RGB") | |
| prompts = ["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"] | |
| results = {} | |
| # Process each prompt | |
| with torch.no_grad(): | |
| for prompt in prompts: | |
| inputs = processor(text=prompt, images=img, return_tensors="pt") | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| early_stopping=False, | |
| do_sample=False, | |
| num_beams=3, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, task=prompt, image_size=(img.width, img.height) | |
| ) | |
| results[prompt] = parsed_answer | |
| # Format the output | |
| output_text = "Image Analysis Results:\n\n" | |
| output_text += f"Color: {results['<COLOR>']}\n" | |
| output_text += f"Lighting: {results['<LIGHTING>']}\n" | |
| output_text += f"Lighting Type: {results['<LIGHTING_TYPE>']}\n" | |
| output_text += f"Composition: {results['<COMPOSITION>']}\n" | |
| return output_text | |
| # Create Gradio interface | |
| with gr.Blocks(title="Image Analyzer") as demo: | |
| gr.Markdown("# Image Analysis Demo") | |
| gr.Markdown("Upload an image to analyze its color, lighting, and composition characteristics.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| analyze_button = gr.Button("Analyze Image") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Analysis Results", lines=10) | |
| # Add example images | |
| examples = gr.Examples( | |
| examples=["shot.jpg"], | |
| inputs=image_input, | |
| label="Try with this example" | |
| ) | |
| # Connect the button to the function | |
| analyze_button.click( | |
| fn=analyze_image, | |
| inputs=image_input, | |
| outputs=output_text | |
| ) | |
| # Launch the demo | |
| demo.launch() |