Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import argparse | |
| import contextlib | |
| import sys | |
| import sentencepiece as spm | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model", required=True, help="sentencepiece model to use for encoding" | |
| ) | |
| parser.add_argument( | |
| "--inputs", nargs="+", default=["-"], help="input files to filter/encode" | |
| ) | |
| parser.add_argument( | |
| "--outputs", nargs="+", default=["-"], help="path to save encoded outputs" | |
| ) | |
| parser.add_argument("--output_format", choices=["piece", "id"], default="piece") | |
| parser.add_argument( | |
| "--min-len", | |
| type=int, | |
| metavar="N", | |
| help="filter sentence pairs with fewer than N tokens", | |
| ) | |
| parser.add_argument( | |
| "--max-len", | |
| type=int, | |
| metavar="N", | |
| help="filter sentence pairs with more than N tokens", | |
| ) | |
| args = parser.parse_args() | |
| assert len(args.inputs) == len( | |
| args.outputs | |
| ), "number of input and output paths should match" | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(args.model) | |
| if args.output_format == "piece": | |
| def encode(l): | |
| return sp.EncodeAsPieces(l) | |
| elif args.output_format == "id": | |
| def encode(l): | |
| return list(map(str, sp.EncodeAsIds(l))) | |
| else: | |
| raise NotImplementedError | |
| if args.min_len is not None or args.max_len is not None: | |
| def valid(line): | |
| return (args.min_len is None or len(line) >= args.min_len) and ( | |
| args.max_len is None or len(line) <= args.max_len | |
| ) | |
| else: | |
| def valid(lines): | |
| return True | |
| with contextlib.ExitStack() as stack: | |
| inputs = [ | |
| stack.enter_context(open(input, "r", encoding="utf-8")) | |
| if input != "-" | |
| else sys.stdin | |
| for input in args.inputs | |
| ] | |
| outputs = [ | |
| stack.enter_context(open(output, "w", encoding="utf-8")) | |
| if output != "-" | |
| else sys.stdout | |
| for output in args.outputs | |
| ] | |
| stats = { | |
| "num_empty": 0, | |
| "num_filtered": 0, | |
| } | |
| def encode_line(line): | |
| line = line.strip() | |
| if len(line) > 0: | |
| line = encode(line) | |
| if valid(line): | |
| return line | |
| else: | |
| stats["num_filtered"] += 1 | |
| else: | |
| stats["num_empty"] += 1 | |
| return None | |
| for i, lines in enumerate(zip(*inputs), start=1): | |
| enc_lines = list(map(encode_line, lines)) | |
| if not any(enc_line is None for enc_line in enc_lines): | |
| for enc_line, output_h in zip(enc_lines, outputs): | |
| print(" ".join(enc_line), file=output_h) | |
| if i % 10000 == 0: | |
| print("processed {} lines".format(i), file=sys.stderr) | |
| print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) | |
| print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) | |
| if __name__ == "__main__": | |
| main() | |