Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| from transformers import pipeline | |
| import torch | |
| DESCRIPTION = """ | |
| # 🐼 Simple Image Classifier (ViT) | |
| Sube una imagen y el modelo devuelve las **probabilidades top‑k** de las clases. | |
| Puedes elegir distintos modelos del Hub (cargados automáticamente). | |
| """ | |
| DEFAULT_MODEL = "google/vit-base-patch16-224" | |
| MODEL_OPTIONS = [ | |
| "google/vit-base-patch16-224", | |
| "facebook/deit-base-patch16-224", | |
| "microsoft/resnet-50", | |
| ] | |
| # Cache del pipeline para no recargar al cambiar parámetros que no sean el modelo | |
| _pipes = {} | |
| def get_pipe(model_id: str): | |
| if model_id not in _pipes: | |
| # device map sencillo: usa GPU si está disponible | |
| device = 0 if torch.cuda.is_available() else -1 | |
| _pipes[model_id] = pipeline( | |
| task="image-classification", | |
| model=model_id, | |
| device=device | |
| ) | |
| return _pipes[model_id] | |
| def classify(image: Image.Image, model_id: str, top_k: int): | |
| if image is None: | |
| return [] | |
| pipe = get_pipe(model_id) | |
| # Asegurar modo RGB | |
| image = image.convert("RGB") | |
| preds = pipe(image, top_k=top_k) | |
| # Normalizamos a una tabla [(label, score)] | |
| rows = [(p["label"], float(p["score"])) for p in preds] | |
| return rows | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image = gr.Image(type="pil", label="Imagen de entrada") | |
| with gr.Row(): | |
| model_id = gr.Dropdown( | |
| choices=MODEL_OPTIONS, | |
| value=DEFAULT_MODEL, | |
| label="Modelo" | |
| ) | |
| top_k = gr.Slider(1, 10, value=5, step=1, label="Top‑K") | |
| btn = gr.Button("Clasificar") | |
| with gr.Column(scale=1): | |
| output = gr.Dataframe( | |
| headers=["label", "score"], | |
| datatype=["str", "number"], | |
| label="Resultados" | |
| ) | |
| btn.click(fn=classify, inputs=[image, model_id, top_k], outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch() | |