dahara1 commited on
Commit
485b838
·
verified ·
1 Parent(s): e0887f1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. prompt_generator.py +6 -2
app.py CHANGED
@@ -64,7 +64,7 @@ vae = None
64
  # スタイルリストから名前のみを抽出
65
  style_names = [style["name"] for style in style_list]
66
 
67
- @spaces.GPU(timeout_seconds=300)
68
  def initialize_llm():
69
  """アプリケーション起動時にLLMだけを初期化する関数"""
70
 
@@ -270,7 +270,7 @@ def convert_text_to_prompt(
270
  return f"エラーが発生しました: {str(e)}", novel_text
271
 
272
  @spaces.GPU
273
- def load_image_model(timeout_seconds=300):
274
  """画像生成モデルをロードする関数"""
275
  global pipe, vae
276
 
 
64
  # スタイルリストから名前のみを抽出
65
  style_names = [style["name"] for style in style_list]
66
 
67
+ @spaces.GPU(timeout_seconds=120)
68
  def initialize_llm():
69
  """アプリケーション起動時にLLMだけを初期化する関数"""
70
 
 
270
  return f"エラーが発生しました: {str(e)}", novel_text
271
 
272
  @spaces.GPU
273
+ def load_image_model(timeout_seconds=120):
274
  """画像生成モデルをロードする関数"""
275
  global pipe, vae
276
 
prompt_generator.py CHANGED
@@ -79,6 +79,7 @@ _model = None
79
  _tokenizer = None
80
 
81
 
 
82
  def load_model():
83
  """モデルをロードする関数"""
84
  global _model, _tokenizer
@@ -133,8 +134,6 @@ def load_model():
133
  logger.error(f"Failed to load prompt generation model: {str(e)}")
134
  raise
135
 
136
-
137
-
138
  def unload_model():
139
  """メモリからモデルをアンロードする関数"""
140
  global _model, _tokenizer
@@ -154,6 +153,7 @@ def unload_model():
154
 
155
  logger.info("Prompt generation model unloaded")
156
 
 
157
  def generate_prompt(
158
  novel_text: str,
159
  series_name: str = "original",
@@ -163,6 +163,7 @@ def generate_prompt(
163
  try:
164
  # モデルとトークナイザーの読み込み
165
  model, tokenizer = load_model()
 
166
 
167
  # 入力の検証
168
  if not novel_text or novel_text.isspace():
@@ -264,6 +265,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
264
  add_generation_prompt=True,
265
  return_tensors="pt",
266
  ).to(model.device)
 
267
 
268
  # 長すぎる入力のトリミング
269
  if inputs.shape[1] > max_input_length:
@@ -271,6 +273,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
271
  logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
272
 
273
  # 生成
 
274
  with torch.no_grad():
275
  generated_ids = model.generate(
276
  input_ids=inputs,
@@ -286,6 +289,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
286
  pad_token_id=tokenizer.pad_token_id,
287
  )
288
 
 
289
  # デコード
290
  full_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
291
 
 
79
  _tokenizer = None
80
 
81
 
82
+ @spaces.GPU
83
  def load_model():
84
  """モデルをロードする関数"""
85
  global _model, _tokenizer
 
134
  logger.error(f"Failed to load prompt generation model: {str(e)}")
135
  raise
136
 
 
 
137
  def unload_model():
138
  """メモリからモデルをアンロードする関数"""
139
  global _model, _tokenizer
 
153
 
154
  logger.info("Prompt generation model unloaded")
155
 
156
+ @spaces.GPU
157
  def generate_prompt(
158
  novel_text: str,
159
  series_name: str = "original",
 
163
  try:
164
  # モデルとトークナイザーの読み込み
165
  model, tokenizer = load_model()
166
+ logger.info("Loading model, tokenizer is ok...")
167
 
168
  # 入力の検証
169
  if not novel_text or novel_text.isspace():
 
265
  add_generation_prompt=True,
266
  return_tensors="pt",
267
  ).to(model.device)
268
+ logger.info("tokenizer.apply_chat_template is ok...")
269
 
270
  # 長すぎる入力のトリミング
271
  if inputs.shape[1] > max_input_length:
 
273
  logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
274
 
275
  # 生成
276
+ logger.info("before ttorch.no_grad")
277
  with torch.no_grad():
278
  generated_ids = model.generate(
279
  input_ids=inputs,
 
289
  pad_token_id=tokenizer.pad_token_id,
290
  )
291
 
292
+ logger.info("after ttorch.no_grad")
293
  # デコード
294
  full_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
295