Spaces:
Runtime error
Runtime error
hysts
commited on
Commit
·
552d2be
1
Parent(s):
2f2bde2
Update logger
Browse files
model.py
CHANGED
|
@@ -215,8 +215,7 @@ class Model:
|
|
| 215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
| 216 |
|
| 217 |
elapsed = time.perf_counter() - start
|
| 218 |
-
logger.info(f'
|
| 219 |
-
logger.info('--- done ---')
|
| 220 |
return model, args
|
| 221 |
|
| 222 |
def load_strategy(self) -> CoglmStrategy:
|
|
@@ -230,8 +229,7 @@ class Model:
|
|
| 230 |
top_k_cluster=self.args.temp_cluster_gen)
|
| 231 |
|
| 232 |
elapsed = time.perf_counter() - start
|
| 233 |
-
logger.info(f'
|
| 234 |
-
logger.info('--- done ---')
|
| 235 |
return strategy
|
| 236 |
|
| 237 |
def load_srg(self) -> SRGroup:
|
|
@@ -241,8 +239,7 @@ class Model:
|
|
| 241 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
| 242 |
|
| 243 |
elapsed = time.perf_counter() - start
|
| 244 |
-
logger.info(f'
|
| 245 |
-
logger.info('--- done ---')
|
| 246 |
return srg
|
| 247 |
|
| 248 |
def update_style(self, style: str) -> None:
|
|
@@ -267,8 +264,7 @@ class Model:
|
|
| 267 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
| 268 |
|
| 269 |
elapsed = time.perf_counter() - start
|
| 270 |
-
logger.info(f'
|
| 271 |
-
logger.info('--- done ---')
|
| 272 |
|
| 273 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
| 274 |
num: int) -> list[np.ndarray] | None:
|
|
@@ -306,8 +302,7 @@ class Model:
|
|
| 306 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
| 307 |
|
| 308 |
elapsed = time.perf_counter() - start
|
| 309 |
-
logger.info(f'
|
| 310 |
-
logger.info('--- done ---')
|
| 311 |
return seq, txt_len
|
| 312 |
|
| 313 |
@torch.inference_mode()
|
|
@@ -345,8 +340,7 @@ class Model:
|
|
| 345 |
logger.debug(f'{output_tokens.shape=}')
|
| 346 |
|
| 347 |
elapsed = time.perf_counter() - start
|
| 348 |
-
logger.info(f'
|
| 349 |
-
logger.info('--- done ---')
|
| 350 |
return output_tokens
|
| 351 |
|
| 352 |
@staticmethod
|
|
@@ -380,8 +374,7 @@ class Model:
|
|
| 380 |
res.append(decoded_img) # only the last image (target)
|
| 381 |
|
| 382 |
elapsed = time.perf_counter() - start
|
| 383 |
-
logger.info(f'
|
| 384 |
-
logger.info('--- done ---')
|
| 385 |
return res
|
| 386 |
|
| 387 |
|
|
|
|
| 215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
| 216 |
|
| 217 |
elapsed = time.perf_counter() - start
|
| 218 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 219 |
return model, args
|
| 220 |
|
| 221 |
def load_strategy(self) -> CoglmStrategy:
|
|
|
|
| 229 |
top_k_cluster=self.args.temp_cluster_gen)
|
| 230 |
|
| 231 |
elapsed = time.perf_counter() - start
|
| 232 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 233 |
return strategy
|
| 234 |
|
| 235 |
def load_srg(self) -> SRGroup:
|
|
|
|
| 239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
| 240 |
|
| 241 |
elapsed = time.perf_counter() - start
|
| 242 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 243 |
return srg
|
| 244 |
|
| 245 |
def update_style(self, style: str) -> None:
|
|
|
|
| 264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
| 265 |
|
| 266 |
elapsed = time.perf_counter() - start
|
| 267 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 268 |
|
| 269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
| 270 |
num: int) -> list[np.ndarray] | None:
|
|
|
|
| 302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
| 303 |
|
| 304 |
elapsed = time.perf_counter() - start
|
| 305 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 306 |
return seq, txt_len
|
| 307 |
|
| 308 |
@torch.inference_mode()
|
|
|
|
| 340 |
logger.debug(f'{output_tokens.shape=}')
|
| 341 |
|
| 342 |
elapsed = time.perf_counter() - start
|
| 343 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 344 |
return output_tokens
|
| 345 |
|
| 346 |
@staticmethod
|
|
|
|
| 374 |
res.append(decoded_img) # only the last image (target)
|
| 375 |
|
| 376 |
elapsed = time.perf_counter() - start
|
| 377 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
|
| 378 |
return res
|
| 379 |
|
| 380 |
|