File size: 1,757 Bytes
981b783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import argparse
import os


def main():
    src_lang, tgt_lang = args.lang_pair.split("-")
    tgt_file_list = [file for file in os.listdir(args.tgt_path) if file.endswith(f".{tgt_lang}")]
    for tgt_file in tgt_file_list:
        src_file = os.path.splitext(tgt_file)[0]
        doc_id = src_file.split('.')[-1]
        label_file = src_file.replace(f".{src_lang}.", ".id.")
        with open(os.path.join(args.disturb_src_path, label_file), "r", encoding="utf-8") as f:
            labels = [line.strip() for line in f]
        with open(os.path.join(args.tgt_path, tgt_file), "r", encoding="utf-8") as f:
            tgt_lines = [line.strip() for line in f]
        
        assert len(labels) == len(tgt_lines), f"Length mismatch in {src_file} and {label_file}"
        filterd_tgt_lines = [tgt for tgt, label in zip(tgt_lines, labels) if label.split('-')[0] == doc_id]
        
        with open(os.path.join(args.original_src_path, src_file), "r", encoding="utf-8") as f:
            original_src_lines = [line.strip() for line in f]
        assert len(original_src_lines) == len(filterd_tgt_lines), f"Length mismatch in {src_file} and filtered {tgt_file}"        
        
        with open(os.path.join(args.output_path, tgt_file), "w", encoding="utf-8") as f:
            f.write("\n".join(filterd_tgt_lines) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--original_src_path", type=str)
    parser.add_argument("--disturb_src_path", type=str)
    parser.add_argument("--tgt_path", type=str)
    parser.add_argument("--output_path", type=str)
    parser.add_argument("--lang_pair", type=str)
    args = parser.parse_args()
    
    os.makedirs(args.output_path, exist_ok=True)
    
    main()