Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Extracts random constraints from reference files.""" | |
| import argparse | |
| import random | |
| import sys | |
| from sacrebleu import extract_ngrams | |
| def get_phrase(words, index, length): | |
| assert index < len(words) - length + 1 | |
| phr = " ".join(words[index : index + length]) | |
| for i in range(index, index + length): | |
| words.pop(index) | |
| return phr | |
| def main(args): | |
| if args.seed: | |
| random.seed(args.seed) | |
| for line in sys.stdin: | |
| constraints = [] | |
| def add_constraint(constraint): | |
| constraints.append(constraint) | |
| source = line.rstrip() | |
| if "\t" in line: | |
| source, target = line.split("\t") | |
| if args.add_sos: | |
| target = f"<s> {target}" | |
| if args.add_eos: | |
| target = f"{target} </s>" | |
| if len(target.split()) >= args.len: | |
| words = [target] | |
| num = args.number | |
| choices = {} | |
| for i in range(num): | |
| if len(words) == 0: | |
| break | |
| segmentno = random.choice(range(len(words))) | |
| segment = words.pop(segmentno) | |
| tokens = segment.split() | |
| phrase_index = random.choice(range(len(tokens))) | |
| choice = " ".join( | |
| tokens[phrase_index : min(len(tokens), phrase_index + args.len)] | |
| ) | |
| for j in range( | |
| phrase_index, min(len(tokens), phrase_index + args.len) | |
| ): | |
| tokens.pop(phrase_index) | |
| if phrase_index > 0: | |
| words.append(" ".join(tokens[0:phrase_index])) | |
| if phrase_index + 1 < len(tokens): | |
| words.append(" ".join(tokens[phrase_index:])) | |
| choices[target.find(choice)] = choice | |
| # mask out with spaces | |
| target = target.replace(choice, " " * len(choice), 1) | |
| for key in sorted(choices.keys()): | |
| add_constraint(choices[key]) | |
| print(source, *constraints, sep="\t") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases") | |
| parser.add_argument("--len", "-l", type=int, default=1, help="phrase length") | |
| parser.add_argument( | |
| "--add-sos", default=False, action="store_true", help="add <s> token" | |
| ) | |
| parser.add_argument( | |
| "--add-eos", default=False, action="store_true", help="add </s> token" | |
| ) | |
| parser.add_argument("--seed", "-s", default=0, type=int) | |
| args = parser.parse_args() | |
| main(args) | |