Spaces:
Running
Running
Implement num_return_sequences parameter.
Browse filesDefine how many possible translations you want for each source sentence. Defualt:1
- translate.py +10 -1
translate.py
CHANGED
|
@@ -59,6 +59,7 @@ def main(
|
|
| 59 |
precision: str = "32",
|
| 60 |
max_length: int = 128,
|
| 61 |
num_beams: int = 4,
|
|
|
|
| 62 |
):
|
| 63 |
|
| 64 |
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
|
|
@@ -96,7 +97,7 @@ def main(
|
|
| 96 |
gen_kwargs = {
|
| 97 |
"max_length": max_length,
|
| 98 |
"num_beams": num_beams,
|
| 99 |
-
"num_return_sequences":
|
| 100 |
}
|
| 101 |
|
| 102 |
# total_lines: int = count_lines(sentences_path)
|
|
@@ -246,6 +247,13 @@ if __name__ == "__main__":
|
|
| 246 |
help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
|
| 247 |
)
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
parser.add_argument(
|
| 250 |
"--precision",
|
| 251 |
type=str,
|
|
@@ -266,5 +274,6 @@ if __name__ == "__main__":
|
|
| 266 |
cache_dir=args.cache_dir,
|
| 267 |
max_length=args.max_length,
|
| 268 |
num_beams=args.num_beams,
|
|
|
|
| 269 |
precision=args.precision,
|
| 270 |
)
|
|
|
|
| 59 |
precision: str = "32",
|
| 60 |
max_length: int = 128,
|
| 61 |
num_beams: int = 4,
|
| 62 |
+
num_return_sequences: int = 1,
|
| 63 |
):
|
| 64 |
|
| 65 |
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
|
|
|
|
| 97 |
gen_kwargs = {
|
| 98 |
"max_length": max_length,
|
| 99 |
"num_beams": num_beams,
|
| 100 |
+
"num_return_sequences": num_return_sequences,
|
| 101 |
}
|
| 102 |
|
| 103 |
# total_lines: int = count_lines(sentences_path)
|
|
|
|
| 247 |
help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
|
| 248 |
)
|
| 249 |
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--num_return_sequences",
|
| 252 |
+
type=int,
|
| 253 |
+
default=1,
|
| 254 |
+
help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
parser.add_argument(
|
| 258 |
"--precision",
|
| 259 |
type=str,
|
|
|
|
| 274 |
cache_dir=args.cache_dir,
|
| 275 |
max_length=args.max_length,
|
| 276 |
num_beams=args.num_beams,
|
| 277 |
+
num_return_sequences=args.num_return_sequences,
|
| 278 |
precision=args.precision,
|
| 279 |
)
|