Spaces:
Runtime error
Runtime error
Commit
·
0c14216
1
Parent(s):
69cd139
Add support for SigLIP-trained weights.
Browse filesSame network structure for now, this is just to make it easier to
compare the two while experimenting.
- app.py +40 -10
- data/wd-v1-4-convnext-tagger-v2/siglip.msgpack +3 -0
app.py
CHANGED
|
@@ -14,10 +14,13 @@ from Models.CLIP import CLIP
|
|
| 14 |
|
| 15 |
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
|
| 16 |
pos = pos_img_embs + pos_tags_embs
|
|
|
|
| 17 |
|
| 18 |
neg = neg_img_embs + neg_tags_embs
|
|
|
|
| 19 |
|
| 20 |
result = pos - neg
|
|
|
|
| 21 |
return result
|
| 22 |
|
| 23 |
|
|
@@ -48,12 +51,9 @@ def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
|
|
| 48 |
|
| 49 |
class Predictor:
|
| 50 |
def __init__(self):
|
|
|
|
| 51 |
self.base_model = "wd-v1-4-convnext-tagger-v2"
|
| 52 |
|
| 53 |
-
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
|
| 54 |
-
data = f.read()
|
| 55 |
-
|
| 56 |
-
self.params = flax.serialization.msgpack_restore(data)["model"]
|
| 57 |
self.model = CLIP()
|
| 58 |
|
| 59 |
self.tags_df = pd.read_csv("data/selected_tags.csv")
|
|
@@ -64,12 +64,27 @@ class Predictor:
|
|
| 64 |
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
|
| 65 |
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def predict(
|
| 68 |
self,
|
| 69 |
pos_img_input,
|
| 70 |
neg_img_input,
|
| 71 |
positive_tags,
|
| 72 |
negative_tags,
|
|
|
|
| 73 |
selected_ratings,
|
| 74 |
n_neighbours,
|
| 75 |
api_username,
|
|
@@ -78,6 +93,8 @@ class Predictor:
|
|
| 78 |
tags_df = self.tags_df
|
| 79 |
model = self.model
|
| 80 |
|
|
|
|
|
|
|
| 81 |
num_classes = len(tags_df)
|
| 82 |
|
| 83 |
output_shape = model.out_units
|
|
@@ -172,10 +189,10 @@ def main():
|
|
| 172 |
positive_tags = gr.Textbox(label="Positive tags")
|
| 173 |
negative_tags = gr.Textbox(label="Negative tags")
|
| 174 |
with gr.Column():
|
| 175 |
-
|
| 176 |
-
choices=["
|
| 177 |
-
value=
|
| 178 |
-
label="
|
| 179 |
)
|
| 180 |
n_neighbours = gr.Slider(
|
| 181 |
minimum=1,
|
|
@@ -185,8 +202,14 @@ def main():
|
|
| 185 |
label="# of images",
|
| 186 |
)
|
| 187 |
with gr.Column():
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
find_btn = gr.Button("Find similar images")
|
| 192 |
|
|
@@ -199,6 +222,7 @@ def main():
|
|
| 199 |
None,
|
| 200 |
"marcille_donato",
|
| 201 |
"",
|
|
|
|
| 202 |
["General", "Sensitive"],
|
| 203 |
5,
|
| 204 |
"",
|
|
@@ -209,6 +233,7 @@ def main():
|
|
| 209 |
None,
|
| 210 |
"yellow_eyes,red_horns",
|
| 211 |
"",
|
|
|
|
| 212 |
["General", "Sensitive"],
|
| 213 |
5,
|
| 214 |
"",
|
|
@@ -219,6 +244,7 @@ def main():
|
|
| 219 |
None,
|
| 220 |
"artoria_pendragon_(fate),solo",
|
| 221 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
|
|
|
| 222 |
["General", "Sensitive"],
|
| 223 |
5,
|
| 224 |
"",
|
|
@@ -229,6 +255,7 @@ def main():
|
|
| 229 |
None,
|
| 230 |
"fujimaru_ritsuka_(female)",
|
| 231 |
"solo",
|
|
|
|
| 232 |
["General", "Sensitive"],
|
| 233 |
5,
|
| 234 |
"",
|
|
@@ -239,6 +266,7 @@ def main():
|
|
| 239 |
"examples/46657164_p1.jpg",
|
| 240 |
"",
|
| 241 |
"",
|
|
|
|
| 242 |
["General", "Sensitive"],
|
| 243 |
5,
|
| 244 |
"",
|
|
@@ -250,6 +278,7 @@ def main():
|
|
| 250 |
neg_img_input,
|
| 251 |
positive_tags,
|
| 252 |
negative_tags,
|
|
|
|
| 253 |
selected_ratings,
|
| 254 |
n_neighbours,
|
| 255 |
api_username,
|
|
@@ -268,6 +297,7 @@ def main():
|
|
| 268 |
neg_img_input,
|
| 269 |
positive_tags,
|
| 270 |
negative_tags,
|
|
|
|
| 271 |
selected_ratings,
|
| 272 |
n_neighbours,
|
| 273 |
api_username,
|
|
|
|
| 14 |
|
| 15 |
def combine_embeddings(pos_img_embs, pos_tags_embs, neg_img_embs, neg_tags_embs):
|
| 16 |
pos = pos_img_embs + pos_tags_embs
|
| 17 |
+
faiss.normalize_L2(pos)
|
| 18 |
|
| 19 |
neg = neg_img_embs + neg_tags_embs
|
| 20 |
+
faiss.normalize_L2(neg)
|
| 21 |
|
| 22 |
result = pos - neg
|
| 23 |
+
faiss.normalize_L2(result)
|
| 24 |
return result
|
| 25 |
|
| 26 |
|
|
|
|
| 51 |
|
| 52 |
class Predictor:
|
| 53 |
def __init__(self):
|
| 54 |
+
self.loaded_variant = None
|
| 55 |
self.base_model = "wd-v1-4-convnext-tagger-v2"
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
self.model = CLIP()
|
| 58 |
|
| 59 |
self.tags_df = pd.read_csv("data/selected_tags.csv")
|
|
|
|
| 64 |
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
|
| 65 |
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
|
| 66 |
|
| 67 |
+
def load_params(self, variant):
|
| 68 |
+
if self.loaded_variant == variant:
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
if variant == "CLIP":
|
| 72 |
+
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
|
| 73 |
+
data = f.read()
|
| 74 |
+
elif variant == "SigLIP":
|
| 75 |
+
with open(f"data/{self.base_model}/siglip.msgpack", "rb") as f:
|
| 76 |
+
data = f.read()
|
| 77 |
+
|
| 78 |
+
self.params = flax.serialization.msgpack_restore(data)["model"]
|
| 79 |
+
self.loaded_variant = variant
|
| 80 |
+
|
| 81 |
def predict(
|
| 82 |
self,
|
| 83 |
pos_img_input,
|
| 84 |
neg_img_input,
|
| 85 |
positive_tags,
|
| 86 |
negative_tags,
|
| 87 |
+
selected_model,
|
| 88 |
selected_ratings,
|
| 89 |
n_neighbours,
|
| 90 |
api_username,
|
|
|
|
| 93 |
tags_df = self.tags_df
|
| 94 |
model = self.model
|
| 95 |
|
| 96 |
+
self.load_params(selected_model)
|
| 97 |
+
|
| 98 |
num_classes = len(tags_df)
|
| 99 |
|
| 100 |
output_shape = model.out_units
|
|
|
|
| 189 |
positive_tags = gr.Textbox(label="Positive tags")
|
| 190 |
negative_tags = gr.Textbox(label="Negative tags")
|
| 191 |
with gr.Column():
|
| 192 |
+
selected_model = gr.Radio(
|
| 193 |
+
choices=["CLIP", "SigLIP"],
|
| 194 |
+
value="CLIP",
|
| 195 |
+
label="Model",
|
| 196 |
)
|
| 197 |
n_neighbours = gr.Slider(
|
| 198 |
minimum=1,
|
|
|
|
| 202 |
label="# of images",
|
| 203 |
)
|
| 204 |
with gr.Column():
|
| 205 |
+
selected_ratings = gr.CheckboxGroup(
|
| 206 |
+
choices=["General", "Sensitive", "Questionable", "Explicit"],
|
| 207 |
+
value=["General", "Sensitive"],
|
| 208 |
+
label="Ratings",
|
| 209 |
+
)
|
| 210 |
+
with gr.Row():
|
| 211 |
+
api_username = gr.Textbox(label="Danbooru API Username")
|
| 212 |
+
api_key = gr.Textbox(label="Danbooru API Key")
|
| 213 |
|
| 214 |
find_btn = gr.Button("Find similar images")
|
| 215 |
|
|
|
|
| 222 |
None,
|
| 223 |
"marcille_donato",
|
| 224 |
"",
|
| 225 |
+
"CLIP",
|
| 226 |
["General", "Sensitive"],
|
| 227 |
5,
|
| 228 |
"",
|
|
|
|
| 233 |
None,
|
| 234 |
"yellow_eyes,red_horns",
|
| 235 |
"",
|
| 236 |
+
"CLIP",
|
| 237 |
["General", "Sensitive"],
|
| 238 |
5,
|
| 239 |
"",
|
|
|
|
| 244 |
None,
|
| 245 |
"artoria_pendragon_(fate),solo",
|
| 246 |
"excalibur_(fate/stay_night),green_eyes,monochrome,blonde_hair",
|
| 247 |
+
"CLIP",
|
| 248 |
["General", "Sensitive"],
|
| 249 |
5,
|
| 250 |
"",
|
|
|
|
| 255 |
None,
|
| 256 |
"fujimaru_ritsuka_(female)",
|
| 257 |
"solo",
|
| 258 |
+
"CLIP",
|
| 259 |
["General", "Sensitive"],
|
| 260 |
5,
|
| 261 |
"",
|
|
|
|
| 266 |
"examples/46657164_p1.jpg",
|
| 267 |
"",
|
| 268 |
"",
|
| 269 |
+
"CLIP",
|
| 270 |
["General", "Sensitive"],
|
| 271 |
5,
|
| 272 |
"",
|
|
|
|
| 278 |
neg_img_input,
|
| 279 |
positive_tags,
|
| 280 |
negative_tags,
|
| 281 |
+
selected_model,
|
| 282 |
selected_ratings,
|
| 283 |
n_neighbours,
|
| 284 |
api_username,
|
|
|
|
| 297 |
neg_img_input,
|
| 298 |
positive_tags,
|
| 299 |
negative_tags,
|
| 300 |
+
selected_model,
|
| 301 |
selected_ratings,
|
| 302 |
n_neighbours,
|
| 303 |
api_username,
|
data/wd-v1-4-convnext-tagger-v2/siglip.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b541d6ed39a4df5ca2edd7e3431e936bbb61c9499026ad3365361af13aa06d06
|
| 3 |
+
size 48689369
|