Skip to content

Instantly share code, notes, and snippets.

@tokestermw
Last active September 6, 2018 21:27
Show Gist options
  • Save tokestermw/6b3549bc5caa1be1d724a2a09659284c to your computer and use it in GitHub Desktop.
Save tokestermw/6b3549bc5caa1be1d724a2a09659284c to your computer and use it in GitHub Desktop.
Test code of `_ElmoSoftmax`.
from typing import List, Tuple
import torch
from allennlp.data import Token, Instance, Vocabulary
from allennlp.data.dataset import Batch
from allennlp.data.fields import TextField
from allennlp.data.token_indexers import ELMoTokenCharactersIndexer, SingleIdTokenIndexer
from allennlp.modules.elmo import (
_ElmoCharacterEncoder, _ElmoBiLm, _ElmoSoftmax, Elmo, #batch_to_ids
)
DEFAULT_OPTIONS_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" # pylint: disable=line-too-long
DEFAULT_WEIGHT_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" # pylint: disable=line-too-long
# TODO: add softmax as an option to the elmo command
DEFAULT_SOFTMAX_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_softmax_weights.hdf5" # pylint: disable=line-too-long
DEFAULT_VOCAB_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/vocab-2016-09-10.txt" # pylint: disable=line-too-long
def batch_to_ids(batch: List[List[str]], vocab: Vocabulary = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Converts a batch of tokenized sentences to a tensor representing the sentences with encoded characters
(len(batch), max sentence length, max word length).
Parameters
----------
batch : ``List[List[str]]``, required
A list of tokenized sentences.
vocab : ``Vocabulary``, optional
A vocab of words if you need to return word ids.
Returns
-------
If vocab is present, returns a tuple of char ids and word ids.
If not, it returns a tensor of char ids.
"""
instances = []
char_indexer = ELMoTokenCharactersIndexer()
if vocab:
token_indexer = SingleIdTokenIndexer(
namespace='tokens', lowercase_tokens=False)
else:
token_indexer = None
for sentence in batch:
tokens = [Token(token) for token in sentence]
if vocab:
field = TextField(tokens, {
'character_ids': char_indexer,
'word_ids': token_indexer,
})
else:
field = TextField(tokens, {'character_ids': char_indexer})
instance = Instance({"elmo": field})
instances.append(instance)
dataset = Batch(instances)
dataset.index_instances(vocab)
elmo_tensor_dict = dataset.as_tensor_dict()['elmo']
if vocab:
return elmo_tensor_dict['character_ids'], elmo_tensor_dict['word_ids']
else:
return elmo_tensor_dict['character_ids']
def _tokenize(text):
return text.split()
if __name__ == '__main__':
sentences = [
'How are you ?',
'how are you ?',
'How are you .',
'You are how ?',
]
sentences = [_tokenize(i) for i in sentences]
# elmo_char_encoder - _ElmoCharacterEncoder
elmo_bilm = _ElmoBiLm(DEFAULT_OPTIONS_FILE, DEFAULT_WEIGHT_FILE)
elmo_softmax = _ElmoSoftmax(DEFAULT_SOFTMAX_FILE, DEFAULT_VOCAB_FILE)
char_ids, word_ids = batch_to_ids(sentences, elmo_softmax.vocab)
bilm_outputs = elmo_bilm(char_ids)
softmax_log_probs, softmax_mask = elmo_softmax(
bilm_outputs, word_ids, aggregation_fun='mean')
# average backward and forward log probs
print(softmax_log_probs)
print(softmax_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment