Spaces:
Running
Running
Add --repetition-penalty flag
Browse files- translate.py +12 -0
translate.py
CHANGED
|
@@ -76,6 +76,7 @@ def main(
|
|
| 76 |
top_p: float = 1.0,
|
| 77 |
keep_special_tokens: bool = False,
|
| 78 |
keep_tokenization_spaces: bool = False,
|
|
|
|
| 79 |
):
|
| 80 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
| 81 |
|
|
@@ -132,6 +133,9 @@ def main(
|
|
| 132 |
"top_p": top_p,
|
| 133 |
}
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
total_lines: int = count_lines(sentences_path)
|
| 136 |
|
| 137 |
if accelerator.is_main_process:
|
|
@@ -351,6 +355,13 @@ if __name__ == "__main__":
|
|
| 351 |
help="Do not clean spaces in the decoded text.",
|
| 352 |
)
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
args = parser.parse_args()
|
| 355 |
|
| 356 |
main(
|
|
@@ -371,4 +382,5 @@ if __name__ == "__main__":
|
|
| 371 |
top_p=args.top_p,
|
| 372 |
keep_special_tokens=args.keep_special_tokens,
|
| 373 |
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
|
|
|
| 374 |
)
|
|
|
|
| 76 |
top_p: float = 1.0,
|
| 77 |
keep_special_tokens: bool = False,
|
| 78 |
keep_tokenization_spaces: bool = False,
|
| 79 |
+
repetition_penalty: float = None,
|
| 80 |
):
|
| 81 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
| 82 |
|
|
|
|
| 133 |
"top_p": top_p,
|
| 134 |
}
|
| 135 |
|
| 136 |
+
if repetition_penalty is not None:
|
| 137 |
+
gen_kwargs["repetition_penalty"] = repetition_penalty
|
| 138 |
+
|
| 139 |
total_lines: int = count_lines(sentences_path)
|
| 140 |
|
| 141 |
if accelerator.is_main_process:
|
|
|
|
| 355 |
help="Do not clean spaces in the decoded text.",
|
| 356 |
)
|
| 357 |
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--repetition_penalty",
|
| 360 |
+
type=float,
|
| 361 |
+
default=None,
|
| 362 |
+
help="Repetition penalty.",
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
args = parser.parse_args()
|
| 366 |
|
| 367 |
main(
|
|
|
|
| 382 |
top_p=args.top_p,
|
| 383 |
keep_special_tokens=args.keep_special_tokens,
|
| 384 |
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
| 385 |
+
repetition_penalty=args.repetition_penalty,
|
| 386 |
)
|