Spaces:
Runtime error
Runtime error
| import json | |
| import faiss | |
| import flax | |
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from imgutils.tagging import wd14 | |
| from Models.CLIP import CLIP | |
| def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs): | |
| pos = pos_img_embs + pos_tags_embs | |
| faiss.normalize_L2(pos) | |
| neg = neg_img_embs + neg_tags_embs | |
| faiss.normalize_L2(neg) | |
| result = pos - neg | |
| faiss.normalize_L2(result) | |
| return result | |
| def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): | |
| headers = {"User-Agent": "image_similarity_tool"} | |
| ratings_to_letters = { | |
| "General": "g", | |
| "Sensitive": "s", | |
| "Questionable": "q", | |
| "Explicit": "e", | |
| } | |
| acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] | |
| image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" | |
| if api_username != "" and api_key != "": | |
| image_url = f"{image_url}?api_key={api_key}&login={api_username}" | |
| r = requests.get(image_url, headers=headers) | |
| if r.status_code != 200: | |
| return None | |
| content = json.loads(r.text) | |
| image_url = content["large_file_url"] if "large_file_url" in content else None | |
| image_url = image_url if content["rating"] in acceptable_ratings else None | |
| return image_url | |
| class Predictor: | |
| def __init__(self): | |
| self.loaded_variant = None | |
| self.base_model = "wd-v1-4-convnext-tagger-v2" | |
| self.model = CLIP() | |
| self.tags_df = pd.read_csv("data/selected_tags.csv") | |
| self.images_ids = np.load("index/cosine_ids.npy") | |
| self.knn_index = faiss.read_index("index/cosine_knn.index") | |
| config = json.loads(open("index/cosine_infos.json").read())["index_param"] | |
| faiss.ParameterSpace().set_index_parameters(self.knn_index, config) | |
| def load_params(self, variant): | |
| if self.loaded_variant == variant: | |
| return | |
| if variant == "CLIP": | |
| with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: | |
| data = f.read() | |
| elif variant == "SigLIP": | |
| with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f: | |
| data = f.read() | |
| self.params = flax.serialization.msgpack_restore(data)["model"] | |
| self.loaded_variant = variant | |
| def predict( | |
| self, | |
| pos_img_input, | |
| neg_img_input, | |
| positive_tags, | |
| negative_tags, | |
| selected_model, | |
| selected_ratings, | |
| n_neighbours, | |
| api_username, | |
| api_key, | |
| ): | |
| tags_df = self.tags_df | |
| model = self.model | |
| self.load_params(selected_model) | |
| num_classes = len(tags_df) | |
| output_shape = model.out_units | |
| pos_img_embs = np.zeros((1, output_shape), dtype=np.float32) | |
| neg_img_embs = np.zeros((1, output_shape), dtype=np.float32) | |
| pos_tags_embs = np.zeros((1, output_shape), dtype=np.float32) | |
| neg_tags_embs = np.zeros((1, output_shape), dtype=np.float32) | |
| positive_tags = positive_tags.split(",") | |
| negative_tags = negative_tags.split(",") | |
| positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() | |
| negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() | |
| if pos_img_input is not None: | |
| pos_img_embs = wd14.get_wd14_tags( | |
| pos_img_input, | |
| model_name="ConvNext", | |
| fmt=("embedding"), | |
| ) | |
| pos_img_embs = np.expand_dims(pos_img_embs, 0) | |
| faiss.normalize_L2(pos_img_embs) | |
| if neg_img_input is not None: | |
| neg_img_embs = wd14.get_wd14_tags( | |
| neg_img_input, | |
| model_name="ConvNext", | |
| fmt=("embedding"), | |
| ) | |
| neg_img_embs = np.expand_dims(neg_img_embs, 0) | |
| faiss.normalize_L2(neg_img_embs) | |
| if len(positive_tags_idxs) > 0: | |
| tags = np.zeros((1, num_classes), dtype=np.float32) | |
| tags[0][positive_tags_idxs] = 1 | |
| pos_tags_embs = model.apply( | |
| {"params": self.params}, | |
| tags, | |
| method=model.encode_text, | |
| ) | |
| pos_tags_embs = jax.device_get(pos_tags_embs) | |
| faiss.normalize_L2(pos_tags_embs) | |
| if len(negative_tags_idxs) > 0: | |
| tags = np.zeros((1, num_classes), dtype=np.float32) | |
| tags[0][negative_tags_idxs] = 1 | |
| neg_tags_embs = model.apply( | |
| {"params": self.params}, | |
| tags, | |
| method=model.encode_text, | |
| ) | |
| neg_tags_embs = jax.device_get(neg_tags_embs) | |
| faiss.normalize_L2(neg_tags_embs) | |
| embeddings = combine_embeddings( | |
| pos_img_embs, | |
| pos_tags_embs, | |
| neg_img_embs, | |
| neg_tags_embs, | |
| ) | |
| dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) | |
| neighbours_ids = self.images_ids[indexes][0] | |
| neighbours_ids = [int(x) for x in neighbours_ids] | |
| captions = [] | |
| image_urls = [] | |
| for image_id, dist in zip(neighbours_ids, dists[0]): | |
| current_url = danbooru_id_to_url( | |
| image_id, | |
| selected_ratings, | |
| api_username, | |
| api_key, | |
| ) | |
| if current_url is not None: | |
| image_urls.append(current_url) | |
| captions.append(f"{image_id}/{dist:.2f}") | |
| return list(zip(image_urls, captions)) | |
| def main(): | |
| predictor = Predictor() | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| pos_img_input = gr.Image(type="pil", label="Positive input") | |
| neg_img_input = gr.Image(type="pil", label="Negative input") | |
| with gr.Row(): | |
| with gr.Column(): | |
| positive_tags = gr.Textbox(label="Positive tags") | |
| negative_tags = gr.Textbox(label="Negative tags") | |
| selected_model = gr.Radio( | |
| choices=["CLIP", "SigLIP"], | |
| value="CLIP", | |
| label="Tags embedding model", | |
| ) | |
| with gr.Column(): | |
| selected_ratings = gr.CheckboxGroup( | |
| choices=["General", "Sensitive", "Questionable", "Explicit"], | |
| value=["General", "Sensitive"], | |
| label="Ratings", | |
| ) | |
| n_neighbours = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="# of images", | |
| ) | |
| with gr.Row(): | |
| api_username = gr.Textbox(label="Danbooru API Username") | |
| api_key = gr.Textbox(label="Danbooru API Key") | |
| find_btn = gr.Button("Find similar images") | |
| similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
| examples = gr.Examples( | |
| [ | |
| [ | |
| None, | |
| None, | |
| "marcille_donato", | |
| "", | |
| "CLIP", | |
| ["General", "Sensitive"], | |
| 5, | |
| "", | |
| "", | |
| ], | |
| [ | |
| None, | |
| None, | |
| "yellow_eyes,red_horns", | |
| "", | |
| "CLIP", | |
| ["General", "Sensitive"], | |
| 5, | |
| "", | |
| "", | |
| ], | |
| [ | |
| None, | |
| None, | |
| "artoria_pendragon_(fate),solo", | |
| "green_eyes", | |
| "CLIP", | |
| ["General", "Sensitive"], | |
| 5, | |
| "", | |
| "", | |
| ], | |
| [ | |
| "examples/60378883_p0.jpg", | |
| None, | |
| "fujimaru_ritsuka_(female)", | |
| "solo", | |
| "CLIP", | |
| ["General", "Sensitive"], | |
| 5, | |
| "", | |
| "", | |
| ], | |
| [ | |
| "examples/DaRlExxUwAAcUOS-orig.jpg", | |
| "examples/46657164_p1.jpg", | |
| "", | |
| "", | |
| "CLIP", | |
| ["General", "Sensitive"], | |
| 5, | |
| "", | |
| "", | |
| ], | |
| ], | |
| inputs=[ | |
| pos_img_input, | |
| neg_img_input, | |
| positive_tags, | |
| negative_tags, | |
| selected_model, | |
| selected_ratings, | |
| n_neighbours, | |
| api_username, | |
| api_key, | |
| ], | |
| outputs=[similar_images], | |
| fn=predictor.predict, | |
| run_on_click=True, | |
| cache_examples=False, | |
| ) | |
| find_btn.click( | |
| fn=predictor.predict, | |
| inputs=[ | |
| pos_img_input, | |
| neg_img_input, | |
| positive_tags, | |
| negative_tags, | |
| selected_model, | |
| selected_ratings, | |
| n_neighbours, | |
| api_username, | |
| api_key, | |
| ], | |
| outputs=[similar_images], | |
| ) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |