Skip to content

Instantly share code, notes, and snippets.

@gsarti
Last active January 6, 2022 13:17
Show Gist options
  • Select an option

  • Save gsarti/8da29d3c16e395f665ee527d4f8edafd to your computer and use it in GitHub Desktop.

Select an option

Save gsarti/8da29d3c16e395f665ee527d4f8edafd to your computer and use it in GitHub Desktop.
Use πŸ€— transformer FlaxMarianMTModel to translate πŸ€— datasets text fields on TPU
import argparse
import datasets
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from tqdm import tqdm
from transformers import FlaxMarianMTModel, MarianTokenizer
def translate(args):
if args.model_name is None:
model = FlaxMarianMTModel.from_pretrained(
f"Helsinki-NLP/opus-mt-{args.src}-{args.tgt}", from_pt=True
)
tokenizer = MarianTokenizer.from_pretrained(
f"Helsinki-NLP/opus-mt-{args.src}-{args.tgt}"
)
else:
model = FlaxMarianMTModel.from_pretrained(args.model_name, from_pt=True)
tokenizer = MarianTokenizer.from_pretrained(args.model_name)
data = datasets.load_dataset(
args.dataset_name,
args.dataset_config,
)
def _generate(params, batch):
output_ids = model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
params=params,
num_beams=args.num_beams,
max_length=args.max_length,
early_stopping=args.early_stopping
).sequences
return output_ids
def encode_texts(texts):
return tokenizer(
texts,
return_tensors="jax",
padding="max_length",
truncation=True,
max_length=args.max_length
)
def decode_texts(ids):
return tokenizer.batch_decode(
ids.reshape(-1,args.max_length),
skip_special_tokens=True,
max_length=args.max_length
)
def generate_parallel(texts):
p_params = replicate(model.params)
p_generate = jax.pmap(_generate, "batch")
inputs = encode_texts(texts)
p_inputs = shard(inputs.data)
output_ids = p_generate(p_params, p_inputs)
return decode_texts(output_ids.reshape(-1,args.max_length))
def generate(texts):
inputs = encode_texts(texts)
output_ids = _generate(model.params, inputs)
return decode_texts(output_ids.reshape(-1,args.max_length))
def replace_translated_fields(ex, id, trans_dict):
for field in trans_dict.keys():
ex[field] = trans_dict[field][id] if id < len(trans_dict[field]) else "<NAN>"
return ex
if args.dataset_splits is None:
args.dataset_splits = list(data.keys())
for split in args.dataset_splits:
print(f"Translating split {split}")
orig_len = len(data[split]) if not args.debug else 555
shard_length = orig_len - (orig_len % args.num_devices)
print(f"Original length: {orig_len}, using only {shard_length} entries")
dic_translated_fields = {}
for field in args.fields_to_translate:
print(f"Processing field {field}...")
dic_translated_fields[field] = [
sent for idx in tqdm(range(0,shard_length, args.batch_size))
for sent in generate_parallel(
data[split][idx:idx+args.batch_size if idx + args.batch_size <= shard_length else shard_length][field],
)
]
if args.show_examples > 0:
print("Translated parallel sentences:", len(dic_translated_fields[field]))
print("Example:", dic_translated_fields[field][:args.show_examples])
# If not all sentences were translated, translate the remaining ones
if shard_length < orig_len:
extra_sentences = generate(data[split][shard_length:orig_len][field])
if args.show_examples > 0:
print("Translated extra sentences:", len(extra_sentences))
print("Example:", extra_sentences[:args.show_examples])
dic_translated_fields[field] += extra_sentences
assert len(dic_translated_fields[field]) == orig_len, f"Size mismatch: {len(dic_translated_fields[field])} does not match original size {orig_len}"
data[split] = data[split].map(
lambda x, id: replace_translated_fields(x, id, dic_translated_fields),
with_indices=True
)
data[split] = data[split].filter(lambda ex: not ex[args.fields_to_translate[0]].startswith('<NAN>'))
if args.debug:
print(data)
else:
save_path = f"{args.dataset_name}_{args.tgt}"
data.save_to_disk(save_path)
print(f"Translated dataset saved at {save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Translate a dataset in a different language with FlaxMarianMTModels.')
parser.add_argument("--dataset_name", required=True, help="Name of the dataset on which the translation should be performed.")
parser.add_argument("--dataset_config", default=None, help="Name of the configuration on which the translation should be performed.")
parser.add_argument("--dataset_splits", default=None, nargs="+", help="Name of the split on which the translation should be performed.")
parser.add_argument("--model_name", default=None, help="Name of the model to use. If not specified, use the default MarianMTModel for src-tgt from the Helsinki-NLP account.")
parser.add_argument("--src", default="en", help="Source language of the original fields.")
parser.add_argument("--tgt", required=True, help="Target language of the translation.")
parser.add_argument("--batch_size", default=512, help="The batch size to carry out the translation.")
parser.add_argument("--num_devices", default=8, help="Number of available devices.")
parser.add_argument("--num_beams", default=4, help="Number of beams for beam search decoding.")
parser.add_argument("--max_length", default=512, help="Max length for the input and the output of the generation")
parser.add_argument("--early_stopping", default=True, help="Whether to perform early stopping in generation")
parser.add_argument("--fields_to_translate", default="text", nargs='+', help="The list of fields that should be translated")
parser.add_argument("--debug", action="store_true", help="Set to use only a small amount of sentences for debugging.")
parser.add_argument("--show_examples", default=3, help="Set to show n examples from the translations.")
args = parser.parse_args()
translate(args)
@gsarti
Copy link
Author

gsarti commented Jan 6, 2022

Example run: translating from English to Dutch all fields of the train split of esnli.

python run_marian_translate_flax.py \
	--dataset_name esnli \
	--dataset_config plain_text \
	--dataset_splits train \
	--model_name Helsinki-NLP/opus-mt-en-nl \
	--src en \
	--tgt nl \
	--fields_to_translate premise hypothesis explanation_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment