import gradio as gr from PIL import Image from transformers import pipeline import torch DESCRIPTION = """ # 🐼 Simple Image Classifier (ViT) Sube una imagen y el modelo devuelve las **probabilidades top‑k** de las clases. Puedes elegir distintos modelos del Hub (cargados automáticamente). """ DEFAULT_MODEL = "google/vit-base-patch16-224" MODEL_OPTIONS = [ "google/vit-base-patch16-224", "facebook/deit-base-patch16-224", "microsoft/resnet-50", ] # Cache del pipeline para no recargar al cambiar parámetros que no sean el modelo _pipes = {} def get_pipe(model_id: str): if model_id not in _pipes: # device map sencillo: usa GPU si está disponible device = 0 if torch.cuda.is_available() else -1 _pipes[model_id] = pipeline( task="image-classification", model=model_id, device=device ) return _pipes[model_id] def classify(image: Image.Image, model_id: str, top_k: int): if image is None: return [] pipe = get_pipe(model_id) # Asegurar modo RGB image = image.convert("RGB") preds = pipe(image, top_k=top_k) # Normalizamos a una tabla [(label, score)] rows = [(p["label"], float(p["score"])) for p in preds] return rows with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): image = gr.Image(type="pil", label="Imagen de entrada") with gr.Row(): model_id = gr.Dropdown( choices=MODEL_OPTIONS, value=DEFAULT_MODEL, label="Modelo" ) top_k = gr.Slider(1, 10, value=5, step=1, label="Top‑K") btn = gr.Button("Clasificar") with gr.Column(scale=1): output = gr.Dataframe( headers=["label", "score"], datatype=["str", "number"], label="Resultados" ) btn.click(fn=classify, inputs=[image, model_id, top_k], outputs=output) if __name__ == "__main__": demo.launch()