Spaces:
Sleeping
Sleeping
Commit
·
999d8f3
1
Parent(s):
b079c7b
Update app.py
Browse filesAdd newly released MOAT model
app.py
CHANGED
|
@@ -19,6 +19,7 @@ from Utils import dbimutils
|
|
| 19 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
| 20 |
DESCRIPTION = """
|
| 21 |
Demo for:
|
|
|
|
| 22 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 23 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 24 |
- [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
|
|
@@ -35,6 +36,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
|
|
|
| 38 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
| 39 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
| 40 |
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
|
@@ -63,7 +65,9 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
|
|
| 63 |
def change_model(model_name):
|
| 64 |
global loaded_models
|
| 65 |
|
| 66 |
-
if model_name == "
|
|
|
|
|
|
|
| 67 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
| 68 |
elif model_name == "ConvNext":
|
| 69 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
|
@@ -78,7 +82,7 @@ def change_model(model_name):
|
|
| 78 |
|
| 79 |
def load_labels() -> list[str]:
|
| 80 |
path = huggingface_hub.hf_hub_download(
|
| 81 |
-
|
| 82 |
)
|
| 83 |
df = pd.read_csv(path)
|
| 84 |
|
|
@@ -213,11 +217,17 @@ def predict(
|
|
| 213 |
|
| 214 |
def main():
|
| 215 |
global loaded_models
|
| 216 |
-
loaded_models = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
args = parse_args()
|
| 219 |
|
| 220 |
-
change_model("
|
| 221 |
|
| 222 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
| 223 |
|
|
@@ -233,7 +243,11 @@ def main():
|
|
| 233 |
fn=func,
|
| 234 |
inputs=[
|
| 235 |
gr.Image(type="pil", label="Input"),
|
| 236 |
-
gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
gr.Slider(
|
| 238 |
0,
|
| 239 |
1,
|
|
@@ -257,7 +271,7 @@ def main():
|
|
| 257 |
gr.Label(label="Output (tags)"),
|
| 258 |
gr.HTML(),
|
| 259 |
],
|
| 260 |
-
examples=[["power.jpg", "
|
| 261 |
title=TITLE,
|
| 262 |
description=DESCRIPTION,
|
| 263 |
allow_flagging="never",
|
|
|
|
| 19 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
| 20 |
DESCRIPTION = """
|
| 21 |
Demo for:
|
| 22 |
+
- [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
|
| 23 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 24 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
| 25 |
- [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
|
|
|
|
| 36 |
"""
|
| 37 |
|
| 38 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
| 39 |
+
MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
| 40 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
| 41 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
| 42 |
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
|
|
|
| 65 |
def change_model(model_name):
|
| 66 |
global loaded_models
|
| 67 |
|
| 68 |
+
if model_name == "MOAT":
|
| 69 |
+
model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
|
| 70 |
+
elif model_name == "SwinV2":
|
| 71 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
| 72 |
elif model_name == "ConvNext":
|
| 73 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
|
|
|
| 82 |
|
| 83 |
def load_labels() -> list[str]:
|
| 84 |
path = huggingface_hub.hf_hub_download(
|
| 85 |
+
MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
|
| 86 |
)
|
| 87 |
df = pd.read_csv(path)
|
| 88 |
|
|
|
|
| 217 |
|
| 218 |
def main():
|
| 219 |
global loaded_models
|
| 220 |
+
loaded_models = {
|
| 221 |
+
"MOAT": None,
|
| 222 |
+
"SwinV2": None,
|
| 223 |
+
"ConvNext": None,
|
| 224 |
+
"ConvNextV2": None,
|
| 225 |
+
"ViT": None,
|
| 226 |
+
}
|
| 227 |
|
| 228 |
args = parse_args()
|
| 229 |
|
| 230 |
+
change_model("MOAT")
|
| 231 |
|
| 232 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
| 233 |
|
|
|
|
| 243 |
fn=func,
|
| 244 |
inputs=[
|
| 245 |
gr.Image(type="pil", label="Input"),
|
| 246 |
+
gr.Radio(
|
| 247 |
+
["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
|
| 248 |
+
value="MOAT",
|
| 249 |
+
label="Model",
|
| 250 |
+
),
|
| 251 |
gr.Slider(
|
| 252 |
0,
|
| 253 |
1,
|
|
|
|
| 271 |
gr.Label(label="Output (tags)"),
|
| 272 |
gr.HTML(),
|
| 273 |
],
|
| 274 |
+
examples=[["power.jpg", "MOAT", 0.35, 0.85]],
|
| 275 |
title=TITLE,
|
| 276 |
description=DESCRIPTION,
|
| 277 |
allow_flagging="never",
|