| import argparse | |
| import json | |
| import faiss | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from imgutils.tagging import wd14 | |
| TITLE = "## Danbooru Explorer" | |
| DESCRIPTION = """ | |
| Image similarity-based retrieval tool using: | |
| - [SmilingWolf/wd-swinv2-tagger-v3](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) as feature extractor | |
| - [dghs-imgutils](https://github.com/deepghs/imgutils) for feature extraction | |
| - [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing | |
| Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input. | |
| """ | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): | |
| headers = {"User-Agent": "image_similarity_tool"} | |
| ratings_to_letters = { | |
| "General": "g", | |
| "Sensitive": "s", | |
| "Questionable": "q", | |
| "Explicit": "e", | |
| } | |
| acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] | |
| image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" | |
| if api_username != "" and api_key != "": | |
| image_url = f"{image_url}?api_key={api_key}&login={api_username}" | |
| r = requests.get(image_url, headers=headers) | |
| if r.status_code != 200: | |
| return None | |
| content = json.loads(r.text) | |
| image_url = content["large_file_url"] if "large_file_url" in content else None | |
| image_url = image_url if content["rating"] in acceptable_ratings else None | |
| return image_url | |
| class SimilaritySearcher: | |
| def __init__(self): | |
| self.images_ids = np.load("index/cosine_ids.npy") | |
| self.knn_index = faiss.read_index("index/cosine_knn.index") | |
| config = json.loads(open("index/cosine_infos.json").read())["index_param"] | |
| faiss.ParameterSpace().set_index_parameters(self.knn_index, config) | |
| def predict( | |
| self, | |
| img_input, | |
| selected_ratings, | |
| n_neighbours, | |
| api_username, | |
| api_key, | |
| ): | |
| embeddings = wd14.get_wd14_tags( | |
| img_input, | |
| model_name="SwinV2_v3", | |
| fmt=("embedding"), | |
| ) | |
| embeddings = np.expand_dims(embeddings, 0) | |
| faiss.normalize_L2(embeddings) | |
| dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) | |
| neighbours_ids = self.images_ids[indexes][0] | |
| neighbours_ids = [int(x) for x in neighbours_ids] | |
| captions = [] | |
| image_urls = [] | |
| for image_id, dist in zip(neighbours_ids, dists[0]): | |
| current_url = danbooru_id_to_url( | |
| image_id, | |
| selected_ratings, | |
| api_username, | |
| api_key, | |
| ) | |
| if current_url is not None: | |
| image_urls.append(current_url) | |
| captions.append(f"{image_id}/{dist:.2f}") | |
| return list(zip(image_urls, captions)) | |
| def main(): | |
| args = parse_args() | |
| searcher = SimilaritySearcher() | |
| with gr.Blocks() as demo: | |
| gr.Markdown(TITLE) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="Input") | |
| with gr.Column(): | |
| with gr.Row(): | |
| api_username = gr.Textbox(label="Danbooru API Username") | |
| api_key = gr.Textbox(label="Danbooru API Key") | |
| selected_ratings = gr.CheckboxGroup( | |
| choices=["General", "Sensitive", "Questionable", "Explicit"], | |
| value=["General", "Sensitive"], | |
| label="Ratings", | |
| ) | |
| with gr.Row(): | |
| n_neighbours = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="# of images", | |
| ) | |
| find_btn = gr.Button("Find similar images") | |
| similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
| find_btn.click( | |
| fn=searcher.predict, | |
| inputs=[ | |
| img_input, | |
| selected_ratings, | |
| n_neighbours, | |
| api_username, | |
| api_key, | |
| ], | |
| outputs=[similar_images], | |
| ) | |
| demo.queue() | |
| demo.launch(share=args.share) | |
| if __name__ == "__main__": | |
| main() | |