sbv2_onnx / app.py
aka7774's picture
Update app.py
cd51476 verified
raw
history blame
15 kB
import gradio as gr
import subprocess
import os
import sys
import yaml
from pathlib import Path
import time
import threading
import tempfile
import shutil
import gc # ガベージコレクション用
# --- 定数 ---
# Dockerfile内でクローンされるパスに合わせる
SCRIPT_DIR = Path(__file__).parent
SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2"
# ダウンロード用ファイルの一時置き場 (コンテナ内に作成)
OUTPUT_DIR = SCRIPT_DIR / "outputs"
# --- ヘルパー関数 ---
def add_sbv2_to_path():
"""Style-Bert-VITS2リポジトリのパスを sys.path に追加"""
repo_path_str = str(SBV2_REPO_PATH.resolve())
if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path:
sys.path.insert(0, repo_path_str)
print(f"Added {repo_path_str} to sys.path")
elif not SBV2_REPO_PATH.exists():
print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}")
def stream_process_output(process, log_list):
"""サブプロセスの標準出力/エラーをリアルタイムでリストに追加"""
try:
if process.stdout:
for line in iter(process.stdout.readline, ''):
log_list.append(line.strip()) # 余分な改行を削除
if process.stderr:
for line in iter(process.stderr.readline, ''):
# WARNING以外のstderrはエラーとして強調表示しても良い
processed_line = f"stderr: {line.strip()}"
if "warning" not in line.lower():
processed_line = f"ERROR (stderr): {line.strip()}"
log_list.append(processed_line)
except Exception as e:
log_list.append(f"Error reading process stream: {e}")
# --- Gradio アプリのバックエンド関数 ---
def convert_safetensors_to_onnx_gradio(safetensors_file_obj, progress=gr.Progress(track_tqdm=True)):
"""
アップロードされたSafetensorsモデルをONNXに変換し、結果をダウンロード可能にする。
"""
log = ["Starting ONNX conversion..."]
# 初期状態ではダウンロードファイルは空
yield "\n".join(log), None
if safetensors_file_obj is None:
log.append("Error: No safetensors file uploaded. Please upload a .safetensors file.")
# エラーメッセージを表示し、ダウンロードはNoneのまま
yield "\n".join(log), None
return
# Style-Bert-VITS2 パスの確認
add_sbv2_to_path()
if not SBV2_REPO_PATH.exists():
log.append(f"Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.")
yield "\n".join(log), None
return
# 出力ディレクトリ作成 (存在しない場合)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# 以前の出力ファイルを削除 (オプション、ディスクスペース節約のため)
# for item in OUTPUT_DIR.glob("*.onnx"):
# try:
# item.unlink()
# log.append(f"Cleaned up previous output: {item.name}")
# except OSError as e:
# log.append(f"Warning: Could not delete previous file {item.name}: {e}")
onnx_output_path_str = None # 最終的なONNXファイルパス (文字列)
current_log = log[:] # ログリストをコピー
try:
# 一時ディレクトリを作成して処理
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
# Gradio Fileコンポーネントは一時ファイルパスを .name で持つ
original_filename = Path(safetensors_file_obj.name).name
# ファイル名に不正な文字が含まれていないか基本的なチェック (オプション)
if "/" in original_filename or "\\" in original_filename or ".." in original_filename:
current_log.append(f"Error: Invalid characters found in filename: {original_filename}")
yield "\n".join(current_log), None
return
temp_safetensors_path_in_root = temp_dir_path / original_filename
current_log.append(f"Processing uploaded file: {original_filename}")
current_log.append(f"Copying to temporary location: {temp_safetensors_path_in_root}")
yield "\n".join(current_log), None
# アップロードされたファイルオブジェクトから一時ディレクトリにコピー
shutil.copy(safetensors_file_obj.name, temp_safetensors_path_in_root)
# --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 ---
# モデル名をファイル名から取得 (拡張子なし)
model_name = temp_safetensors_path_in_root.stem
# assets_root を一時ディレクトリ自体にする
assets_root = temp_dir_path
# assets_root の下に model_name のディレクトリを作成
model_dir_in_temp = assets_root / model_name
model_dir_in_temp.mkdir(exist_ok=True)
# safetensorsファイルを model_name ディレクトリに移動
temp_safetensors_path = model_dir_in_temp / original_filename
shutil.move(temp_safetensors_path_in_root, temp_safetensors_path)
# dataset_root も assets_root と同じにしておく (今回は使用しない)
dataset_root = assets_root
current_log.append(f"Using temporary assets_root: {assets_root}")
current_log.append(f"Created temporary model directory: {model_dir_in_temp}")
current_log.append(f"Using temporary model path: {temp_safetensors_path}")
yield "\n".join(current_log), None
# --- paths.yml を一時的に設定 ---
config_path = SBV2_REPO_PATH / "configs" / "paths.yml"
config_path.parent.mkdir(parents=True, exist_ok=True)
paths_config = {"dataset_root": str(dataset_root.resolve()), "assets_root": str(assets_root.resolve())}
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(paths_config, f)
current_log.append(f"Saved temporary paths config to {config_path}")
yield "\n".join(current_log), None
# --- ONNX変換スクリプト実行 ---
current_log.append(f"\nStarting ONNX conversion script for: {temp_safetensors_path.name}")
convert_script = SBV2_REPO_PATH / "convert_onnx.py"
if not convert_script.exists():
current_log.append(f"Error: convert_onnx.py not found at '{convert_script}'.")
yield "\n".join(current_log), None
return # tryブロックを抜ける
python_executable = sys.executable
command = [
python_executable,
str(convert_script.resolve()),
"--model",
str(temp_safetensors_path.resolve()) # 一時ディレクトリ内のモデルパス
]
current_log.append(f"\nRunning command: {' '.join(command)}")
yield "\n".join(current_log), None
process_env = os.environ.copy()
# メモリリーク対策? (あまり効果ないかも)
# process_env["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, encoding='utf-8', errors='replace',
cwd=SBV2_REPO_PATH, # スクリプトの場所で実行
env=process_env)
# ログ出力用リスト (スレッドと共有)
process_output_lines = []
thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines))
thread.start()
# 進捗表示のためのループ
while thread.is_alive():
# yieldでGradio UIを更新 (現在のログ + プロセスからのログ)
yield "\n".join(current_log + process_output_lines), None
time.sleep(0.2) # 更新頻度
# スレッド終了待ち
thread.join()
# プロセス終了待ち (タイムアウト設定)
try:
process.wait(timeout=900) # 15分タイムアウト
except subprocess.TimeoutExpired:
current_log.extend(process_output_lines) # ここまでのログを追加
current_log.append("\nError: Conversion process timed out after 15 minutes.")
process.kill() # タイムアウトしたらプロセスを強制終了
yield "\n".join(current_log), None
return # tryブロックを抜ける
# 最終的なプロセス出力を取得
final_stdout, final_stderr = process.communicate()
if final_stdout:
process_output_lines.extend(final_stdout.strip().split('\n'))
if final_stderr:
processed_stderr = []
for line in final_stderr.strip().split('\n'):
processed_line = f"stderr: {line.strip()}"
if "warning" not in line.lower():
processed_line = f"ERROR (stderr): {line.strip()}"
processed_stderr.append(processed_line)
if processed_stderr:
process_output_lines.append("--- stderr ---")
process_output_lines.extend(processed_stderr)
process_output_lines.append("--------------")
# 全てのプロセスログをメインログに追加
current_log.extend(process_output_lines)
current_log.append("\n-------------------------------")
# --- 結果の確認と出力ファイルのコピー ---
if process.returncode == 0:
current_log.append("ONNX conversion command finished successfully.")
# 期待されるONNXファイルパス (入力と同じディレクトリ内)
expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx")
if expected_onnx_path_in_temp.exists():
current_log.append(f"Found converted ONNX file: {expected_onnx_path_in_temp.name}")
# 一時ディレクトリから永続的な出力ディレクトリにコピー
final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name
shutil.copy(expected_onnx_path_in_temp, final_onnx_path)
current_log.append(f"Copied ONNX file for download to: {final_onnx_path}")
onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定
else:
current_log.append(f"Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp}'. Please check the logs.")
else:
current_log.append(f"ONNX conversion command failed with return code {process.returncode}.")
current_log.append("Please check the logs above for errors (especially lines starting with 'ERROR').")
# 一時ディレクトリが自動で削除される前に yield する必要がある
yield "\n".join(current_log), onnx_output_path_str
except FileNotFoundError as e:
# コマンドが見つからない場合など
current_log.append(f"\nError: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.")
current_log.append(f"{e}")
yield "\n".join(current_log), None
except Exception as e:
current_log.append(f"\nAn unexpected error occurred: {e}")
import traceback
current_log.append(traceback.format_exc())
# エラー発生時も最終ログとNoneを返す
yield "\n".join(current_log), None
finally:
# ガベージコレクションを試みる (メモリ解放目的)
gc.collect()
# 最終的な状態をUIに反映させるための最後のyield (重要)
# tryブロック内で既にyieldしている場合でも、finallyで再度yieldすることで
# UIが最終状態(エラーメッセージや成功メッセージ+ダウンロードリンク)に更新される
print("Conversion function finished.") # サーバーログ用
# ここで再度yieldするとUIの更新が確実になることがある
# yield "\n".join(current_log), onnx_output_path_str
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter")
gr.Markdown(
"Upload your `.safetensors` model file, convert it to ONNX format, and download the result. "
"The environment setup (cloning repo, installing dependencies, running initialize.py) "
"is handled automatically when this Space starts."
)
with gr.Row():
with gr.Column(scale=1):
safetensors_upload = gr.File(
label="1. Upload Safetensors Model",
file_types=[".safetensors"],
# file_count="single" (default)
)
convert_button = gr.Button("2. Convert to ONNX", variant="primary")
gr.Markdown("---")
onnx_download = gr.File(
label="3. Download ONNX Model",
interactive=False, # 出力専用
)
gr.Markdown(
"**Note:** Conversion can take several minutes, especially on free hardware. "
"Please be patient. The log on the right will update during the process."
)
with gr.Column(scale=2):
output_log = gr.Textbox(
label="Conversion Log",
lines=25, # 少し高さを増やす
interactive=False,
autoscroll=True,
max_lines=1500 # ログが多くなる可能性を考慮
)
# ボタンクリック時のアクション設定
convert_button.click(
convert_safetensors_to_onnx_gradio,
inputs=[safetensors_upload],
outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力
)
# --- アプリの起動 ---
if __name__ == "__main__":
# Style-Bert-VITS2 へのパスを追加
add_sbv2_to_path()
# 出力ディレクトリ作成 (存在確認含む)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR.resolve()}")
# Gradioアプリを起動
# share=True にするとパブリックリンクが生成される(HF Spacesでは不要)
# queue() を使うと複数ユーザーのリクエストを処理しやすくなる
demo.queue().launch()