Spaces:
Runtime error
Runtime error
Updated to use arbitrary model paths
Browse files
app.py
CHANGED
|
@@ -31,25 +31,22 @@ from generate_videos import generate_frames, video_from_interpolations, vid_to_g
|
|
| 31 |
model_dir = "models"
|
| 32 |
os.makedirs(model_dir, exist_ok=True)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
def get_models():
|
| 39 |
os.makedirs(model_dir, exist_ok=True)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
hf_hub_download(repo_id=repo_id, filename=file_path)
|
| 43 |
-
if not "akhaliq" in repo_id:
|
| 44 |
-
shutil.move(file_path, os.path.join(model_dir, file_path))
|
| 45 |
-
elif "stylegan2" in file_path:
|
| 46 |
-
shutil.move(file_path, os.path.join(model_dir, "base.pt"))
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
return
|
| 51 |
|
| 52 |
-
|
| 53 |
|
| 54 |
class ImageEditor(object):
|
| 55 |
def __init__(self):
|
|
@@ -62,18 +59,20 @@ class ImageEditor(object):
|
|
| 62 |
|
| 63 |
self.generators = {}
|
| 64 |
|
| 65 |
-
for
|
|
|
|
|
|
|
| 66 |
g_ema = Generator(
|
| 67 |
model_size, latent_size, n_mlp, channel_multiplier=channel_mult
|
| 68 |
).to(self.device)
|
| 69 |
|
| 70 |
-
checkpoint = torch.load(
|
| 71 |
|
| 72 |
g_ema.load_state_dict(checkpoint['g_ema'])
|
| 73 |
|
| 74 |
self.generators[model] = g_ema
|
| 75 |
|
| 76 |
-
self.experiment_args = {"model_path": "
|
| 77 |
self.experiment_args["transform"] = transforms.Compose(
|
| 78 |
[
|
| 79 |
transforms.Resize((256, 256)),
|
|
@@ -96,7 +95,7 @@ class ImageEditor(object):
|
|
| 96 |
self.e4e_net.cuda()
|
| 97 |
|
| 98 |
self.shape_predictor = dlib.shape_predictor(
|
| 99 |
-
|
| 100 |
)
|
| 101 |
|
| 102 |
print("setup complete")
|
|
@@ -120,11 +119,11 @@ class ImageEditor(object):
|
|
| 120 |
):
|
| 121 |
|
| 122 |
if output_style == 'all':
|
| 123 |
-
styles = model_list
|
| 124 |
elif output_style == 'list - enter below':
|
| 125 |
styles = style_list.split(",")
|
| 126 |
for style in styles:
|
| 127 |
-
if style not in model_list:
|
| 128 |
raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
|
| 129 |
else:
|
| 130 |
styles = [output_style]
|
|
|
|
| 31 |
model_dir = "models"
|
| 32 |
os.makedirs(model_dir, exist_ok=True)
|
| 33 |
|
| 34 |
+
model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
|
| 35 |
+
"dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
|
| 36 |
+
"base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt")}
|
| 37 |
|
| 38 |
def get_models():
|
| 39 |
os.makedirs(model_dir, exist_ok=True)
|
| 40 |
|
| 41 |
+
model_paths = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
for model_name, repo_details in model_repos.items():
|
| 44 |
+
download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
|
| 45 |
+
model_paths[model_name] = download_path
|
| 46 |
|
| 47 |
+
return model_paths
|
| 48 |
|
| 49 |
+
model_paths = get_models()
|
| 50 |
|
| 51 |
class ImageEditor(object):
|
| 52 |
def __init__(self):
|
|
|
|
| 59 |
|
| 60 |
self.generators = {}
|
| 61 |
|
| 62 |
+
self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
|
| 63 |
+
|
| 64 |
+
for model in self.model_list:
|
| 65 |
g_ema = Generator(
|
| 66 |
model_size, latent_size, n_mlp, channel_multiplier=channel_mult
|
| 67 |
).to(self.device)
|
| 68 |
|
| 69 |
+
checkpoint = torch.load(model_paths[model])
|
| 70 |
|
| 71 |
g_ema.load_state_dict(checkpoint['g_ema'])
|
| 72 |
|
| 73 |
self.generators[model] = g_ema
|
| 74 |
|
| 75 |
+
self.experiment_args = {"model_path": model_paths["e4e"]}
|
| 76 |
self.experiment_args["transform"] = transforms.Compose(
|
| 77 |
[
|
| 78 |
transforms.Resize((256, 256)),
|
|
|
|
| 95 |
self.e4e_net.cuda()
|
| 96 |
|
| 97 |
self.shape_predictor = dlib.shape_predictor(
|
| 98 |
+
model_paths["dlib"]
|
| 99 |
)
|
| 100 |
|
| 101 |
print("setup complete")
|
|
|
|
| 119 |
):
|
| 120 |
|
| 121 |
if output_style == 'all':
|
| 122 |
+
styles = self.model_list
|
| 123 |
elif output_style == 'list - enter below':
|
| 124 |
styles = style_list.split(",")
|
| 125 |
for style in styles:
|
| 126 |
+
if style not in self.model_list:
|
| 127 |
raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
|
| 128 |
else:
|
| 129 |
styles = [output_style]
|