Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import os | |
| import sys | |
| from huggingface_hub import login | |
| from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM | |
| import uvicorn | |
| # Import spaces module for ZeroGPU support | |
| try: | |
| import spaces | |
| has_spaces = True | |
| print("ZeroGPU support enabled via spaces module") | |
| except ImportError: | |
| has_spaces = False | |
| print("spaces module not found, ZeroGPU features will be disabled") | |
| # Create examples directory if it doesn't exist | |
| os.makedirs("examples", exist_ok=True) | |
| # Authenticate with Hugging Face Hub using environment variable | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| print("Warning: HF_TOKEN environment variable not set. Some features may not work.") | |
| # Model and device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Model identifier - hardcode the correct model path instead of using environment variables | |
| model_id = "thorscribe/thorscribe-model-3" | |
| print(f"Using model: {model_id}") | |
| # Determine dtype based on available hardware | |
| if device == "cuda": | |
| if torch.cuda.is_bf16_supported(): | |
| torch_dtype = torch.bfloat16 | |
| print("Using bfloat16 precision") | |
| else: | |
| torch_dtype = torch.float16 | |
| print("Using float16 precision") | |
| else: | |
| torch_dtype = torch.float32 | |
| print("Using float32 precision (CPU mode)") | |
| # Calculate target dimensions - using fixed dimensions | |
| target_size = 1024 # Use a fixed size that works well with the model | |
| print(f"Using fixed image resolution of {target_size}x{target_size}") | |
| def pad_to_square(image, background_color=(0, 0, 0)): | |
| """Pad image to square with black background""" | |
| if image is None: | |
| return None | |
| width, height = image.size | |
| if width == height: | |
| return image | |
| new_size = max(width, height) | |
| new_image = Image.new('RGB', (new_size, new_size), background_color) | |
| # Paste the original image centered in the square | |
| paste_x = (new_size - width) // 2 | |
| paste_y = (new_size - height) // 2 | |
| new_image.paste(image, (paste_x, paste_y)) | |
| return new_image | |
| def process_image(image, size=1024): | |
| """Process image to be suitable for the model""" | |
| if image is None: | |
| return None | |
| # First make the image square by padding | |
| image = pad_to_square(image) | |
| # Then resize to the target size | |
| image = image.resize((size, size), Image.LANCZOS) | |
| print(f"Processed image to {image.size[0]}x{image.size[1]}") | |
| return image | |
| # Load processor first (lower memory requirements) | |
| print(f"Loading processor from {model_id}...") | |
| try: | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| print("Processor loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading processor: {str(e)}") | |
| sys.exit(1) | |
| # Load and inspect model config via AutoConfig | |
| try: | |
| cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True) | |
| print("Vision config - patch_size:", cfg.vision_config.patch_size) | |
| print("Vision config - patch_stride:", cfg.vision_config.patch_stride) | |
| print("Vision config - patch_padding:", cfg.vision_config.patch_padding) | |
| except Exception as e: | |
| print(f"Error loading model config: {str(e)}") | |
| sys.exit(1) | |
| # Load model with explicit config | |
| try: | |
| print(f"Loading model from {model_id}...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| config=cfg, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| # Only move model to GPU when we're actually using it | |
| # Will be handled by the @spaces.GPU decorator | |
| if not has_spaces and device == "cuda": | |
| model.to(device) | |
| print("Model moved to CUDA device") | |
| print("Model loaded successfully with explicit config!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| sys.exit(1) | |
| # Default prompt to use (hidden from UI) | |
| DEFAULT_PROMPT = "<THORSCRIBE> What does this figure show?" | |
| # Define the generation function with ZeroGPU decorator if available | |
| if has_spaces: | |
| # Set appropriate duration based on your model's generation time | |
| def generate_caption(image): | |
| if image is None: | |
| return "Please upload an image." | |
| try: | |
| # Move model to GPU when using ZeroGPU | |
| model.to(device) | |
| # Process the image to be suitable for the model | |
| processed_image = process_image(image, size=target_size) | |
| # Process text and image separately | |
| pixel_values = processor.image_processor(images=processed_image, return_tensors="pt").pixel_values | |
| # Process the text with controlled parameters | |
| input_ids = processor.tokenizer( | |
| DEFAULT_PROMPT, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=77, # Use a safe, reasonable value | |
| truncation=True | |
| ).input_ids | |
| # Build inputs dictionary | |
| inputs = { | |
| "pixel_values": pixel_values.to(device, dtype=torch_dtype), | |
| "input_ids": input_ids.to(device) | |
| } | |
| # Generate with conservative settings | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=50, | |
| num_beams=1, | |
| do_sample=False | |
| ) | |
| # Decode and truncate | |
| text = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return text | |
| except Exception as e: | |
| import traceback | |
| trace = traceback.format_exc() | |
| print(f"Error: {str(e)}") | |
| print(trace) | |
| return f"Error processing image: {str(e)[:200]}. Check console for full traceback." | |
| else: | |
| # Regular function without ZeroGPU | |
| def generate_caption(image): | |
| if image is None: | |
| return "Please upload an image." | |
| try: | |
| # Process the image to be suitable for the model | |
| processed_image = process_image(image, size=target_size) | |
| # Process text and image separately | |
| pixel_values = processor.image_processor(images=processed_image, return_tensors="pt").pixel_values | |
| # Process the text with controlled parameters | |
| input_ids = processor.tokenizer( | |
| DEFAULT_PROMPT, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=77, # Use a safe, reasonable value | |
| truncation=True | |
| ).input_ids | |
| # Build inputs dictionary | |
| inputs = { | |
| "pixel_values": pixel_values.to(device, dtype=torch_dtype), | |
| "input_ids": input_ids.to(device) | |
| } | |
| # Generate with conservative settings | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=50, | |
| num_beams=1, | |
| do_sample=False | |
| ) | |
| # Decode and truncate | |
| text = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return text | |
| except Exception as e: | |
| import traceback | |
| trace = traceback.format_exc() | |
| print(f"Error: {str(e)}") | |
| print(trace) | |
| return f"Error processing image: {str(e)[:200]}. Check console for full traceback." | |
| # Create a simple Gradio interface without FastAPI integration | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=gr.Image(type="pil", label="Upload Thoracic MRI/X-ray Image"), | |
| outputs=gr.Textbox(label="Generated Caption", lines=5, max_lines=5, show_copy_button=True), | |
| title="THORSCRIBE: AI-Powered Thoracic Image Captioning", | |
| description="THORSCRIBE is an advanced AI model that generates detailed captions for MRI and X-ray images of the thorax area. Upload your medical image to receive an informative caption." + (" (with ZeroGPU)" if has_spaces else ""), | |
| allow_flagging="never", | |
| theme=gr.themes.Monochrome(), | |
| examples=["examples/example1.jpg", "examples/example2.jpg", "examples/example3.jpg", "examples/example4.jpg"] if os.path.exists("examples/example1.jpg") else None, | |
| article="<div style='text-align: center; max-width: 800px; margin: 0 auto;'><h3>About THORSCRIBE</h3><p>THORSCRIBE is specialized in analyzing thoracic medical imagery, providing accurate descriptions of findings in MRI and X-ray images. This tool is designed to assist medical professionals in their diagnostic workflows.</p><p><small>Powered by model: thorscribe/thorscribe-model-3</small></p></div>" | |
| ) | |
| # Launch the app - Use 7860 which is the standard port for Hugging Face Spaces | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False, | |
| show_error=True, | |
| ) |