Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import TimmWrapperModel | |
| import torch | |
| import torchvision.transforms.v2 as T | |
| MODEL_MAP = { | |
| "p1atdev/style_250412.vit_base_patch16_siglip_384.v2_webli": { | |
| "mean": [0, 0, 0], | |
| "std": [1.0, 1.0, 1.0], | |
| "image_size": 384, | |
| "background": 0, | |
| } | |
| } | |
| def config_to_processor(config: dict): | |
| return T.Compose( | |
| [ | |
| T.PILToTensor(), | |
| T.Resize( | |
| size=None, | |
| max_size=config["image_size"], | |
| interpolation=T.InterpolationMode.NEAREST, | |
| ), | |
| T.Pad( | |
| padding=config["image_size"] // 2, | |
| fill=config["background"], | |
| ), | |
| T.CenterCrop( | |
| size=(config["image_size"], config["image_size"]), | |
| ), | |
| T.ToDtype(dtype=torch.float32, scale=True), # 0~255 -> 0~1 | |
| T.Normalize(mean=config["mean"], std=config["std"]), | |
| ] | |
| ) | |
| def load_model(name: str): | |
| return TimmWrapperModel.from_pretrained(name).eval().requires_grad_(False) | |
| MODELS = { | |
| name: { | |
| "model": load_model(name), | |
| "processor": config_to_processor(config), | |
| } | |
| for name, config in MODEL_MAP.items() | |
| } | |
| def calculate_similarity(model_name: str, image_1: Image.Image, image_2: Image.Image): | |
| model = MODELS[model_name]["model"] | |
| processor = MODELS[model_name]["processor"] | |
| pixel_values = torch.stack([processor(image) for image in [image_1, image_2]]) | |
| embeddings = model(pixel_values).pooler_output | |
| embeddings /= embeddings.norm(p=2, dim=-1, keepdim=True) | |
| similarity = (embeddings[0] @ embeddings[1].T).item() | |
| return similarity | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_1 = gr.Image(label="Image 1", type="pil") | |
| image_2 = gr.Image(label="Image 2", type="pil") | |
| model_name = gr.Dropdown( | |
| label="Model", | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| similarity = gr.Label(label="Similarity") | |
| gr.on( | |
| triggers=[submit_btn.click], | |
| fn=calculate_similarity, | |
| inputs=[ | |
| model_name, | |
| image_1, | |
| image_2, | |
| ], | |
| outputs=[similarity], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |