Spaces:
Running
Running
Add --keep_tokenization_spaces argument to control the space decoding
Browse files- translate.py +14 -4
translate.py
CHANGED
|
@@ -31,7 +31,6 @@ def get_dataloader(
|
|
| 31 |
batch_size: int,
|
| 32 |
max_length: int,
|
| 33 |
) -> DataLoader:
|
| 34 |
-
|
| 35 |
dataset = DatasetReader(filename, tokenizer, max_length)
|
| 36 |
if accelerator.distributed_type == DistributedType.TPU:
|
| 37 |
data_collator = DataCollatorForSeq2Seq(
|
|
@@ -76,8 +75,8 @@ def main(
|
|
| 76 |
top_k: int = 50,
|
| 77 |
top_p: float = 1.0,
|
| 78 |
keep_special_tokens: bool = False,
|
|
|
|
| 79 |
):
|
| 80 |
-
|
| 81 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
| 82 |
|
| 83 |
accelerator = Accelerator(
|
|
@@ -149,6 +148,8 @@ def main(
|
|
| 149 |
f"Max length: {max_length}\n"
|
| 150 |
f"Precision: {model.dtype}\n"
|
| 151 |
f"Model: {model_name}\n"
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
print("** Generation parameters **")
|
| 154 |
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
|
|
@@ -197,7 +198,9 @@ def main(
|
|
| 197 |
)
|
| 198 |
|
| 199 |
tgt_text = tokenizer.batch_decode(
|
| 200 |
-
generated_tokens,
|
|
|
|
|
|
|
| 201 |
)
|
| 202 |
if accelerator.is_main_process:
|
| 203 |
if (
|
|
@@ -342,6 +345,12 @@ if __name__ == "__main__":
|
|
| 342 |
help="Keep special tokens in the decoded text.",
|
| 343 |
)
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
args = parser.parse_args()
|
| 346 |
|
| 347 |
main(
|
|
@@ -360,5 +369,6 @@ if __name__ == "__main__":
|
|
| 360 |
temperature=args.temperature,
|
| 361 |
top_k=args.top_k,
|
| 362 |
top_p=args.top_p,
|
| 363 |
-
keep_special_tokens=args.keep_special_tokens
|
|
|
|
| 364 |
)
|
|
|
|
| 31 |
batch_size: int,
|
| 32 |
max_length: int,
|
| 33 |
) -> DataLoader:
|
|
|
|
| 34 |
dataset = DatasetReader(filename, tokenizer, max_length)
|
| 35 |
if accelerator.distributed_type == DistributedType.TPU:
|
| 36 |
data_collator = DataCollatorForSeq2Seq(
|
|
|
|
| 75 |
top_k: int = 50,
|
| 76 |
top_p: float = 1.0,
|
| 77 |
keep_special_tokens: bool = False,
|
| 78 |
+
keep_tokenization_spaces: bool = False,
|
| 79 |
):
|
|
|
|
| 80 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
| 81 |
|
| 82 |
accelerator = Accelerator(
|
|
|
|
| 148 |
f"Max length: {max_length}\n"
|
| 149 |
f"Precision: {model.dtype}\n"
|
| 150 |
f"Model: {model_name}\n"
|
| 151 |
+
f"Keep special tokens: {keep_special_tokens}\n"
|
| 152 |
+
f"Keep tokenization spaces: {keep_tokenization_spaces}\n"
|
| 153 |
)
|
| 154 |
print("** Generation parameters **")
|
| 155 |
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
tgt_text = tokenizer.batch_decode(
|
| 201 |
+
generated_tokens,
|
| 202 |
+
skip_special_tokens=not keep_special_tokens,
|
| 203 |
+
clean_up_tokenization_spaces=not keep_tokenization_spaces,
|
| 204 |
)
|
| 205 |
if accelerator.is_main_process:
|
| 206 |
if (
|
|
|
|
| 345 |
help="Keep special tokens in the decoded text.",
|
| 346 |
)
|
| 347 |
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--keep_tokenization_spaces",
|
| 350 |
+
action="store_true",
|
| 351 |
+
help="Do not clean spaces in the decoded text.",
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
args = parser.parse_args()
|
| 355 |
|
| 356 |
main(
|
|
|
|
| 369 |
temperature=args.temperature,
|
| 370 |
top_k=args.top_k,
|
| 371 |
top_p=args.top_p,
|
| 372 |
+
keep_special_tokens=args.keep_special_tokens,
|
| 373 |
+
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
| 374 |
)
|