Spaces:
Runtime error
Runtime error
add cli arg to 🅱️oost 🅱️eams
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- aggregate.py +9 -0
- app.py +14 -0
aggregate.py
CHANGED
|
@@ -179,6 +179,15 @@ class BatchAggregator:
|
|
| 179 |
|
| 180 |
self.aggregator.model.generation_config.update(**kwargs)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
def update_loglevel(self, level: str = "INFO"):
|
| 183 |
"""
|
| 184 |
Update the log level.
|
|
|
|
| 179 |
|
| 180 |
self.aggregator.model.generation_config.update(**kwargs)
|
| 181 |
|
| 182 |
+
def get_generation_config(self) -> dict:
|
| 183 |
+
"""
|
| 184 |
+
Get the current generation configuration.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
dict: The current generation configuration.
|
| 188 |
+
"""
|
| 189 |
+
return self.aggregator.model.generation_config.to_dict()
|
| 190 |
+
|
| 191 |
def update_loglevel(self, level: str = "INFO"):
|
| 192 |
"""
|
| 193 |
Update the log level.
|
app.py
CHANGED
|
@@ -427,6 +427,14 @@ def parse_args():
|
|
| 427 |
default=None,
|
| 428 |
help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
| 429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
parser.add_argument(
|
| 431 |
"-level",
|
| 432 |
"--log_level",
|
|
@@ -460,6 +468,12 @@ if __name__ == "__main__":
|
|
| 460 |
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
| 461 |
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
logger.info("Loading OCR model")
|
| 464 |
with contextlib.redirect_stdout(None):
|
| 465 |
ocr_model = ocr_predictor(
|
|
|
|
| 427 |
default=None,
|
| 428 |
help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
| 429 |
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"-max_agg",
|
| 432 |
+
"-2x",
|
| 433 |
+
"--aggregator_beam_boost",
|
| 434 |
+
dest="aggregator_beam_boost",
|
| 435 |
+
action="store_true",
|
| 436 |
+
help="Double the number of beams for the aggregator during beam search",
|
| 437 |
+
)
|
| 438 |
parser.add_argument(
|
| 439 |
"-level",
|
| 440 |
"--log_level",
|
|
|
|
| 468 |
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
| 469 |
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
| 470 |
|
| 471 |
+
if args.aggregator_beam_boost:
|
| 472 |
+
logger.info("Doubling aggregator num_beams")
|
| 473 |
+
_agg_cfg = aggregator.get_generation_config()
|
| 474 |
+
_agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
|
| 475 |
+
aggregator.update_generation_config(**_agg_cfg)
|
| 476 |
+
|
| 477 |
logger.info("Loading OCR model")
|
| 478 |
with contextlib.redirect_stdout(None):
|
| 479 |
ocr_model = ocr_predictor(
|