|
|
import argparse |
|
|
import requests |
|
|
import time |
|
|
import os |
|
|
|
|
|
|
|
|
def get_comet_score(instances: list[dict], timeout=100, max_retries=10, comet_api: str=None): |
|
|
if comet_api is not None: |
|
|
url = f"http://{comet_api}/evaluate" |
|
|
else: |
|
|
url = f"http://{os.getenv('COMET_API')}/evaluate" |
|
|
payload = {'instances': instances} |
|
|
|
|
|
retries = 0 |
|
|
while retries < max_retries: |
|
|
try: |
|
|
response = requests.post(url, json=payload, timeout=timeout) |
|
|
|
|
|
if response.status_code == 200: |
|
|
|
|
|
return response.json()['scores'] |
|
|
else: |
|
|
print(f"Request failed with status code: {response.status_code}") |
|
|
except requests.Timeout: |
|
|
retries += 1 |
|
|
print(f"Request timed out. Retrying... ({retries}/{max_retries})") |
|
|
time.sleep(5) |
|
|
except requests.RequestException as e: |
|
|
raise RuntimeError(f"Request failed due to: {e}") |
|
|
|
|
|
raise RuntimeError("Max retries exceeded. Request failed.") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--source_file', '-s', type=str, required=True) |
|
|
parser.add_argument('--target_file', '-t', type=str, required=True) |
|
|
parser.add_argument('--reference_file', '-r', type=str, required=True) |
|
|
parser.add_argument('--url', '-u', type=str, required=True) |
|
|
args = parser.parse_args() |
|
|
|
|
|
source_file = args.source_file |
|
|
target_file = args.target_file |
|
|
reference_file = args.reference_file |
|
|
comet_api = args.url |
|
|
|
|
|
with open(source_file, 'r') as f: |
|
|
source_lines = f.readlines() |
|
|
with open(target_file, 'r') as f: |
|
|
target_lines = f.readlines() |
|
|
with open(reference_file, 'r') as f: |
|
|
reference_lines = f.readlines() |
|
|
|
|
|
line_comet_scores = get_comet_score([{'src': i, 'mt': j, 'ref': k} for i, j, k in zip(source_lines, target_lines, reference_lines)], comet_api=comet_api) |
|
|
avg_score = sum(line_comet_scores) / len(line_comet_scores) if line_comet_scores and len(line_comet_scores) > 0 else -1.0 |
|
|
print(f'{target_file}\tscore: {avg_score:.4f}') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|