File size: 2,183 Bytes
981b783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import requests
import time
import os


def get_comet_score(instances: list[dict], timeout=100, max_retries=10, comet_api: str=None):
    if comet_api is not None:
        url = f"http://{comet_api}/evaluate"
    else:
        url = f"http://{os.getenv('COMET_API')}/evaluate"
    payload = {'instances': instances}

    retries = 0
    while retries < max_retries:
        try:
            response = requests.post(url, json=payload, timeout=timeout)

            if response.status_code == 200:
                # print(response.json())  # {'score': ...}
                return response.json()['scores']
            else:
                print(f"Request failed with status code: {response.status_code}")
        except requests.Timeout:
            retries += 1
            print(f"Request timed out. Retrying... ({retries}/{max_retries})")
            time.sleep(5)
        except requests.RequestException as e:
            raise RuntimeError(f"Request failed due to: {e}")
            
    raise RuntimeError("Max retries exceeded. Request failed.")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--source_file', '-s', type=str, required=True)
    parser.add_argument('--target_file', '-t', type=str, required=True)
    parser.add_argument('--reference_file', '-r', type=str, required=True)
    parser.add_argument('--url', '-u', type=str, required=True)
    args = parser.parse_args()

    source_file = args.source_file
    target_file = args.target_file
    reference_file = args.reference_file
    comet_api = args.url

    with open(source_file, 'r') as f:
        source_lines = f.readlines()
    with open(target_file, 'r') as f:
        target_lines = f.readlines()
    with open(reference_file, 'r') as f:
        reference_lines = f.readlines()

    line_comet_scores = get_comet_score([{'src': i, 'mt': j, 'ref': k} for i, j, k in zip(source_lines, target_lines, reference_lines)], comet_api=comet_api)
    avg_score = sum(line_comet_scores) / len(line_comet_scores) if line_comet_scores and len(line_comet_scores) > 0 else -1.0
    print(f'{target_file}\tscore: {avg_score:.4f}')


if __name__ == '__main__':
    main()