| from transformers import AutoFeatureExtractor, ResNetForImageClassification | |
| import torch | |
| # from datasets import load_dataset | |
| # dataset = load_dataset("huggingface/cats-image") | |
| # image = dataset["test"]["image"][0] | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") | |
| model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") | |
| import gradio as gr | |
| def segment(image): | |
| inputs = feature_extractor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.nn.Softmax(dim=1)(logits) | |
| # labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs[0])] | |
| labels = {model.config.id2label[idx] : float(prob) for idx, prob in enumerate(probs[0])} | |
| print(labels) | |
| # model predicts one of the 1000 ImageNet classes | |
| # predicted_label = logits.argmax(-1).item() | |
| return labels # model.config.id2label[predicted_label] | |
| gr.Interface(fn=segment, inputs="image", outputs="label").launch() | |
| #gr.Interface(fn=segment, inputs="image", outputs="text").launch() | |
| # with torch.no_grad(): | |
| # prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0) | |
| # return {model.config.id2label[i]: float(prediction[i]) for i in range(3)} | |
| #gr.Interface(fn=segment, inputs="image", outputs="label").launch() | |