Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import os | |
| # モデルをダウンロードするディレクトリ | |
| model_dir = "models/Miwa-Keita/zenz-v1-checkpoints" | |
| # 不要なファイルを除外し、特定のファイルのみダウンロード | |
| snapshot_download( | |
| repo_id="Miwa-Keita/zenz-v1-checkpoints", | |
| local_dir=model_dir, | |
| allow_patterns=["*.bin", "*.json", "*.txt", "*.model"], # 必要なファイルだけ取得 | |
| ignore_patterns=["optimizer.pt", "checkpoint*"], # いらないファイルを無視 | |
| ) | |
| # モデルとトークナイザーのロード(GPT-2 アーキテクチャ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float32) | |
| # 入力を調整する関数 | |
| def preprocess_input(user_input): | |
| prefix = "\uEE00" # 前に付与する文字列 | |
| suffix = "\uEE01" # 後ろに付与する文字列 | |
| processed_input = prefix + user_input + suffix | |
| return processed_input | |
| # 出力を調整する関数 | |
| def postprocess_output(model_output): | |
| suffix = "\uEE01" | |
| # \uEE01の後の部分を抽出 | |
| if suffix in model_output: | |
| return model_output.split(suffix)[1] | |
| return model_output | |
| # 変換関数 | |
| def generate_text(user_input): | |
| processed_input = preprocess_input(user_input) | |
| # テキストをトークン化 | |
| inputs = tokenizer(processed_input, return_tensors="pt") | |
| # モデルで生成 | |
| outputs = model.generate(**inputs, max_length=100) | |
| # 出力のデコード | |
| decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # 出力の整形 | |
| return postprocess_output(decoded_output) | |
| # Gradio インターフェース | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=gr.Textbox(label="変換する文字列(カタカナ)"), | |
| outputs=gr.Textbox(label="変換結果"), | |
| title="ニューラルかな漢字変換モデル zenz-v1 のデモ", | |
| description="変換したい文字列をカタカナを入力してください" | |
| ) | |
| # ローンチ | |
| iface.launch() |