Spaces:
Paused
Paused
Julian Bilcke
commited on
Commit
·
56c6949
1
Parent(s):
90ccabd
tentative fix
Browse files
app.py
CHANGED
|
@@ -35,16 +35,20 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
| 35 |
# Import args_config module first
|
| 36 |
import OmniAvatar.utils.args_config
|
| 37 |
|
| 38 |
-
# Create
|
| 39 |
class Args:
|
| 40 |
def __init__(self):
|
| 41 |
self.rank = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self.dtype = 'bf16'
|
| 43 |
-
self.exp_path = str(MODELS_DIR / "OmniAvatar-1.3B")
|
| 44 |
-
self.dit_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
|
| 45 |
-
self.text_encoder_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 46 |
-
self.vae_path = str(MODELS_DIR / "Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 47 |
-
self.wav2vec_path = str(MODELS_DIR / "wav2vec2-base-960h")
|
| 48 |
self.train_architecture = 'lora'
|
| 49 |
self.lora_rank = 128
|
| 50 |
self.lora_alpha = 64.0
|
|
@@ -69,6 +73,24 @@ class Args:
|
|
| 69 |
self.image_sizes_720 = [[400, 720], [720, 720], [720, 400]]
|
| 70 |
self.image_sizes_1280 = [[720, 720], [528, 960], [960, 528], [720, 1280], [1280, 720]]
|
| 71 |
self.seq_len = 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Set the global args before any other OmniAvatar imports
|
| 74 |
OmniAvatar.utils.args_config.args = Args()
|
|
|
|
| 35 |
# Import args_config module first
|
| 36 |
import OmniAvatar.utils.args_config
|
| 37 |
|
| 38 |
+
# Create and set global args before any other OmniAvatar imports
|
| 39 |
class Args:
|
| 40 |
def __init__(self):
|
| 41 |
self.rank = 0
|
| 42 |
+
self.world_size = 1
|
| 43 |
+
self.local_rank = 0
|
| 44 |
+
self.device = 'cuda:0'
|
| 45 |
+
self.num_nodes = 1
|
| 46 |
self.dtype = 'bf16'
|
| 47 |
+
self.exp_path = str(Path(os.environ.get('MODELS_DIR', 'pretrained_models')) / "OmniAvatar-1.3B")
|
| 48 |
+
self.dit_path = str(Path(os.environ.get('MODELS_DIR', 'pretrained_models')) / "Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
|
| 49 |
+
self.text_encoder_path = str(Path(os.environ.get('MODELS_DIR', 'pretrained_models')) / "Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 50 |
+
self.vae_path = str(Path(os.environ.get('MODELS_DIR', 'pretrained_models')) / "Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 51 |
+
self.wav2vec_path = str(Path(os.environ.get('MODELS_DIR', 'pretrained_models')) / "wav2vec2-base-960h")
|
| 52 |
self.train_architecture = 'lora'
|
| 53 |
self.lora_rank = 128
|
| 54 |
self.lora_alpha = 64.0
|
|
|
|
| 73 |
self.image_sizes_720 = [[400, 720], [720, 720], [720, 400]]
|
| 74 |
self.image_sizes_1280 = [[720, 720], [528, 960], [960, 528], [720, 1280], [1280, 720]]
|
| 75 |
self.seq_len = 200
|
| 76 |
+
self.infer = True
|
| 77 |
+
self.debug = False
|
| 78 |
+
|
| 79 |
+
def __contains__(self, key):
|
| 80 |
+
"""Support 'in' operator for checking if attribute exists"""
|
| 81 |
+
return hasattr(self, key)
|
| 82 |
+
|
| 83 |
+
def __iter__(self):
|
| 84 |
+
"""Make the Args object iterable over its attributes"""
|
| 85 |
+
return iter(self.__dict__)
|
| 86 |
+
|
| 87 |
+
def keys(self):
|
| 88 |
+
"""Return the attribute names"""
|
| 89 |
+
return self.__dict__.keys()
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, key):
|
| 92 |
+
"""Support dictionary-style access"""
|
| 93 |
+
return getattr(self, key)
|
| 94 |
|
| 95 |
# Set the global args before any other OmniAvatar imports
|
| 96 |
OmniAvatar.utils.args_config.args = Args()
|