Skip to content

Instantly share code, notes, and snippets.

@gsarti
Created January 6, 2022 13:21
Show Gist options
  • Save gsarti/9f5080c6622a3cb4edb09862f6fd14c6 to your computer and use it in GitHub Desktop.
Save gsarti/9f5080c6622a3cb4edb09862f6fd14c6 to your computer and use it in GitHub Desktop.
Evaluate the translation of a πŸ€— dataset using Comet-based QE models
import argparse
import os
from datasets import load_dataset
from comet import download_model, load_from_checkpoint
def replace_missing(scores, sentences, default = -100):
out = []
for score, sentence in zip(scores, sentences):
if not sentence:
out.append(default)
else:
out.append(score)
return out
def main(args):
if args.translated_fields is None:
args.translated_fields = args.orig_fields
if args.score_name is None:
args.score_name = args.model.split('-')[-1]
assert len(args.orig_fields) == len(args.translated_fields), "Fields mismatch!"
orig = load_dataset(args.orig_dataset_name, args.orig_dataset_config, split=args.splits, use_auth_token=args.auth_token)
trans = load_dataset(args.translated_dataset_name, args.translated_dataset_config, split=args.splits, use_auth_token=args.auth_token)
model_path = download_model(args.model)
model = load_from_checkpoint(model_path)
if args.splits is None:
args.splits = list(orig.keys())
orig = [orig[s] for s in args.splits]
trans = [trans[s] for s in args.splits]
for split, orig_split, trans_split in zip(args.splits, orig, trans):
print("Processing split {}".format(split))
assert len(orig_split) == len(trans_split), f"Size mismatch for split {split}: {len(orig_split)} != {len(trans_split)}"
json_filename = f"{args.translated_dataset_name.split('/')[-1]}_{args.orig_dataset_config if args.orig_dataset_config else ''}_{split}_{args.score_name}.json"
for of, tf in zip(args.orig_fields, args.translated_fields):
data = [{'src': orig_split[idx][of], 'mt': trans_split[idx][tf]} for idx in range(len(orig_split))]
print(f"Example translations: {data[:3]}")
seg_scores, sys_score = model.predict(data, batch_size=8, gpus=1)
seg_scores = replace_missing(seg_scores, [d['mt'] for d in data])
os.makedirs(args.output_dir, exist_ok=True)
filename_scores = f'trans_{args.orig_dataset_name.split("/")[-1]}_{split}_{args.model}_{of}.txt'
with open(os.path.join(args.output_dir, filename_scores), 'w') as f:
f.writelines([str(score) + "\n" for score in seg_scores])
f.write("==== System score ====\n")
f.write(str(sys_score))
trans_split = trans_split.add_column(f"{args.score_name}_{tf}", seg_scores)
print(f"Done {of}!")
trans_split.to_json(os.path.join(args.output_dir, json_filename))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='wmt21-comet-qe-mqm')
parser.add_argument('--score_name', '-n', type=str, default=None)
parser.add_argument('--orig_dataset_name', '-o', type=str, required=True)
parser.add_argument('--translated_dataset_name', '-t', type=str, required=True)
parser.add_argument('--orig_dataset_config', '-x', type=str, default=None)
parser.add_argument('--translated_dataset_config', '-y', type=str, default=None)
parser.add_argument('--auth_token', '-a', type=str, default=None)
parser.add_argument('--orig_fields', '-e', nargs='+', type=str, default=["text"])
parser.add_argument('--translated_fields', '-f', nargs='+', type=str, default=None)
parser.add_argument('--output_dir', '-d', type=str, default='.')
parser.add_argument('--splits', '-s', nargs='+', type=str, default=None)
args = parser.parse_args()
print(args)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment