Spaces:
Running
Running
Upload 2 files
Browse files- app.py +2 -2
- prompt_generator.py +6 -2
app.py
CHANGED
|
@@ -64,7 +64,7 @@ vae = None
|
|
| 64 |
# スタイルリストから名前のみを抽出
|
| 65 |
style_names = [style["name"] for style in style_list]
|
| 66 |
|
| 67 |
-
@spaces.GPU(timeout_seconds=
|
| 68 |
def initialize_llm():
|
| 69 |
"""アプリケーション起動時にLLMだけを初期化する関数"""
|
| 70 |
|
|
@@ -270,7 +270,7 @@ def convert_text_to_prompt(
|
|
| 270 |
return f"エラーが発生しました: {str(e)}", novel_text
|
| 271 |
|
| 272 |
@spaces.GPU
|
| 273 |
-
def load_image_model(timeout_seconds=
|
| 274 |
"""画像生成モデルをロードする関数"""
|
| 275 |
global pipe, vae
|
| 276 |
|
|
|
|
| 64 |
# スタイルリストから名前のみを抽出
|
| 65 |
style_names = [style["name"] for style in style_list]
|
| 66 |
|
| 67 |
+
@spaces.GPU(timeout_seconds=120)
|
| 68 |
def initialize_llm():
|
| 69 |
"""アプリケーション起動時にLLMだけを初期化する関数"""
|
| 70 |
|
|
|
|
| 270 |
return f"エラーが発生しました: {str(e)}", novel_text
|
| 271 |
|
| 272 |
@spaces.GPU
|
| 273 |
+
def load_image_model(timeout_seconds=120):
|
| 274 |
"""画像生成モデルをロードする関数"""
|
| 275 |
global pipe, vae
|
| 276 |
|
prompt_generator.py
CHANGED
|
@@ -79,6 +79,7 @@ _model = None
|
|
| 79 |
_tokenizer = None
|
| 80 |
|
| 81 |
|
|
|
|
| 82 |
def load_model():
|
| 83 |
"""モデルをロードする関数"""
|
| 84 |
global _model, _tokenizer
|
|
@@ -133,8 +134,6 @@ def load_model():
|
|
| 133 |
logger.error(f"Failed to load prompt generation model: {str(e)}")
|
| 134 |
raise
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
def unload_model():
|
| 139 |
"""メモリからモデルをアンロードする関数"""
|
| 140 |
global _model, _tokenizer
|
|
@@ -154,6 +153,7 @@ def unload_model():
|
|
| 154 |
|
| 155 |
logger.info("Prompt generation model unloaded")
|
| 156 |
|
|
|
|
| 157 |
def generate_prompt(
|
| 158 |
novel_text: str,
|
| 159 |
series_name: str = "original",
|
|
@@ -163,6 +163,7 @@ def generate_prompt(
|
|
| 163 |
try:
|
| 164 |
# モデルとトークナイザーの読み込み
|
| 165 |
model, tokenizer = load_model()
|
|
|
|
| 166 |
|
| 167 |
# 入力の検証
|
| 168 |
if not novel_text or novel_text.isspace():
|
|
@@ -264,6 +265,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
|
|
| 264 |
add_generation_prompt=True,
|
| 265 |
return_tensors="pt",
|
| 266 |
).to(model.device)
|
|
|
|
| 267 |
|
| 268 |
# 長すぎる入力のトリミング
|
| 269 |
if inputs.shape[1] > max_input_length:
|
|
@@ -271,6 +273,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
|
|
| 271 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
| 272 |
|
| 273 |
# 生成
|
|
|
|
| 274 |
with torch.no_grad():
|
| 275 |
generated_ids = model.generate(
|
| 276 |
input_ids=inputs,
|
|
@@ -286,6 +289,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
|
|
| 286 |
pad_token_id=tokenizer.pad_token_id,
|
| 287 |
)
|
| 288 |
|
|
|
|
| 289 |
# デコード
|
| 290 |
full_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 291 |
|
|
|
|
| 79 |
_tokenizer = None
|
| 80 |
|
| 81 |
|
| 82 |
+
@spaces.GPU
|
| 83 |
def load_model():
|
| 84 |
"""モデルをロードする関数"""
|
| 85 |
global _model, _tokenizer
|
|
|
|
| 134 |
logger.error(f"Failed to load prompt generation model: {str(e)}")
|
| 135 |
raise
|
| 136 |
|
|
|
|
|
|
|
| 137 |
def unload_model():
|
| 138 |
"""メモリからモデルをアンロードする関数"""
|
| 139 |
global _model, _tokenizer
|
|
|
|
| 153 |
|
| 154 |
logger.info("Prompt generation model unloaded")
|
| 155 |
|
| 156 |
+
@spaces.GPU
|
| 157 |
def generate_prompt(
|
| 158 |
novel_text: str,
|
| 159 |
series_name: str = "original",
|
|
|
|
| 163 |
try:
|
| 164 |
# モデルとトークナイザーの読み込み
|
| 165 |
model, tokenizer = load_model()
|
| 166 |
+
logger.info("Loading model, tokenizer is ok...")
|
| 167 |
|
| 168 |
# 入力の検証
|
| 169 |
if not novel_text or novel_text.isspace():
|
|
|
|
| 265 |
add_generation_prompt=True,
|
| 266 |
return_tensors="pt",
|
| 267 |
).to(model.device)
|
| 268 |
+
logger.info("tokenizer.apply_chat_template is ok...")
|
| 269 |
|
| 270 |
# 長すぎる入力のトリミング
|
| 271 |
if inputs.shape[1] > max_input_length:
|
|
|
|
| 273 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
| 274 |
|
| 275 |
# 生成
|
| 276 |
+
logger.info("before ttorch.no_grad")
|
| 277 |
with torch.no_grad():
|
| 278 |
generated_ids = model.generate(
|
| 279 |
input_ids=inputs,
|
|
|
|
| 289 |
pad_token_id=tokenizer.pad_token_id,
|
| 290 |
)
|
| 291 |
|
| 292 |
+
logger.info("after ttorch.no_grad")
|
| 293 |
# デコード
|
| 294 |
full_outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 295 |
|