| import argparse | |
| import os | |
| def main(): | |
| src_lang, tgt_lang = args.lang_pair.split("-") | |
| tgt_file_list = [file for file in os.listdir(args.tgt_path) if file.endswith(f".{tgt_lang}")] | |
| for tgt_file in tgt_file_list: | |
| src_file = os.path.splitext(tgt_file)[0] | |
| doc_id = src_file.split('.')[-1] | |
| label_file = src_file.replace(f".{src_lang}.", ".id.") | |
| with open(os.path.join(args.disturb_src_path, label_file), "r", encoding="utf-8") as f: | |
| labels = [line.strip() for line in f] | |
| with open(os.path.join(args.tgt_path, tgt_file), "r", encoding="utf-8") as f: | |
| tgt_lines = [line.strip() for line in f] | |
| assert len(labels) == len(tgt_lines), f"Length mismatch in {src_file} and {label_file}" | |
| filterd_tgt_lines = [tgt for tgt, label in zip(tgt_lines, labels) if label.split('-')[0] == doc_id] | |
| with open(os.path.join(args.original_src_path, src_file), "r", encoding="utf-8") as f: | |
| original_src_lines = [line.strip() for line in f] | |
| assert len(original_src_lines) == len(filterd_tgt_lines), f"Length mismatch in {src_file} and filtered {tgt_file}" | |
| with open(os.path.join(args.output_path, tgt_file), "w", encoding="utf-8") as f: | |
| f.write("\n".join(filterd_tgt_lines) + "\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--original_src_path", type=str) | |
| parser.add_argument("--disturb_src_path", type=str) | |
| parser.add_argument("--tgt_path", type=str) | |
| parser.add_argument("--output_path", type=str) | |
| parser.add_argument("--lang_pair", type=str) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_path, exist_ok=True) | |
| main() | |