Spaces:
Runtime error
Runtime error
File size: 4,226 Bytes
23fa49c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import json
import faiss
import flax
import gradio as gr
import jax
import numpy as np
import pandas as pd
import requests
from Models.CLIP import CLIP
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class Predictor:
def __init__(self):
self.base_model = "wd-v1-4-convnext-tagger-v2"
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
data = f.read()
self.params = flax.serialization.msgpack_restore(data)["model"]
self.model = CLIP()
self.tags_df = pd.read_csv("data/selected_tags.csv")
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def predict(self, positive_tags, negative_tags, n_neighbours=5):
tags_df = self.tags_df
model = self.model
num_classes = len(tags_df)
positive_tags = positive_tags.split(",")
negative_tags = negative_tags.split(",")
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][positive_tags_idxs] = 1
emb_from_logits = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
emb_from_logits = jax.device_get(emb_from_logits)
if len(negative_tags_idxs) > 0:
tags = np.zeros((1, num_classes), dtype=np.float32)
tags[0][negative_tags_idxs] = 1
neg_emb_from_logits = model.apply(
{"params": self.params},
tags,
method=model.encode_text,
)
neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
emb_from_logits = emb_from_logits - neg_emb_from_logits
faiss.normalize_L2(emb_from_logits)
dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
[
"General",
"Sensitive",
"Questionable",
"Explicit",
],
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
predictor = Predictor()
with gr.Blocks() as demo:
with gr.Row():
positive_tags = gr.Textbox(label="Positive tags")
negative_tags = gr.Textbox(label="Negative tags")
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=predictor.predict,
inputs=[positive_tags, negative_tags],
outputs=[similar_images],
)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()
|