Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import os | |
| import sys | |
| import yaml | |
| from pathlib import Path | |
| import time | |
| import threading | |
| import tempfile | |
| import shutil | |
| import gc # ガベージコレクション用 | |
| # --- 定数 --- | |
| # Dockerfile内でクローンされるパスに合わせる | |
| SCRIPT_DIR = Path(__file__).parent | |
| SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2" | |
| # ダウンロード用ファイルの一時置き場 (コンテナ内に作成) | |
| OUTPUT_DIR = SCRIPT_DIR / "outputs" | |
| # --- ヘルパー関数 --- | |
| def add_sbv2_to_path(): | |
| """Style-Bert-VITS2リポジトリのパスを sys.path に追加""" | |
| repo_path_str = str(SBV2_REPO_PATH.resolve()) | |
| if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path: | |
| sys.path.insert(0, repo_path_str) | |
| print(f"Added {repo_path_str} to sys.path") | |
| elif not SBV2_REPO_PATH.exists(): | |
| print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}") | |
| def stream_process_output(process, log_list): | |
| """サブプロセスの標準出力/エラーをリアルタイムでリストに追加""" | |
| try: | |
| if process.stdout: | |
| for line in iter(process.stdout.readline, ''): | |
| log_list.append(line.strip()) # 余分な改行を削除 | |
| if process.stderr: | |
| for line in iter(process.stderr.readline, ''): | |
| # WARNING以外のstderrはエラーとして強調表示しても良い | |
| processed_line = f"stderr: {line.strip()}" | |
| if "warning" not in line.lower(): | |
| processed_line = f"ERROR (stderr): {line.strip()}" | |
| log_list.append(processed_line) | |
| except Exception as e: | |
| log_list.append(f"Error reading process stream: {e}") | |
| # --- Gradio アプリのバックエンド関数 --- | |
| def convert_safetensors_to_onnx_gradio(safetensors_file_obj, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| アップロードされたSafetensorsモデルをONNXに変換し、結果をダウンロード可能にする。 | |
| """ | |
| log = ["Starting ONNX conversion..."] | |
| # 初期状態ではダウンロードファイルは空 | |
| yield "\n".join(log), None | |
| if safetensors_file_obj is None: | |
| log.append("Error: No safetensors file uploaded. Please upload a .safetensors file.") | |
| # エラーメッセージを表示し、ダウンロードはNoneのまま | |
| yield "\n".join(log), None | |
| return | |
| # Style-Bert-VITS2 パスの確認 | |
| add_sbv2_to_path() | |
| if not SBV2_REPO_PATH.exists(): | |
| log.append(f"Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.") | |
| yield "\n".join(log), None | |
| return | |
| # 出力ディレクトリ作成 (存在しない場合) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # 以前の出力ファイルを削除 (オプション、ディスクスペース節約のため) | |
| # for item in OUTPUT_DIR.glob("*.onnx"): | |
| # try: | |
| # item.unlink() | |
| # log.append(f"Cleaned up previous output: {item.name}") | |
| # except OSError as e: | |
| # log.append(f"Warning: Could not delete previous file {item.name}: {e}") | |
| onnx_output_path_str = None # 最終的なONNXファイルパス (文字列) | |
| current_log = log[:] # ログリストをコピー | |
| try: | |
| # 一時ディレクトリを作成して処理 | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_dir_path = Path(temp_dir) | |
| # Gradio Fileコンポーネントは一時ファイルパスを .name で持つ | |
| original_filename = Path(safetensors_file_obj.name).name | |
| # ファイル名に不正な文字が含まれていないか基本的なチェック (オプション) | |
| if "/" in original_filename or "\\" in original_filename or ".." in original_filename: | |
| current_log.append(f"Error: Invalid characters found in filename: {original_filename}") | |
| yield "\n".join(current_log), None | |
| return | |
| temp_safetensors_path_in_root = temp_dir_path / original_filename | |
| current_log.append(f"Processing uploaded file: {original_filename}") | |
| current_log.append(f"Copying to temporary location: {temp_safetensors_path_in_root}") | |
| yield "\n".join(current_log), None | |
| # アップロードされたファイルオブジェクトから一時ディレクトリにコピー | |
| shutil.copy(safetensors_file_obj.name, temp_safetensors_path_in_root) | |
| # --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 --- | |
| # モデル名をファイル名から取得 (拡張子なし) | |
| model_name = temp_safetensors_path_in_root.stem | |
| # assets_root を一時ディレクトリ自体にする | |
| assets_root = temp_dir_path | |
| # assets_root の下に model_name のディレクトリを作成 | |
| model_dir_in_temp = assets_root / model_name | |
| model_dir_in_temp.mkdir(exist_ok=True) | |
| # safetensorsファイルを model_name ディレクトリに移動 | |
| temp_safetensors_path = model_dir_in_temp / original_filename | |
| shutil.move(temp_safetensors_path_in_root, temp_safetensors_path) | |
| # dataset_root も assets_root と同じにしておく (今回は使用しない) | |
| dataset_root = assets_root | |
| current_log.append(f"Using temporary assets_root: {assets_root}") | |
| current_log.append(f"Created temporary model directory: {model_dir_in_temp}") | |
| current_log.append(f"Using temporary model path: {temp_safetensors_path}") | |
| yield "\n".join(current_log), None | |
| # --- paths.yml を一時的に設定 --- | |
| config_path = SBV2_REPO_PATH / "configs" / "paths.yml" | |
| config_path.parent.mkdir(parents=True, exist_ok=True) | |
| paths_config = {"dataset_root": str(dataset_root.resolve()), "assets_root": str(assets_root.resolve())} | |
| with open(config_path, "w", encoding="utf-8") as f: | |
| yaml.dump(paths_config, f) | |
| current_log.append(f"Saved temporary paths config to {config_path}") | |
| yield "\n".join(current_log), None | |
| # --- ONNX変換スクリプト実行 --- | |
| current_log.append(f"\nStarting ONNX conversion script for: {temp_safetensors_path.name}") | |
| convert_script = SBV2_REPO_PATH / "convert_onnx.py" | |
| if not convert_script.exists(): | |
| current_log.append(f"Error: convert_onnx.py not found at '{convert_script}'.") | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| python_executable = sys.executable | |
| command = [ | |
| python_executable, | |
| str(convert_script.resolve()), | |
| "--model", | |
| str(temp_safetensors_path.resolve()) # 一時ディレクトリ内のモデルパス | |
| ] | |
| current_log.append(f"\nRunning command: {' '.join(command)}") | |
| yield "\n".join(current_log), None | |
| process_env = os.environ.copy() | |
| # メモリリーク対策? (あまり効果ないかも) | |
| # process_env["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, | |
| text=True, encoding='utf-8', errors='replace', | |
| cwd=SBV2_REPO_PATH, # スクリプトの場所で実行 | |
| env=process_env) | |
| # ログ出力用リスト (スレッドと共有) | |
| process_output_lines = [] | |
| thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines)) | |
| thread.start() | |
| # 進捗表示のためのループ | |
| while thread.is_alive(): | |
| # yieldでGradio UIを更新 (現在のログ + プロセスからのログ) | |
| yield "\n".join(current_log + process_output_lines), None | |
| time.sleep(0.2) # 更新頻度 | |
| # スレッド終了待ち | |
| thread.join() | |
| # プロセス終了待ち (タイムアウト設定) | |
| try: | |
| process.wait(timeout=900) # 15分タイムアウト | |
| except subprocess.TimeoutExpired: | |
| current_log.extend(process_output_lines) # ここまでのログを追加 | |
| current_log.append("\nError: Conversion process timed out after 15 minutes.") | |
| process.kill() # タイムアウトしたらプロセスを強制終了 | |
| yield "\n".join(current_log), None | |
| return # tryブロックを抜ける | |
| # 最終的なプロセス出力を取得 | |
| final_stdout, final_stderr = process.communicate() | |
| if final_stdout: | |
| process_output_lines.extend(final_stdout.strip().split('\n')) | |
| if final_stderr: | |
| processed_stderr = [] | |
| for line in final_stderr.strip().split('\n'): | |
| processed_line = f"stderr: {line.strip()}" | |
| if "warning" not in line.lower(): | |
| processed_line = f"ERROR (stderr): {line.strip()}" | |
| processed_stderr.append(processed_line) | |
| if processed_stderr: | |
| process_output_lines.append("--- stderr ---") | |
| process_output_lines.extend(processed_stderr) | |
| process_output_lines.append("--------------") | |
| # 全てのプロセスログをメインログに追加 | |
| current_log.extend(process_output_lines) | |
| current_log.append("\n-------------------------------") | |
| # --- 結果の確認と出力ファイルのコピー --- | |
| if process.returncode == 0: | |
| current_log.append("ONNX conversion command finished successfully.") | |
| # 期待されるONNXファイルパス (入力と同じディレクトリ内) | |
| expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx") | |
| if expected_onnx_path_in_temp.exists(): | |
| current_log.append(f"Found converted ONNX file: {expected_onnx_path_in_temp.name}") | |
| # 一時ディレクトリから永続的な出力ディレクトリにコピー | |
| final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name | |
| shutil.copy(expected_onnx_path_in_temp, final_onnx_path) | |
| current_log.append(f"Copied ONNX file for download to: {final_onnx_path}") | |
| onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定 | |
| else: | |
| current_log.append(f"Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp}'. Please check the logs.") | |
| else: | |
| current_log.append(f"ONNX conversion command failed with return code {process.returncode}.") | |
| current_log.append("Please check the logs above for errors (especially lines starting with 'ERROR').") | |
| # 一時ディレクトリが自動で削除される前に yield する必要がある | |
| yield "\n".join(current_log), onnx_output_path_str | |
| except FileNotFoundError as e: | |
| # コマンドが見つからない場合など | |
| current_log.append(f"\nError: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.") | |
| current_log.append(f"{e}") | |
| yield "\n".join(current_log), None | |
| except Exception as e: | |
| current_log.append(f"\nAn unexpected error occurred: {e}") | |
| import traceback | |
| current_log.append(traceback.format_exc()) | |
| # エラー発生時も最終ログとNoneを返す | |
| yield "\n".join(current_log), None | |
| finally: | |
| # ガベージコレクションを試みる (メモリ解放目的) | |
| gc.collect() | |
| # 最終的な状態をUIに反映させるための最後のyield (重要) | |
| # tryブロック内で既にyieldしている場合でも、finallyで再度yieldすることで | |
| # UIが最終状態(エラーメッセージや成功メッセージ+ダウンロードリンク)に更新される | |
| print("Conversion function finished.") # サーバーログ用 | |
| # ここで再度yieldするとUIの更新が確実になることがある | |
| # yield "\n".join(current_log), onnx_output_path_str | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter") | |
| gr.Markdown( | |
| "Upload your `.safetensors` model file, convert it to ONNX format, and download the result. " | |
| "The environment setup (cloning repo, installing dependencies, running initialize.py) " | |
| "is handled automatically when this Space starts." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| safetensors_upload = gr.File( | |
| label="1. Upload Safetensors Model", | |
| file_types=[".safetensors"], | |
| # file_count="single" (default) | |
| ) | |
| convert_button = gr.Button("2. Convert to ONNX", variant="primary") | |
| gr.Markdown("---") | |
| onnx_download = gr.File( | |
| label="3. Download ONNX Model", | |
| interactive=False, # 出力専用 | |
| ) | |
| gr.Markdown( | |
| "**Note:** Conversion can take several minutes, especially on free hardware. " | |
| "Please be patient. The log on the right will update during the process." | |
| ) | |
| with gr.Column(scale=2): | |
| output_log = gr.Textbox( | |
| label="Conversion Log", | |
| lines=25, # 少し高さを増やす | |
| interactive=False, | |
| autoscroll=True, | |
| max_lines=1500 # ログが多くなる可能性を考慮 | |
| ) | |
| # ボタンクリック時のアクション設定 | |
| convert_button.click( | |
| convert_safetensors_to_onnx_gradio, | |
| inputs=[safetensors_upload], | |
| outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力 | |
| ) | |
| # --- アプリの起動 --- | |
| if __name__ == "__main__": | |
| # Style-Bert-VITS2 へのパスを追加 | |
| add_sbv2_to_path() | |
| # 出力ディレクトリ作成 (存在確認含む) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| print(f"Output directory: {OUTPUT_DIR.resolve()}") | |
| # Gradioアプリを起動 | |
| # share=True にするとパブリックリンクが生成される(HF Spacesでは不要) | |
| # queue() を使うと複数ユーザーのリクエストを処理しやすくなる | |
| demo.queue().launch() |