Created
January 6, 2022 13:21
-
-
Save gsarti/9f5080c6622a3cb4edb09862f6fd14c6 to your computer and use it in GitHub Desktop.
Evaluate the translation of a π€ dataset using Comet-based QE models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import os | |
| from datasets import load_dataset | |
| from comet import download_model, load_from_checkpoint | |
| def replace_missing(scores, sentences, default = -100): | |
| out = [] | |
| for score, sentence in zip(scores, sentences): | |
| if not sentence: | |
| out.append(default) | |
| else: | |
| out.append(score) | |
| return out | |
| def main(args): | |
| if args.translated_fields is None: | |
| args.translated_fields = args.orig_fields | |
| if args.score_name is None: | |
| args.score_name = args.model.split('-')[-1] | |
| assert len(args.orig_fields) == len(args.translated_fields), "Fields mismatch!" | |
| orig = load_dataset(args.orig_dataset_name, args.orig_dataset_config, split=args.splits, use_auth_token=args.auth_token) | |
| trans = load_dataset(args.translated_dataset_name, args.translated_dataset_config, split=args.splits, use_auth_token=args.auth_token) | |
| model_path = download_model(args.model) | |
| model = load_from_checkpoint(model_path) | |
| if args.splits is None: | |
| args.splits = list(orig.keys()) | |
| orig = [orig[s] for s in args.splits] | |
| trans = [trans[s] for s in args.splits] | |
| for split, orig_split, trans_split in zip(args.splits, orig, trans): | |
| print("Processing split {}".format(split)) | |
| assert len(orig_split) == len(trans_split), f"Size mismatch for split {split}: {len(orig_split)} != {len(trans_split)}" | |
| json_filename = f"{args.translated_dataset_name.split('/')[-1]}_{args.orig_dataset_config if args.orig_dataset_config else ''}_{split}_{args.score_name}.json" | |
| for of, tf in zip(args.orig_fields, args.translated_fields): | |
| data = [{'src': orig_split[idx][of], 'mt': trans_split[idx][tf]} for idx in range(len(orig_split))] | |
| print(f"Example translations: {data[:3]}") | |
| seg_scores, sys_score = model.predict(data, batch_size=8, gpus=1) | |
| seg_scores = replace_missing(seg_scores, [d['mt'] for d in data]) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| filename_scores = f'trans_{args.orig_dataset_name.split("/")[-1]}_{split}_{args.model}_{of}.txt' | |
| with open(os.path.join(args.output_dir, filename_scores), 'w') as f: | |
| f.writelines([str(score) + "\n" for score in seg_scores]) | |
| f.write("==== System score ====\n") | |
| f.write(str(sys_score)) | |
| trans_split = trans_split.add_column(f"{args.score_name}_{tf}", seg_scores) | |
| print(f"Done {of}!") | |
| trans_split.to_json(os.path.join(args.output_dir, json_filename)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model', '-m', type=str, default='wmt21-comet-qe-mqm') | |
| parser.add_argument('--score_name', '-n', type=str, default=None) | |
| parser.add_argument('--orig_dataset_name', '-o', type=str, required=True) | |
| parser.add_argument('--translated_dataset_name', '-t', type=str, required=True) | |
| parser.add_argument('--orig_dataset_config', '-x', type=str, default=None) | |
| parser.add_argument('--translated_dataset_config', '-y', type=str, default=None) | |
| parser.add_argument('--auth_token', '-a', type=str, default=None) | |
| parser.add_argument('--orig_fields', '-e', nargs='+', type=str, default=["text"]) | |
| parser.add_argument('--translated_fields', '-f', nargs='+', type=str, default=None) | |
| parser.add_argument('--output_dir', '-d', type=str, default='.') | |
| parser.add_argument('--splits', '-s', nargs='+', type=str, default=None) | |
| args = parser.parse_args() | |
| print(args) | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment