sirochild commited on
Commit
90babcf
·
verified ·
1 Parent(s): 1224dcb

Upload core_dialogue.py

Browse files
Files changed (1) hide show
  1. core_dialogue.py +110 -22
core_dialogue.py CHANGED
@@ -16,7 +16,10 @@ class DialogueGenerator:
16
  def __init__(self):
17
  self.client = None
18
  self.model = None
 
 
19
  self._initialize_client()
 
20
 
21
  def _initialize_client(self):
22
  """Together.ai APIクライアントの初期化"""
@@ -35,6 +38,23 @@ class DialogueGenerator:
35
  except Exception as e:
36
  logger.error(f"Together.ai APIクライアントの初期化に失敗しました: {e}")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def get_system_prompt_mari(self, use_ura_mode: bool = False) -> str:
39
  """環境変数からシステムプロンプトを取得、なければデフォルトを返す"""
40
  if use_ura_mode:
@@ -90,16 +110,9 @@ class DialogueGenerator:
90
  return os.getenv("SYSTEM_PROMPT_MARI", default_prompt)
91
 
92
  def call_llm(self, system_prompt: str, user_prompt: str, is_json_output: bool = False) -> str:
93
- """Together.ai APIを呼び出す"""
94
  logger.info(f"🔗 call_llm開始 - is_json_output: {is_json_output}")
95
 
96
- if not self.client:
97
- logger.warning("⚠️ APIクライアントが利用できません - デモモード応答を返します")
98
- # デモモード用の固定応答(隠された真実付き)
99
- if is_json_output:
100
- return '{"scene": "none"}'
101
- return "[HIDDEN:(本当は話したいけど...)]は?何それ。あたしに話しかけてるの?"
102
-
103
  # 入力検証
104
  if not isinstance(system_prompt, str) or not isinstance(user_prompt, str):
105
  logger.error(f"プロンプトが文字列ではありません: system={type(system_prompt)}, user={type(user_prompt)}")
@@ -107,41 +120,116 @@ class DialogueGenerator:
107
  return '{"scene": "none"}'
108
  return "…なんか変なこと言ってない?"
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
- # Together.ai APIを呼び出し
 
 
 
112
  # JSON出力の場合は短く、通常の対話は適度な長さに制限
113
  max_tokens = 150 if is_json_output else 500
114
  logger.info(f"🔗 Together.ai API呼び出し開始 - model: {self.model}, max_tokens: {max_tokens}")
115
 
116
- response = self.client.chat.completions.create(
117
- model=self.model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  messages=[
119
  {"role": "system", "content": system_prompt},
120
  {"role": "user", "content": user_prompt}
121
  ],
122
  temperature=0.8,
123
  max_tokens=max_tokens,
 
124
  )
125
 
126
- logger.info("🔗 Together.ai API呼び出し完了")
127
 
128
  content = response.choices[0].message.content if response.choices else ""
129
- logger.info(f"🔗 API応答内容: '{content[:100]}...' (長さ: {len(content)}文字)")
130
 
131
  if not content:
132
- logger.warning("Together.ai API応答が空です")
133
- if is_json_output:
134
- return '{"scene": "none"}'
135
- return "[HIDDEN:(何て言えばいいか分からない...)]…言葉が出てこない。"
136
 
137
- logger.info("🔗 call_llm正常完了")
138
  return content
139
 
140
  except Exception as e:
141
- logger.error(f"Together.ai API呼び出しエラー: {e}")
142
- if is_json_output:
143
- return '{"scene": "none"}'
144
- return "[HIDDEN:(システムが不調で困ってる...)]…システムの調子が悪いみたい。"
145
 
146
  def generate_dialogue(self, history: List[Tuple[str, str]], message: str,
147
  affection: int, stage_name: str, scene_params: Dict[str, Any],
 
16
  def __init__(self):
17
  self.client = None
18
  self.model = None
19
+ self.groq_client = None
20
+ self.groq_model = None
21
  self._initialize_client()
22
+ self._initialize_groq_client()
23
 
24
  def _initialize_client(self):
25
  """Together.ai APIクライアントの初期化"""
 
38
  except Exception as e:
39
  logger.error(f"Together.ai APIクライアントの初期化に失敗しました: {e}")
40
 
41
+ def _initialize_groq_client(self):
42
+ """Groq APIクライアントの初期化(フォールバック用)"""
43
+ try:
44
+ groq_api_key = os.getenv("GROQ_API_KEY")
45
+ if not groq_api_key:
46
+ logger.warning("環境変数 GROQ_API_KEY が設定されていません。Groqフォールバックは利用できません。")
47
+ return
48
+
49
+ self.groq_client = OpenAI(
50
+ api_key=groq_api_key,
51
+ base_url="https://api.groq.com/openai/v1"
52
+ )
53
+ self.groq_model = "llama-3.1-70b-versatile"
54
+ logger.info("Groq APIクライアントの初期化が完了しました(フォールバック用)。")
55
+ except Exception as e:
56
+ logger.error(f"Groq APIクライアントの初期化に失敗しました: {e}")
57
+
58
  def get_system_prompt_mari(self, use_ura_mode: bool = False) -> str:
59
  """環境変数からシステムプロンプトを取得、なければデフォルトを返す"""
60
  if use_ura_mode:
 
110
  return os.getenv("SYSTEM_PROMPT_MARI", default_prompt)
111
 
112
  def call_llm(self, system_prompt: str, user_prompt: str, is_json_output: bool = False) -> str:
113
+ """Together.ai APIを呼び出し、15秒でタイムアウトした場合はGroq APIにフォールバック"""
114
  logger.info(f"🔗 call_llm開始 - is_json_output: {is_json_output}")
115
 
 
 
 
 
 
 
 
116
  # 入力検証
117
  if not isinstance(system_prompt, str) or not isinstance(user_prompt, str):
118
  logger.error(f"プロンプトが文字列ではありません: system={type(system_prompt)}, user={type(user_prompt)}")
 
120
  return '{"scene": "none"}'
121
  return "…なんか変なこと言ってない?"
122
 
123
+ # まずTogether.ai APIを試行
124
+ together_result = self._call_together_api(system_prompt, user_prompt, is_json_output)
125
+ if together_result is not None:
126
+ return together_result
127
+
128
+ # Together.ai APIが失敗した場合、Groq APIにフォールバック
129
+ logger.warning("🔄 Together.ai APIが失敗、Groq APIにフォールバック")
130
+ groq_result = self._call_groq_api(system_prompt, user_prompt, is_json_output)
131
+ if groq_result is not None:
132
+ return groq_result
133
+
134
+ # 両方のAPIが失敗した場合のデモモード応答
135
+ logger.error("⚠️ 全てのAPIが利用できません - デモモード応答を返します")
136
+ if is_json_output:
137
+ return '{"scene": "none"}'
138
+ return "[HIDDEN:(本当は話したいけど...)]は?何それ。あたしに話しかけてるの?"
139
+
140
+ def _call_together_api(self, system_prompt: str, user_prompt: str, is_json_output: bool = False) -> Optional[str]:
141
+ """Together.ai APIを15秒タイムアウトで呼び出し(Windows対応)"""
142
+ if not self.client:
143
+ logger.warning("⚠️ Together.ai APIクライアントが利用できません")
144
+ return None
145
+
146
  try:
147
+ import time
148
+ import threading
149
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
150
+
151
  # JSON出力の場合は短く、通常の対話は適度な長さに制限
152
  max_tokens = 150 if is_json_output else 500
153
  logger.info(f"🔗 Together.ai API呼び出し開始 - model: {self.model}, max_tokens: {max_tokens}")
154
 
155
+ start_time = time.time()
156
+
157
+ def api_call():
158
+ """API呼び出しを別スレッドで実行"""
159
+ return self.client.chat.completions.create(
160
+ model=self.model,
161
+ messages=[
162
+ {"role": "system", "content": system_prompt},
163
+ {"role": "user", "content": user_prompt}
164
+ ],
165
+ temperature=0.8,
166
+ max_tokens=max_tokens,
167
+ timeout=15 # APIレベルでも15秒タイムアウト
168
+ )
169
+
170
+ # ThreadPoolExecutorを使用して15秒タイムアウトを実装
171
+ with ThreadPoolExecutor(max_workers=1) as executor:
172
+ future = executor.submit(api_call)
173
+ try:
174
+ response = future.result(timeout=15) # 15秒タイムアウト
175
+
176
+ elapsed_time = time.time() - start_time
177
+ logger.info(f"🔗 Together.ai API呼び出し完了 ({elapsed_time:.2f}秒)")
178
+
179
+ content = response.choices[0].message.content if response.choices else ""
180
+ logger.info(f"🔗 Together.ai API応答内容: '{content[:100]}...' (長さ: {len(content)}文字)")
181
+
182
+ if not content:
183
+ logger.warning("Together.ai API応答が空です")
184
+ return None
185
+
186
+ return content
187
+
188
+ except FutureTimeoutError:
189
+ elapsed_time = time.time() - start_time
190
+ logger.warning(f"⏰ Together.ai API呼び出しタイムアウト ({elapsed_time:.2f}秒)")
191
+ return None
192
+
193
+ except Exception as e:
194
+ logger.error(f"Together.ai API呼び出しエラー: {e}")
195
+ return None
196
+
197
+ def _call_groq_api(self, system_prompt: str, user_prompt: str, is_json_output: bool = False) -> Optional[str]:
198
+ """Groq APIを呼び出し(フォールバック用)"""
199
+ if not self.groq_client:
200
+ logger.warning("⚠️ Groq APIクライアントが利用できません")
201
+ return None
202
+
203
+ try:
204
+ # JSON出力の場合は短く、通常の対話は適度な長さに制限
205
+ max_tokens = 150 if is_json_output else 500
206
+ logger.info(f"🔄 Groq API呼び出し開始 - model: {self.groq_model}, max_tokens: {max_tokens}")
207
+
208
+ response = self.groq_client.chat.completions.create(
209
+ model=self.groq_model,
210
  messages=[
211
  {"role": "system", "content": system_prompt},
212
  {"role": "user", "content": user_prompt}
213
  ],
214
  temperature=0.8,
215
  max_tokens=max_tokens,
216
+ timeout=10 # Groqは10秒タイムアウト
217
  )
218
 
219
+ logger.info("🔄 Groq API呼び出し完了")
220
 
221
  content = response.choices[0].message.content if response.choices else ""
222
+ logger.info(f"🔄 Groq API応答内容: '{content[:100]}...' (長さ: {len(content)}文字)")
223
 
224
  if not content:
225
+ logger.warning("Groq API応答が空です")
226
+ return None
 
 
227
 
 
228
  return content
229
 
230
  except Exception as e:
231
+ logger.error(f"Groq API呼び出しエラー: {e}")
232
+ return None
 
 
233
 
234
  def generate_dialogue(self, history: List[Tuple[str, str]], message: str,
235
  affection: int, stage_name: str, scene_params: Dict[str, Any],