Spaces:
Runtime error
Runtime error
Update
Browse files
model.py
CHANGED
|
@@ -215,7 +215,7 @@ class Model:
|
|
| 215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
| 216 |
|
| 217 |
elapsed = time.perf_counter() - start
|
| 218 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 219 |
return model, args
|
| 220 |
|
| 221 |
def load_strategy(self) -> CoglmStrategy:
|
|
@@ -229,7 +229,7 @@ class Model:
|
|
| 229 |
top_k_cluster=self.args.temp_cluster_gen)
|
| 230 |
|
| 231 |
elapsed = time.perf_counter() - start
|
| 232 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 233 |
return strategy
|
| 234 |
|
| 235 |
def load_srg(self) -> SRGroup:
|
|
@@ -239,7 +239,7 @@ class Model:
|
|
| 239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
| 240 |
|
| 241 |
elapsed = time.perf_counter() - start
|
| 242 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 243 |
return srg
|
| 244 |
|
| 245 |
def update_style(self, style: str) -> None:
|
|
@@ -264,7 +264,7 @@ class Model:
|
|
| 264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
| 265 |
|
| 266 |
elapsed = time.perf_counter() - start
|
| 267 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 268 |
|
| 269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
| 270 |
num: int) -> list[np.ndarray] | None:
|
|
@@ -302,7 +302,7 @@ class Model:
|
|
| 302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
| 303 |
|
| 304 |
elapsed = time.perf_counter() - start
|
| 305 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 306 |
return seq, txt_len
|
| 307 |
|
| 308 |
@torch.inference_mode()
|
|
@@ -340,7 +340,7 @@ class Model:
|
|
| 340 |
logger.debug(f'{output_tokens.shape=}')
|
| 341 |
|
| 342 |
elapsed = time.perf_counter() - start
|
| 343 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 344 |
return output_tokens
|
| 345 |
|
| 346 |
@staticmethod
|
|
@@ -374,7 +374,7 @@ class Model:
|
|
| 374 |
res.append(decoded_img) # only the last image (target)
|
| 375 |
|
| 376 |
elapsed = time.perf_counter() - start
|
| 377 |
-
logger.info(f'--- done ({elapsed=:.3f} ---')
|
| 378 |
return res
|
| 379 |
|
| 380 |
|
|
|
|
| 215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
| 216 |
|
| 217 |
elapsed = time.perf_counter() - start
|
| 218 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 219 |
return model, args
|
| 220 |
|
| 221 |
def load_strategy(self) -> CoglmStrategy:
|
|
|
|
| 229 |
top_k_cluster=self.args.temp_cluster_gen)
|
| 230 |
|
| 231 |
elapsed = time.perf_counter() - start
|
| 232 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 233 |
return strategy
|
| 234 |
|
| 235 |
def load_srg(self) -> SRGroup:
|
|
|
|
| 239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
| 240 |
|
| 241 |
elapsed = time.perf_counter() - start
|
| 242 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 243 |
return srg
|
| 244 |
|
| 245 |
def update_style(self, style: str) -> None:
|
|
|
|
| 264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
| 265 |
|
| 266 |
elapsed = time.perf_counter() - start
|
| 267 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 268 |
|
| 269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
| 270 |
num: int) -> list[np.ndarray] | None:
|
|
|
|
| 302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
| 303 |
|
| 304 |
elapsed = time.perf_counter() - start
|
| 305 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 306 |
return seq, txt_len
|
| 307 |
|
| 308 |
@torch.inference_mode()
|
|
|
|
| 340 |
logger.debug(f'{output_tokens.shape=}')
|
| 341 |
|
| 342 |
elapsed = time.perf_counter() - start
|
| 343 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 344 |
return output_tokens
|
| 345 |
|
| 346 |
@staticmethod
|
|
|
|
| 374 |
res.append(decoded_img) # only the last image (target)
|
| 375 |
|
| 376 |
elapsed = time.perf_counter() - start
|
| 377 |
+
logger.info(f'--- done ({elapsed=:.3f}) ---')
|
| 378 |
return res
|
| 379 |
|
| 380 |
|