import argparse import os def main(): src_lang, tgt_lang = args.lang_pair.split("-") tgt_file_list = [file for file in os.listdir(args.tgt_path) if file.endswith(f".{tgt_lang}")] for tgt_file in tgt_file_list: src_file = os.path.splitext(tgt_file)[0] doc_id = src_file.split('.')[-1] label_file = src_file.replace(f".{src_lang}.", ".id.") with open(os.path.join(args.disturb_src_path, label_file), "r", encoding="utf-8") as f: labels = [line.strip() for line in f] with open(os.path.join(args.tgt_path, tgt_file), "r", encoding="utf-8") as f: tgt_lines = [line.strip() for line in f] assert len(labels) == len(tgt_lines), f"Length mismatch in {src_file} and {label_file}" filterd_tgt_lines = [tgt for tgt, label in zip(tgt_lines, labels) if label.split('-')[0] == doc_id] with open(os.path.join(args.original_src_path, src_file), "r", encoding="utf-8") as f: original_src_lines = [line.strip() for line in f] assert len(original_src_lines) == len(filterd_tgt_lines), f"Length mismatch in {src_file} and filtered {tgt_file}" with open(os.path.join(args.output_path, tgt_file), "w", encoding="utf-8") as f: f.write("\n".join(filterd_tgt_lines) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--original_src_path", type=str) parser.add_argument("--disturb_src_path", type=str) parser.add_argument("--tgt_path", type=str) parser.add_argument("--output_path", type=str) parser.add_argument("--lang_pair", type=str) args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) main()