File size: 1,801 Bytes
3ae84a3
 
 
78cdedf
3e5aff4
3ae84a3
 
 
dc86583
9b28e54
3ae84a3
 
ab7a42f
 
 
0d69242
3ae84a3
 
0d69242
78cdedf
3ae84a3
 
a2223b7
9b28e54
 
78cdedf
0d69242
c45624f
19cb4eb
d92e73a
c6131ee
657f29b
3ae84a3
 
19cb4eb
ab7a42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a836b0
ab7a42f
 
 
a9cb5c7
ab7a42f
19cb4eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia


# Only runs once when the script is first run.
with open("butts_1024_new.pickle", "rb") as handle:
    index = pickle.load(handle)

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained(
    "sasha/autotrain-butterfly-similarity-2490576840"
)
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")

# Candidate images.
dataset = load_dataset("sasha/butterflies_10k_names_multiple")
ds = dataset["train"]


def query(image, top_k=1):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    logits = results[1][0].tolist()
    butterfly = ds.select(inx)["image"]
    butterfly[0].show()
    return butterfly


with gr.Blocks() as demo:
    gr.Markdown("# Find my Butterfly 🦋")
    gr.Markdown(
        "## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!"
    )
    with gr.Row():
        with gr.Column(scale=1):
            inputs = gr.Image(width=288, height=384)
            btn = gr.Button("Find my butterfly!")
            description = gr.Markdown()

        with gr.Column(scale=2):
            outputs = gr.Gallery(rows=1)

    gr.Markdown("### Image Examples")
    gr.Examples(
        examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
        inputs=inputs,
        outputs=outputs,
        fn=query,
        cache_examples=True,
    )
    btn.click(query, inputs, outputs)

demo.launch()