dangthr commited on
Commit
1565972
·
verified ·
1 Parent(s): a713267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -49,16 +49,17 @@ def resolve_path(user_path, repo_root):
49
  1. 優先檢查本地路徑(絕對或相對)。
50
  2. 如果找不到,則嘗試從 HF 快取目錄中尋找。
51
  """
 
52
  if os.path.exists(user_path):
53
  print(f"找到本地檔案: {os.path.abspath(user_path)}")
54
  return os.path.abspath(user_path)
55
 
 
56
  potential_repo_path = os.path.join(repo_root, user_path)
57
  if os.path.exists(potential_repo_path):
58
  print(f"在 Hugging Face 快取目錄中找到檔案: {potential_repo_path}")
59
  return potential_repo_path
60
 
61
- print(f"錯誤:在任何位置都找不到檔案: {user_path}")
62
  return None
63
 
64
  def setup_models(repo_root, model_version):
@@ -66,10 +67,15 @@ def setup_models(repo_root, model_version):
66
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
67
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
68
 
69
- # 這個路徑現在應該可以正確找到了
70
- config_path = os.path.join(repo_root, "deepspeed_config/wan2.1/wan_civitai.yaml")
71
- if not os.path.exists(config_path):
72
- raise FileNotFoundError(f"設定檔未找到: {config_path},請檢查 snapshot_download 是否已下載此檔案。")
 
 
 
 
 
73
  config = OmegaConf.load(config_path)
74
  sampler_name = "Flow"
75
 
@@ -202,17 +208,17 @@ def main():
202
  args = parser.parse_args()
203
 
204
  print("--- 步驟 1: 正在檢查並下載模型與設定檔 ---")
205
- # <<< 核心修正:加入 'deepspeed_config/**' 來下載設定檔 >>>
206
  repo_root = snapshot_download(
207
  repo_id="FrancisRing/StableAvatar",
208
  allow_patterns=[
209
  "StableAvatar-1.3B/*",
210
  "Wan2.1-Fun-V1.1-1.3B-InP/*",
211
  "wav2vec2-base-960h/*",
212
- "deepspeed_config/**" # <-- 修正點
 
213
  ],
214
  )
215
- # <<< 修正結束 >>>
216
  print("模型檔案已準備就緒。")
217
 
218
  print("\n--- 步驟 2: 正在解析輸入檔案路徑 ---")
@@ -227,7 +233,6 @@ def main():
227
  return
228
 
229
  print("\n--- 步驟 3: 正在載入模型 ---")
230
- # 將 repo_root 傳遞給 setup_models,這樣它才能在正確的位置找到設定檔
231
  pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
232
  print("模型載入完成。")
233
 
 
49
  1. 優先檢查本地路徑(絕對或相對)。
50
  2. 如果找不到,則嘗試從 HF 快取目錄中尋找。
51
  """
52
+ # 優先檢查本地路徑
53
  if os.path.exists(user_path):
54
  print(f"找到本地檔案: {os.path.abspath(user_path)}")
55
  return os.path.abspath(user_path)
56
 
57
+ # 其次,嘗試從 HF 快取目錄中尋找
58
  potential_repo_path = os.path.join(repo_root, user_path)
59
  if os.path.exists(potential_repo_path):
60
  print(f"在 Hugging Face 快取目錄中找到檔案: {potential_repo_path}")
61
  return potential_repo_path
62
 
 
63
  return None
64
 
65
  def setup_models(repo_root, model_version):
 
67
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
68
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
69
 
70
+ # <<< 核心修正:對設定檔使用與輸入檔案相同的路徑解析邏輯 >>>
71
+ config_relative_path = "deepspeed_config/wan2.1/wan_civitai.yaml"
72
+ config_path = resolve_path(config_relative_path, repo_root)
73
+
74
+ if not config_path:
75
+ raise FileNotFoundError(f"設定檔 '{config_relative_path}' 在當前目錄或 HF 快取中都找不到。請確保該檔案存在。")
76
+ # <<< 修正結束 >>>
77
+
78
+ print(f"正在從 {config_path} 載入設定...")
79
  config = OmegaConf.load(config_path)
80
  sampler_name = "Flow"
81
 
 
208
  args = parser.parse_args()
209
 
210
  print("--- 步驟 1: 正在檢查並下載模型與設定檔 ---")
211
+ # 確保所有需要的檔案都被下載,以作為本地找不到檔案時的後備
212
  repo_root = snapshot_download(
213
  repo_id="FrancisRing/StableAvatar",
214
  allow_patterns=[
215
  "StableAvatar-1.3B/*",
216
  "Wan2.1-Fun-V1.1-1.3B-InP/*",
217
  "wav2vec2-base-960h/*",
218
+ "deepspeed_config/**",
219
+ "example_case/**" # 也下載範例,以防使用者直接執行預設參數
220
  ],
221
  )
 
222
  print("模型檔案已準備就緒。")
223
 
224
  print("\n--- 步驟 2: 正在解析輸入檔案路徑 ---")
 
233
  return
234
 
235
  print("\n--- 步驟 3: 正在載入模型 ---")
 
236
  pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
237
  print("模型載入完成。")
238