File size: 15,029 Bytes
0d64808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd51476
0d64808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import gradio as gr
import subprocess
import os
import sys
import yaml
from pathlib import Path
import time
import threading
import tempfile
import shutil
import gc # ガベージコレクション用

# --- 定数 ---
# Dockerfile内でクローンされるパスに合わせる
SCRIPT_DIR = Path(__file__).parent
SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2"
# ダウンロード用ファイルの一時置き場 (コンテナ内に作成)
OUTPUT_DIR = SCRIPT_DIR / "outputs"

# --- ヘルパー関数 ---
def add_sbv2_to_path():
    """Style-Bert-VITS2リポジトリのパスを sys.path に追加"""
    repo_path_str = str(SBV2_REPO_PATH.resolve())
    if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path:
        sys.path.insert(0, repo_path_str)
        print(f"Added {repo_path_str} to sys.path")
    elif not SBV2_REPO_PATH.exists():
        print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}")

def stream_process_output(process, log_list):
    """サブプロセスの標準出力/エラーをリアルタイムでリストに追加"""
    try:
        if process.stdout:
            for line in iter(process.stdout.readline, ''):
                log_list.append(line.strip()) # 余分な改行を削除
        if process.stderr:
             for line in iter(process.stderr.readline, ''):
                # WARNING以外のstderrはエラーとして強調表示しても良い
                processed_line = f"stderr: {line.strip()}"
                if "warning" not in line.lower():
                     processed_line = f"ERROR (stderr): {line.strip()}"
                log_list.append(processed_line)
    except Exception as e:
        log_list.append(f"Error reading process stream: {e}")

# --- Gradio アプリのバックエンド関数 ---
def convert_safetensors_to_onnx_gradio(safetensors_file_obj, progress=gr.Progress(track_tqdm=True)):
    """
    アップロードされたSafetensorsモデルをONNXに変換し、結果をダウンロード可能にする。
    """
    log = ["Starting ONNX conversion..."]
    # 初期状態ではダウンロードファイルは空
    yield "\n".join(log), None

    if safetensors_file_obj is None:
        log.append("Error: No safetensors file uploaded. Please upload a .safetensors file.")
        # エラーメッセージを表示し、ダウンロードはNoneのまま
        yield "\n".join(log), None
        return

    # Style-Bert-VITS2 パスの確認
    add_sbv2_to_path()
    if not SBV2_REPO_PATH.exists():
        log.append(f"Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.")
        yield "\n".join(log), None
        return

    # 出力ディレクトリ作成 (存在しない場合)
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # 以前の出力ファイルを削除 (オプション、ディスクスペース節約のため)
    # for item in OUTPUT_DIR.glob("*.onnx"):
    #     try:
    #         item.unlink()
    #         log.append(f"Cleaned up previous output: {item.name}")
    #     except OSError as e:
    #         log.append(f"Warning: Could not delete previous file {item.name}: {e}")

    onnx_output_path_str = None # 最終的なONNXファイルパス (文字列)
    current_log = log[:] # ログリストをコピー

    try:
        # 一時ディレクトリを作成して処理
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_dir_path = Path(temp_dir)
            # Gradio Fileコンポーネントは一時ファイルパスを .name で持つ
            original_filename = Path(safetensors_file_obj.name).name
            # ファイル名に不正な文字が含まれていないか基本的なチェック (オプション)
            if "/" in original_filename or "\\" in original_filename or ".." in original_filename:
                 current_log.append(f"Error: Invalid characters found in filename: {original_filename}")
                 yield "\n".join(current_log), None
                 return

            temp_safetensors_path_in_root = temp_dir_path / original_filename

            current_log.append(f"Processing uploaded file: {original_filename}")
            current_log.append(f"Copying to temporary location: {temp_safetensors_path_in_root}")
            yield "\n".join(current_log), None
            # アップロードされたファイルオブジェクトから一時ディレクトリにコピー
            shutil.copy(safetensors_file_obj.name, temp_safetensors_path_in_root)


            # --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 ---
            # モデル名をファイル名から取得 (拡張子なし)
            model_name = temp_safetensors_path_in_root.stem
            # assets_root を一時ディレクトリ自体にする
            assets_root = temp_dir_path
            # assets_root の下に model_name のディレクトリを作成
            model_dir_in_temp = assets_root / model_name
            model_dir_in_temp.mkdir(exist_ok=True)
            # safetensorsファイルを model_name ディレクトリに移動
            temp_safetensors_path = model_dir_in_temp / original_filename
            shutil.move(temp_safetensors_path_in_root, temp_safetensors_path)

            # dataset_root も assets_root と同じにしておく (今回は使用しない)
            dataset_root = assets_root

            current_log.append(f"Using temporary assets_root: {assets_root}")
            current_log.append(f"Created temporary model directory: {model_dir_in_temp}")
            current_log.append(f"Using temporary model path: {temp_safetensors_path}")
            yield "\n".join(current_log), None

            # --- paths.yml を一時的に設定 ---
            config_path = SBV2_REPO_PATH / "configs" / "paths.yml"
            config_path.parent.mkdir(parents=True, exist_ok=True)
            paths_config = {"dataset_root": str(dataset_root.resolve()), "assets_root": str(assets_root.resolve())}
            with open(config_path, "w", encoding="utf-8") as f:
                yaml.dump(paths_config, f)
            current_log.append(f"Saved temporary paths config to {config_path}")
            yield "\n".join(current_log), None

            # --- ONNX変換スクリプト実行 ---
            current_log.append(f"\nStarting ONNX conversion script for: {temp_safetensors_path.name}")
            convert_script = SBV2_REPO_PATH / "convert_onnx.py"
            if not convert_script.exists():
                 current_log.append(f"Error: convert_onnx.py not found at '{convert_script}'.")
                 yield "\n".join(current_log), None
                 return # tryブロックを抜ける

            python_executable = sys.executable
            command = [
                python_executable,
                str(convert_script.resolve()),
                "--model",
                str(temp_safetensors_path.resolve()) # 一時ディレクトリ内のモデルパス
            ]
            current_log.append(f"\nRunning command: {' '.join(command)}")
            yield "\n".join(current_log), None

            process_env = os.environ.copy()
            # メモリリーク対策? (あまり効果ないかも)
            # process_env["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                       text=True, encoding='utf-8', errors='replace',
                                       cwd=SBV2_REPO_PATH, # スクリプトの場所で実行
                                       env=process_env)

            # ログ出力用リスト (スレッドと共有)
            process_output_lines = []
            thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines))
            thread.start()

            # 進捗表示のためのループ
            while thread.is_alive():
                # yieldでGradio UIを更新 (現在のログ + プロセスからのログ)
                yield "\n".join(current_log + process_output_lines), None
                time.sleep(0.2) # 更新頻度

            # スレッド終了待ち
            thread.join()
            # プロセス終了待ち (タイムアウト設定)
            try:
                 process.wait(timeout=900) # 15分タイムアウト
            except subprocess.TimeoutExpired:
                 current_log.extend(process_output_lines) # ここまでのログを追加
                 current_log.append("\nError: Conversion process timed out after 15 minutes.")
                 process.kill() # タイムアウトしたらプロセスを強制終了
                 yield "\n".join(current_log), None
                 return # tryブロックを抜ける

            # 最終的なプロセス出力を取得
            final_stdout, final_stderr = process.communicate()
            if final_stdout:
                 process_output_lines.extend(final_stdout.strip().split('\n'))
            if final_stderr:
                 processed_stderr = []
                 for line in final_stderr.strip().split('\n'):
                     processed_line = f"stderr: {line.strip()}"
                     if "warning" not in line.lower():
                         processed_line = f"ERROR (stderr): {line.strip()}"
                     processed_stderr.append(processed_line)
                 if processed_stderr:
                      process_output_lines.append("--- stderr ---")
                      process_output_lines.extend(processed_stderr)
                      process_output_lines.append("--------------")

            # 全てのプロセスログをメインログに追加
            current_log.extend(process_output_lines)
            current_log.append("\n-------------------------------")

            # --- 結果の確認と出力ファイルのコピー ---
            if process.returncode == 0:
                current_log.append("ONNX conversion command finished successfully.")
                # 期待されるONNXファイルパス (入力と同じディレクトリ内)
                expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx")

                if expected_onnx_path_in_temp.exists():
                    current_log.append(f"Found converted ONNX file: {expected_onnx_path_in_temp.name}")
                    # 一時ディレクトリから永続的な出力ディレクトリにコピー
                    final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name
                    shutil.copy(expected_onnx_path_in_temp, final_onnx_path)
                    current_log.append(f"Copied ONNX file for download to: {final_onnx_path}")
                    onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定
                else:
                    current_log.append(f"Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp}'. Please check the logs.")
            else:
                current_log.append(f"ONNX conversion command failed with return code {process.returncode}.")
                current_log.append("Please check the logs above for errors (especially lines starting with 'ERROR').")

            # 一時ディレクトリが自動で削除される前に yield する必要がある
            yield "\n".join(current_log), onnx_output_path_str

    except FileNotFoundError as e:
        # コマンドが見つからない場合など
        current_log.append(f"\nError: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.")
        current_log.append(f"{e}")
        yield "\n".join(current_log), None
    except Exception as e:
        current_log.append(f"\nAn unexpected error occurred: {e}")
        import traceback
        current_log.append(traceback.format_exc())
        # エラー発生時も最終ログとNoneを返す
        yield "\n".join(current_log), None
    finally:
        # ガベージコレクションを試みる (メモリ解放目的)
        gc.collect()
        # 最終的な状態をUIに反映させるための最後のyield (重要)
        # tryブロック内で既にyieldしている場合でも、finallyで再度yieldすることで
        # UIが最終状態(エラーメッセージや成功メッセージ+ダウンロードリンク)に更新される
        print("Conversion function finished.") # サーバーログ用
        # ここで再度yieldするとUIの更新が確実になることがある
        # yield "\n".join(current_log), onnx_output_path_str


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter")
    gr.Markdown(
        "Upload your `.safetensors` model file, convert it to ONNX format, and download the result. "
        "The environment setup (cloning repo, installing dependencies, running initialize.py) "
        "is handled automatically when this Space starts."
    )

    with gr.Row():
        with gr.Column(scale=1):
            safetensors_upload = gr.File(
                label="1. Upload Safetensors Model",
                file_types=[".safetensors"],
                # file_count="single" (default)
            )
            convert_button = gr.Button("2. Convert to ONNX", variant="primary")
            gr.Markdown("---")
            onnx_download = gr.File(
                label="3. Download ONNX Model",
                interactive=False, # 出力専用
            )
            gr.Markdown(
                "**Note:** Conversion can take several minutes, especially on free hardware. "
                "Please be patient. The log on the right will update during the process."
            )

        with gr.Column(scale=2):
            output_log = gr.Textbox(
                label="Conversion Log",
                lines=25, # 少し高さを増やす
                interactive=False,
                autoscroll=True,
                max_lines=1500 # ログが多くなる可能性を考慮
            )

    # ボタンクリック時のアクション設定
    convert_button.click(
        convert_safetensors_to_onnx_gradio,
        inputs=[safetensors_upload],
        outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力
    )

# --- アプリの起動 ---
if __name__ == "__main__":
    # Style-Bert-VITS2 へのパスを追加
    add_sbv2_to_path()
    # 出力ディレクトリ作成 (存在確認含む)
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {OUTPUT_DIR.resolve()}")

    # Gradioアプリを起動
    # share=True にするとパブリックリンクが生成される(HF Spacesでは不要)
    # queue() を使うと複数ユーザーのリクエストを処理しやすくなる
    demo.queue().launch()