Last active
January 6, 2022 13:17
-
-
Save gsarti/8da29d3c16e395f665ee527d4f8edafd to your computer and use it in GitHub Desktop.
Use π€ transformer FlaxMarianMTModel to translate π€ datasets text fields on TPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import datasets | |
| import jax | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from tqdm import tqdm | |
| from transformers import FlaxMarianMTModel, MarianTokenizer | |
| def translate(args): | |
| if args.model_name is None: | |
| model = FlaxMarianMTModel.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{args.src}-{args.tgt}", from_pt=True | |
| ) | |
| tokenizer = MarianTokenizer.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{args.src}-{args.tgt}" | |
| ) | |
| else: | |
| model = FlaxMarianMTModel.from_pretrained(args.model_name, from_pt=True) | |
| tokenizer = MarianTokenizer.from_pretrained(args.model_name) | |
| data = datasets.load_dataset( | |
| args.dataset_name, | |
| args.dataset_config, | |
| ) | |
| def _generate(params, batch): | |
| output_ids = model.generate( | |
| batch["input_ids"], | |
| attention_mask=batch["attention_mask"], | |
| params=params, | |
| num_beams=args.num_beams, | |
| max_length=args.max_length, | |
| early_stopping=args.early_stopping | |
| ).sequences | |
| return output_ids | |
| def encode_texts(texts): | |
| return tokenizer( | |
| texts, | |
| return_tensors="jax", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=args.max_length | |
| ) | |
| def decode_texts(ids): | |
| return tokenizer.batch_decode( | |
| ids.reshape(-1,args.max_length), | |
| skip_special_tokens=True, | |
| max_length=args.max_length | |
| ) | |
| def generate_parallel(texts): | |
| p_params = replicate(model.params) | |
| p_generate = jax.pmap(_generate, "batch") | |
| inputs = encode_texts(texts) | |
| p_inputs = shard(inputs.data) | |
| output_ids = p_generate(p_params, p_inputs) | |
| return decode_texts(output_ids.reshape(-1,args.max_length)) | |
| def generate(texts): | |
| inputs = encode_texts(texts) | |
| output_ids = _generate(model.params, inputs) | |
| return decode_texts(output_ids.reshape(-1,args.max_length)) | |
| def replace_translated_fields(ex, id, trans_dict): | |
| for field in trans_dict.keys(): | |
| ex[field] = trans_dict[field][id] if id < len(trans_dict[field]) else "<NAN>" | |
| return ex | |
| if args.dataset_splits is None: | |
| args.dataset_splits = list(data.keys()) | |
| for split in args.dataset_splits: | |
| print(f"Translating split {split}") | |
| orig_len = len(data[split]) if not args.debug else 555 | |
| shard_length = orig_len - (orig_len % args.num_devices) | |
| print(f"Original length: {orig_len}, using only {shard_length} entries") | |
| dic_translated_fields = {} | |
| for field in args.fields_to_translate: | |
| print(f"Processing field {field}...") | |
| dic_translated_fields[field] = [ | |
| sent for idx in tqdm(range(0,shard_length, args.batch_size)) | |
| for sent in generate_parallel( | |
| data[split][idx:idx+args.batch_size if idx + args.batch_size <= shard_length else shard_length][field], | |
| ) | |
| ] | |
| if args.show_examples > 0: | |
| print("Translated parallel sentences:", len(dic_translated_fields[field])) | |
| print("Example:", dic_translated_fields[field][:args.show_examples]) | |
| # If not all sentences were translated, translate the remaining ones | |
| if shard_length < orig_len: | |
| extra_sentences = generate(data[split][shard_length:orig_len][field]) | |
| if args.show_examples > 0: | |
| print("Translated extra sentences:", len(extra_sentences)) | |
| print("Example:", extra_sentences[:args.show_examples]) | |
| dic_translated_fields[field] += extra_sentences | |
| assert len(dic_translated_fields[field]) == orig_len, f"Size mismatch: {len(dic_translated_fields[field])} does not match original size {orig_len}" | |
| data[split] = data[split].map( | |
| lambda x, id: replace_translated_fields(x, id, dic_translated_fields), | |
| with_indices=True | |
| ) | |
| data[split] = data[split].filter(lambda ex: not ex[args.fields_to_translate[0]].startswith('<NAN>')) | |
| if args.debug: | |
| print(data) | |
| else: | |
| save_path = f"{args.dataset_name}_{args.tgt}" | |
| data.save_to_disk(save_path) | |
| print(f"Translated dataset saved at {save_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Translate a dataset in a different language with FlaxMarianMTModels.') | |
| parser.add_argument("--dataset_name", required=True, help="Name of the dataset on which the translation should be performed.") | |
| parser.add_argument("--dataset_config", default=None, help="Name of the configuration on which the translation should be performed.") | |
| parser.add_argument("--dataset_splits", default=None, nargs="+", help="Name of the split on which the translation should be performed.") | |
| parser.add_argument("--model_name", default=None, help="Name of the model to use. If not specified, use the default MarianMTModel for src-tgt from the Helsinki-NLP account.") | |
| parser.add_argument("--src", default="en", help="Source language of the original fields.") | |
| parser.add_argument("--tgt", required=True, help="Target language of the translation.") | |
| parser.add_argument("--batch_size", default=512, help="The batch size to carry out the translation.") | |
| parser.add_argument("--num_devices", default=8, help="Number of available devices.") | |
| parser.add_argument("--num_beams", default=4, help="Number of beams for beam search decoding.") | |
| parser.add_argument("--max_length", default=512, help="Max length for the input and the output of the generation") | |
| parser.add_argument("--early_stopping", default=True, help="Whether to perform early stopping in generation") | |
| parser.add_argument("--fields_to_translate", default="text", nargs='+', help="The list of fields that should be translated") | |
| parser.add_argument("--debug", action="store_true", help="Set to use only a small amount of sentences for debugging.") | |
| parser.add_argument("--show_examples", default=3, help="Set to show n examples from the translations.") | |
| args = parser.parse_args() | |
| translate(args) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example run: translating from English to Dutch all fields of the
trainsplit ofesnli.