Last active
September 6, 2018 21:27
-
-
Save tokestermw/6b3549bc5caa1be1d724a2a09659284c to your computer and use it in GitHub Desktop.
Test code of `_ElmoSoftmax`.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import List, Tuple | |
| import torch | |
| from allennlp.data import Token, Instance, Vocabulary | |
| from allennlp.data.dataset import Batch | |
| from allennlp.data.fields import TextField | |
| from allennlp.data.token_indexers import ELMoTokenCharactersIndexer, SingleIdTokenIndexer | |
| from allennlp.modules.elmo import ( | |
| _ElmoCharacterEncoder, _ElmoBiLm, _ElmoSoftmax, Elmo, #batch_to_ids | |
| ) | |
| DEFAULT_OPTIONS_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" # pylint: disable=line-too-long | |
| DEFAULT_WEIGHT_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" # pylint: disable=line-too-long | |
| # TODO: add softmax as an option to the elmo command | |
| DEFAULT_SOFTMAX_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_softmax_weights.hdf5" # pylint: disable=line-too-long | |
| DEFAULT_VOCAB_FILE = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/vocab-2016-09-10.txt" # pylint: disable=line-too-long | |
| def batch_to_ids(batch: List[List[str]], vocab: Vocabulary = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Converts a batch of tokenized sentences to a tensor representing the sentences with encoded characters | |
| (len(batch), max sentence length, max word length). | |
| Parameters | |
| ---------- | |
| batch : ``List[List[str]]``, required | |
| A list of tokenized sentences. | |
| vocab : ``Vocabulary``, optional | |
| A vocab of words if you need to return word ids. | |
| Returns | |
| ------- | |
| If vocab is present, returns a tuple of char ids and word ids. | |
| If not, it returns a tensor of char ids. | |
| """ | |
| instances = [] | |
| char_indexer = ELMoTokenCharactersIndexer() | |
| if vocab: | |
| token_indexer = SingleIdTokenIndexer( | |
| namespace='tokens', lowercase_tokens=False) | |
| else: | |
| token_indexer = None | |
| for sentence in batch: | |
| tokens = [Token(token) for token in sentence] | |
| if vocab: | |
| field = TextField(tokens, { | |
| 'character_ids': char_indexer, | |
| 'word_ids': token_indexer, | |
| }) | |
| else: | |
| field = TextField(tokens, {'character_ids': char_indexer}) | |
| instance = Instance({"elmo": field}) | |
| instances.append(instance) | |
| dataset = Batch(instances) | |
| dataset.index_instances(vocab) | |
| elmo_tensor_dict = dataset.as_tensor_dict()['elmo'] | |
| if vocab: | |
| return elmo_tensor_dict['character_ids'], elmo_tensor_dict['word_ids'] | |
| else: | |
| return elmo_tensor_dict['character_ids'] | |
| def _tokenize(text): | |
| return text.split() | |
| if __name__ == '__main__': | |
| sentences = [ | |
| 'How are you ?', | |
| 'how are you ?', | |
| 'How are you .', | |
| 'You are how ?', | |
| ] | |
| sentences = [_tokenize(i) for i in sentences] | |
| # elmo_char_encoder - _ElmoCharacterEncoder | |
| elmo_bilm = _ElmoBiLm(DEFAULT_OPTIONS_FILE, DEFAULT_WEIGHT_FILE) | |
| elmo_softmax = _ElmoSoftmax(DEFAULT_SOFTMAX_FILE, DEFAULT_VOCAB_FILE) | |
| char_ids, word_ids = batch_to_ids(sentences, elmo_softmax.vocab) | |
| bilm_outputs = elmo_bilm(char_ids) | |
| softmax_log_probs, softmax_mask = elmo_softmax( | |
| bilm_outputs, word_ids, aggregation_fun='mean') | |
| # average backward and forward log probs | |
| print(softmax_log_probs) | |
| print(softmax_mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment