Spaces:
Runtime error
Runtime error
| import json | |
| import faiss | |
| import flax | |
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from Models.CLIP import CLIP | |
| def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): | |
| headers = {"User-Agent": "image_similarity_tool"} | |
| ratings_to_letters = { | |
| "General": "g", | |
| "Sensitive": "s", | |
| "Questionable": "q", | |
| "Explicit": "e", | |
| } | |
| acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] | |
| image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" | |
| if api_username != "" and api_key != "": | |
| image_url = f"{image_url}?api_key={api_key}&login={api_username}" | |
| r = requests.get(image_url, headers=headers) | |
| if r.status_code != 200: | |
| return None | |
| content = json.loads(r.text) | |
| image_url = content["large_file_url"] if "large_file_url" in content else None | |
| image_url = image_url if content["rating"] in acceptable_ratings else None | |
| return image_url | |
| class Predictor: | |
| def __init__(self): | |
| self.base_model = "wd-v1-4-convnext-tagger-v2" | |
| with open(f"data/{self.base_model}/clip.msgpack", "rb") as f: | |
| data = f.read() | |
| self.params = flax.serialization.msgpack_restore(data)["model"] | |
| self.model = CLIP() | |
| self.tags_df = pd.read_csv("data/selected_tags.csv") | |
| self.images_ids = np.load("index/cosine_ids.npy") | |
| self.knn_index = faiss.read_index("index/cosine_knn.index") | |
| config = json.loads(open("index/cosine_infos.json").read())["index_param"] | |
| faiss.ParameterSpace().set_index_parameters(self.knn_index, config) | |
| def predict(self, positive_tags, negative_tags, n_neighbours=5): | |
| tags_df = self.tags_df | |
| model = self.model | |
| num_classes = len(tags_df) | |
| positive_tags = positive_tags.split(",") | |
| negative_tags = negative_tags.split(",") | |
| positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist() | |
| negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist() | |
| tags = np.zeros((1, num_classes), dtype=np.float32) | |
| tags[0][positive_tags_idxs] = 1 | |
| emb_from_logits = model.apply( | |
| {"params": self.params}, | |
| tags, | |
| method=model.encode_text, | |
| ) | |
| emb_from_logits = jax.device_get(emb_from_logits) | |
| if len(negative_tags_idxs) > 0: | |
| tags = np.zeros((1, num_classes), dtype=np.float32) | |
| tags[0][negative_tags_idxs] = 1 | |
| neg_emb_from_logits = model.apply( | |
| {"params": self.params}, | |
| tags, | |
| method=model.encode_text, | |
| ) | |
| neg_emb_from_logits = jax.device_get(neg_emb_from_logits) | |
| emb_from_logits = emb_from_logits - neg_emb_from_logits | |
| faiss.normalize_L2(emb_from_logits) | |
| dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours) | |
| neighbours_ids = self.images_ids[indexes][0] | |
| neighbours_ids = [int(x) for x in neighbours_ids] | |
| captions = [] | |
| image_urls = [] | |
| for image_id, dist in zip(neighbours_ids, dists[0]): | |
| current_url = danbooru_id_to_url( | |
| image_id, | |
| [ | |
| "General", | |
| "Sensitive", | |
| "Questionable", | |
| "Explicit", | |
| ], | |
| ) | |
| if current_url is not None: | |
| image_urls.append(current_url) | |
| captions.append(f"{image_id}/{dist:.2f}") | |
| return list(zip(image_urls, captions)) | |
| def main(): | |
| predictor = Predictor() | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| positive_tags = gr.Textbox(label="Positive tags") | |
| negative_tags = gr.Textbox(label="Negative tags") | |
| find_btn = gr.Button("Find similar images") | |
| similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
| find_btn.click( | |
| fn=predictor.predict, | |
| inputs=[positive_tags, negative_tags], | |
| outputs=[similar_images], | |
| ) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |