dahara1 commited on
Commit
255a49e
·
verified ·
1 Parent(s): 485b838

Upload prompt_generator.py

Browse files
Files changed (1) hide show
  1. prompt_generator.py +6 -2
prompt_generator.py CHANGED
@@ -1,9 +1,14 @@
1
  import os
2
  import logging
3
  import torch
 
4
  from typing import Tuple, Optional, Dict, Any
5
  import gc
6
 
 
 
 
 
7
  # ロギング設定
8
  logging.basicConfig(
9
  level=logging.INFO,
@@ -99,8 +104,7 @@ def load_model():
99
  # HuggingFaceからモデルを直接ロードする
100
  model_path = model_name
101
 
102
- is_spaces_environment = "SPACE_ID" in os.environ and os.environ.get("SYSTEM") == "spaces"
103
- if is_spaces_environment:
104
  # Spaces環境ではdevice_mapの設定を変更
105
  device_map = None
106
  torch_dtype = torch.float16
 
1
  import os
2
  import logging
3
  import torch
4
+ import utils
5
  from typing import Tuple, Optional, Dict, Any
6
  import gc
7
 
8
+ if utils.is_space_environment():
9
+ import spaces
10
+
11
+
12
  # ロギング設定
13
  logging.basicConfig(
14
  level=logging.INFO,
 
104
  # HuggingFaceからモデルを直接ロードする
105
  model_path = model_name
106
 
107
+ if utils.is_space_environment():
 
108
  # Spaces環境ではdevice_mapの設定を変更
109
  device_map = None
110
  torch_dtype = torch.float16