train-scripts / preprocess_robust.py
Ashton2000's picture
Upload folder using huggingface_hub
981b783 verified
import argparse
import os
def main():
src_lang, tgt_lang = args.lang_pair.split("-")
tgt_file_list = [file for file in os.listdir(args.tgt_path) if file.endswith(f".{tgt_lang}")]
for tgt_file in tgt_file_list:
src_file = os.path.splitext(tgt_file)[0]
doc_id = src_file.split('.')[-1]
label_file = src_file.replace(f".{src_lang}.", ".id.")
with open(os.path.join(args.disturb_src_path, label_file), "r", encoding="utf-8") as f:
labels = [line.strip() for line in f]
with open(os.path.join(args.tgt_path, tgt_file), "r", encoding="utf-8") as f:
tgt_lines = [line.strip() for line in f]
assert len(labels) == len(tgt_lines), f"Length mismatch in {src_file} and {label_file}"
filterd_tgt_lines = [tgt for tgt, label in zip(tgt_lines, labels) if label.split('-')[0] == doc_id]
with open(os.path.join(args.original_src_path, src_file), "r", encoding="utf-8") as f:
original_src_lines = [line.strip() for line in f]
assert len(original_src_lines) == len(filterd_tgt_lines), f"Length mismatch in {src_file} and filtered {tgt_file}"
with open(os.path.join(args.output_path, tgt_file), "w", encoding="utf-8") as f:
f.write("\n".join(filterd_tgt_lines) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original_src_path", type=str)
parser.add_argument("--disturb_src_path", type=str)
parser.add_argument("--tgt_path", type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--lang_pair", type=str)
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
main()