Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| from transformers import pipeline | |
| import io | |
| import torch # Import PyTorch | |
| # --- Configuration --- | |
| # Specify the model | |
| MODEL_NAME = "microsoft/maira-2" | |
| # --- Model Loading (using pipeline) --- | |
| # Cache the pipeline for performance | |
| def load_pipeline(): | |
| """Loads the VQA pipeline.""" | |
| try: | |
| # Explicitly set device if CUDA is available, otherwise use CPU | |
| device = 0 if torch.cuda.is_available() else -1 # Use torch.cuda | |
| vqa_pipeline = pipeline("visual-question-answering", model=MODEL_NAME, device=device) # Add device | |
| return vqa_pipeline | |
| except Exception as e: | |
| st.error(f"Error loading pipeline: {e}") | |
| return None | |
| # --- Image Preprocessing (Keep as bytes) --- | |
| def prepare_image(image): | |
| """Prepares the PIL Image object for the pipeline (handles RGBA).""" | |
| image_bytes = io.BytesIO() | |
| if image.mode == "RGBA": | |
| image = image.convert("RGB") | |
| image.save(image_bytes, format="JPEG") | |
| return image_bytes.getvalue() # Return bytes directly | |
| # --- Streamlit App --- | |
| def main(): | |
| st.title("Chest X-ray Analysis with Maira-2 (Transformers Pipeline)") | |
| st.write("Upload a chest X-ray image. This app uses the Maira-2 model via the Transformers library.") | |
| vqa_pipeline = load_pipeline() | |
| if vqa_pipeline is None: | |
| st.warning("Pipeline not loaded. Predictions will not be available.") | |
| return | |
| uploaded_file = st.file_uploader("Choose a chest X-ray image (JPG, PNG)", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| with st.spinner("Analyzing image with Maira-2..."): | |
| image_data = prepare_image(image) | |
| try: | |
| results = vqa_pipeline( | |
| image=image_data, # Pass the image bytes | |
| question="Analyze this chest X-ray image and provide detailed findings. Include any abnormalities, their locations, and potential diagnoses. Be as specific as possible.", | |
| ) | |
| if results: # Handle results (list of dicts) | |
| if isinstance(results, list) and len(results) > 0: | |
| best_answer = max(results, key=lambda x: x.get('score', 0)) | |
| if 'answer' in best_answer: | |
| st.subheader("Findings:") | |
| st.write(best_answer['answer']) | |
| else: | |
| st.warning("Could not find 'answer' in results.") | |
| else: | |
| st.warning("Unexpected result format.") | |
| except Exception as e: | |
| st.error(f"An error occurred during analysis: {e}") | |
| else: | |
| st.write("Please upload an image.") | |
| st.write("---") | |
| st.write("Disclaimer: For informational purposes only. Not medical advice.") | |
| if __name__ == "__main__": | |
| main() |