sirochild commited on
Commit
dc62d4a
·
verified ·
1 Parent(s): 7c8b666

Upload 4 files

Browse files
Files changed (3) hide show
  1. app.py +42 -48
  2. generate_dialogue_with_swallow.py +15 -22
  3. requirements.txt +3 -8
app.py CHANGED
@@ -3,9 +3,10 @@ from groq import Groq
3
  import os
4
  import json
5
  from dotenv import load_dotenv
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
- import torch
8
  import re
 
 
9
  from generate_dialogue_with_swallow import generate_dialogue_with_swallow
10
 
11
  # --- 1. 初期設定とAPIクライアントの初期化 ---
@@ -18,47 +19,44 @@ if not GROQ_API_KEY:
18
 
19
  groq_client = Groq(api_key=GROQ_API_KEY)
20
 
21
- # Swallowモデルの初期化
22
  print("Swallowモデルをロード中...")
23
- MODEL_ID = "tokyotech-llm/Swallow-MX-8x7b-NVE-v0.1"
 
24
 
25
  try:
 
 
 
 
 
26
  # Hugging Face Spaceでの実行時はGPUメモリを節約するための設定
27
  if os.getenv("SPACE_ID"):
28
  print("Hugging Face Space環境を検出しました。メモリ効率の良い設定を使用します。")
29
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
30
- swallow_model = AutoModelForCausalLM.from_pretrained(
31
- MODEL_ID,
32
- torch_dtype=torch.float16,
33
- device_map="auto",
34
- load_in_8bit=True
 
35
  )
36
  else:
37
  # ローカル環境での実行時の設定
38
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
39
- swallow_model = AutoModelForCausalLM.from_pretrained(
40
- MODEL_ID,
41
- torch_dtype=torch.float16,
42
- device_map="auto"
43
  )
44
  print("Swallowモデルのロード完了")
 
45
  except Exception as e:
46
  print(f"Swallowモデルのロードエラー: {e}")
47
- # フォールバックとして小さいモデルを使用
48
- try:
49
- print("フォールバックモデルをロード中...")
50
- tokenizer = AutoTokenizer.from_pretrained("elyza/ELYZA-japanese-Llama-2-7b-instruct")
51
- swallow_model = AutoModelForCausalLM.from_pretrained(
52
- "elyza/ELYZA-japanese-Llama-2-7b-instruct",
53
- torch_dtype=torch.float16,
54
- device_map="auto",
55
- load_in_8bit=True
56
- )
57
- print("フォールバックモデルのロード完了")
58
- except Exception as fallback_error:
59
- print(f"フォールバックモデルのロードエラー: {fallback_error}")
60
- swallow_model = None
61
- tokenizer = None
62
 
63
  # 日本語感情分析モデルの初期化(グローバル変数として保持)
64
  print("日本語感情分析モデルを初期化中...")
@@ -118,26 +116,20 @@ def detect_scene_change(history, message):
118
  ---
119
  # 出力
120
  """
121
- # Swallowモデルを使用してシーン検出
122
  try:
123
- # トークナイズ
124
- inputs = tokenizer(prompt, return_tensors="pt").to(swallow_model.device)
125
-
126
- # 生成パラメータ - シーン検出には低い温度を使用
127
- gen_kwargs = {
128
- "max_new_tokens": 50,
129
- "temperature": 0.1,
130
- "top_p": 0.9,
131
- "do_sample": False,
132
- "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
133
- }
134
-
135
- # 生成
136
- with torch.no_grad():
137
- output = swallow_model.generate(**inputs, **gen_kwargs)
138
 
139
- # デコード
140
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
141
 
142
  # プロンプトを除去して応答のみを取得
143
  response_text = generated_text[len(prompt):].strip().lower()
@@ -157,6 +149,8 @@ def detect_scene_change(history, message):
157
  return None
158
  except Exception as e:
159
  print(f"シーン検出LLMエラー: {e}")
 
 
160
  return None
161
 
162
  def generate_scene_instruction_with_groq(affection, stage_name, scene, previous_topic):
 
3
  import os
4
  import json
5
  from dotenv import load_dotenv
6
+ from transformers import pipeline
 
7
  import re
8
+ from llama_cpp import Llama
9
+ from huggingface_hub import hf_hub_download
10
  from generate_dialogue_with_swallow import generate_dialogue_with_swallow
11
 
12
  # --- 1. 初期設定とAPIクライアントの初期化 ---
 
19
 
20
  groq_client = Groq(api_key=GROQ_API_KEY)
21
 
22
+ # Swallowモデルの初期化(GGUF版)
23
  print("Swallowモデルをロード中...")
24
+ MODEL_REPO = "mmnga/tokyotech-llm-Swallow-MX-8x7b-NVE-v0.1-gguf"
25
+ MODEL_FILE = "tokyotech-llm-Swallow-MX-8x7b-NVE-v0.1-q4_K_M.gguf"
26
 
27
  try:
28
+ # モデルファイルをダウンロード
29
+ print(f"モデルファイル {MODEL_FILE} をダウンロード中...")
30
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
31
+ print(f"モデルファイルのダウンロード完了: {model_path}")
32
+
33
  # Hugging Face Spaceでの実行時はGPUメモリを節約するための設定
34
  if os.getenv("SPACE_ID"):
35
  print("Hugging Face Space環境を検出しました。メモリ効率の良い設定を使用します。")
36
+ # GPUを使用し、低いレイヤー数でロード
37
+ swallow_model = Llama(
38
+ model_path=model_path,
39
+ n_ctx=2048, # コンテキスト長
40
+ n_gpu_layers=-1, # 可能な限りGPUを使用
41
+ n_threads=4, # スレッド数を制限
42
+ verbose=False # デバッグ出力を無効化
43
  )
44
  else:
45
  # ローカル環境での実行時の設定
46
+ swallow_model = Llama(
47
+ model_path=model_path,
48
+ n_ctx=4096, # より長いコンテキスト長
49
+ n_gpu_layers=-1, # 可能な限りGPUを使用
50
+ verbose=True # デバッグ出力を有効化
51
  )
52
  print("Swallowモデルのロード完了")
53
+ tokenizer = None # llama-cppではtokenizerは不要
54
  except Exception as e:
55
  print(f"Swallowモデルのロードエラー: {e}")
56
+ import traceback
57
+ traceback.print_exc()
58
+ swallow_model = None
59
+ tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # 日本語感情分析モデルの初期化(グローバル変数として保持)
62
  print("日本語感情分析モデルを初期化中...")
 
116
  ---
117
  # 出力
118
  """
119
+ # Swallowモデル(GGUF版)を使用してシーン検出
120
  try:
121
+ # llama-cppを使用して生成
122
+ output = swallow_model(
123
+ prompt,
124
+ max_tokens=50,
125
+ temperature=0.1,
126
+ top_p=0.9,
127
+ stop=["#", "\n\n"],
128
+ echo=True # 入力プロンプトも含めて返す
129
+ )
 
 
 
 
 
 
130
 
131
+ # 生成されたテキストを取得
132
+ generated_text = output["choices"][0]["text"]
133
 
134
  # プロンプトを除去して応答のみを取得
135
  response_text = generated_text[len(prompt):].strip().lower()
 
149
  return None
150
  except Exception as e:
151
  print(f"シーン検出LLMエラー: {e}")
152
+ import traceback
153
+ traceback.print_exc()
154
  return None
155
 
156
  def generate_scene_instruction_with_groq(affection, stage_name, scene, previous_topic):
generate_dialogue_with_swallow.py CHANGED
@@ -1,11 +1,10 @@
1
- import torch
2
  import traceback
3
  import datetime
4
  import random
5
 
6
  def generate_dialogue_with_swallow(history, message, affection, stage_name, scene_params, instruction=None, use_simple_prompt=False, swallow_model=None, tokenizer=None, SYSTEM_PROMPT_MARI=None):
7
  """
8
- Swallowモデルを使用して対話応答を生成する関数
9
 
10
  Args:
11
  history: 会話履歴のリスト [(ユーザー発言, ボット応答), ...]
@@ -15,8 +14,8 @@ def generate_dialogue_with_swallow(history, message, affection, stage_name, scen
15
  scene_params: シーンパラメータの辞書
16
  instruction: 特別な指示(シーン遷移時など)
17
  use_simple_prompt: 簡潔なプロンプトを使用するかどうか
18
- swallow_model: Swallowモデルのインスタンス
19
- tokenizer: トークナイザーのインスタンス
20
  SYSTEM_PROMPT_MARI: システムプロンプト
21
 
22
  Returns:
@@ -27,7 +26,7 @@ def generate_dialogue_with_swallow(history, message, affection, stage_name, scen
27
  print(f"scene_params: {scene_params}")
28
 
29
  # モデルがロードされていない場合はフォールバック応答を返す
30
- if swallow_model is None or tokenizer is None:
31
  print("モデルがロードされていないため、フォールバック応答を返します")
32
  return "(……システムエラーが発生しました)"
33
 
@@ -77,24 +76,18 @@ def generate_dialogue_with_swallow(history, message, affection, stage_name, scen
77
  print(f"システムプロンプト: {system_prompt[:100]}...(省略)")
78
 
79
  try:
80
- # トークナイズ
81
- inputs = tokenizer(system_prompt, return_tensors="pt").to(swallow_model.device)
 
 
 
 
 
 
 
82
 
83
- # 生成パラメータ
84
- gen_kwargs = {
85
- "max_new_tokens": 200,
86
- "temperature": 0.95,
87
- "top_p": 0.9,
88
- "do_sample": True,
89
- "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
90
- }
91
-
92
- # 生成
93
- with torch.no_grad():
94
- output = swallow_model.generate(**inputs, **gen_kwargs)
95
-
96
- # デコード
97
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
98
 
99
  # プロンプトを除去して応答のみを取得
100
  response_text = generated_text[len(system_prompt):].strip()
 
 
1
  import traceback
2
  import datetime
3
  import random
4
 
5
  def generate_dialogue_with_swallow(history, message, affection, stage_name, scene_params, instruction=None, use_simple_prompt=False, swallow_model=None, tokenizer=None, SYSTEM_PROMPT_MARI=None):
6
  """
7
+ Swallowモデル(GGUF版)を使用して対話応答を生成する関数
8
 
9
  Args:
10
  history: 会話履歴のリスト [(ユーザー発言, ボット応答), ...]
 
14
  scene_params: シーンパラメータの辞書
15
  instruction: 特別な指示(シーン遷移時など)
16
  use_simple_prompt: 簡潔なプロンプトを使用するかどうか
17
+ swallow_model: Swallowモデル(llama-cpp)のインスタンス
18
+ tokenizer: 未使用(llama-cppでは不要)
19
  SYSTEM_PROMPT_MARI: システムプロンプト
20
 
21
  Returns:
 
26
  print(f"scene_params: {scene_params}")
27
 
28
  # モデルがロードされていない場合はフォールバック応答を返す
29
+ if swallow_model is None:
30
  print("モデルがロードされていないため、フォールバック応答を返します")
31
  return "(……システムエラーが発生しました)"
32
 
 
76
  print(f"システムプロンプト: {system_prompt[:100]}...(省略)")
77
 
78
  try:
79
+ # llama-cppを使用して生成
80
+ output = swallow_model(
81
+ system_prompt,
82
+ max_tokens=200,
83
+ temperature=0.95,
84
+ top_p=0.9,
85
+ stop=["ユーザー:", "\n\n"],
86
+ echo=True # 入力プロンプトも含めて返す
87
+ )
88
 
89
+ # 生成されたテキストを取得
90
+ generated_text = output["choices"][0]["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # プロンプトを除去して応答のみを取得
93
  response_text = generated_text[len(system_prompt):].strip()
requirements.txt CHANGED
@@ -1,14 +1,9 @@
1
  gradio>=5.0.0
2
  groq
3
  python-dotenv
4
- transformers>=4.34.0
5
- torch>=2.0.0
6
- sentencepiece
7
  fugashi
8
  unidic_lite
9
- accelerate>=0.20.0
10
- bitsandbytes>=0.41.0
11
- einops>=0.6.0
12
- safetensors>=0.3.1
13
- huggingface_hub>=0.16.0
14
  protobuf>=3.20.0
 
1
  gradio>=5.0.0
2
  groq
3
  python-dotenv
4
+ llama-cpp-python>=0.2.19
5
+ huggingface_hub>=0.16.0
 
6
  fugashi
7
  unidic_lite
8
+ transformers>=4.34.0
 
 
 
 
9
  protobuf>=3.20.0