Last active
July 27, 2023 10:34
-
-
Save manueldeprada/00d7a84632d8e858ff0c208e5e44559b to your computer and use it in GitHub Desktop.
Revisions
-
manueldeprada revised this gist
Jul 27, 2023 . 1 changed file with 11 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer, BatchEncoding from transformers.modeling_outputs import Seq2SeqLMOutput import torch @@ -59,14 +59,18 @@ def forward( attentions = () if output_attentions else None hidden_states = () if output_hidden_states else None output = output.log() #replace -inf with -1e9 #output[output == float('-inf')] = -1e9 # Create Seq2SeqLMOutput object if not return_dict: att, hidd = attentions if attentions is not None else (), hidden_states if hidden_states is not None else () return (output,) + past_key_values + hidd + att else: return Seq2SeqLMOutput( loss=None, logits=output, past_key_values=past_key_values, decoder_hidden_states=hidden_states, decoder_attentions=attentions, @@ -77,14 +81,17 @@ def forward( def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def _reorder_cache(self, past, beam_idx): return past class FakeTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): super().__init__(**kwargs) def __call__(self, text, **kwargs): return BatchEncoding({"input_ids": torch.tensor([[0],[0]])}) def batch_decode(self, token_ids, **kwargs): return [str(ids) for ids in token_ids] -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 12 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput import torch @@ -77,6 +77,17 @@ def forward( def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} class FakeTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): super().__init__(**kwargs) def __call__(self, text, **kwargs): return {"input_ids": [0]} def batch_decode(self, token_ids, **kwargs): return [str(ids) for ids in token_ids] if __name__ == "__main__": -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 7 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -10,6 +10,8 @@ def __init__(self, vocab_size=4, **kwargs): super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs) self.vocab_size = vocab_size self.max_new_tokens = 4 self.do_sample = True self.num_beams = 1 class FakeTransformer(PreTrainedModel): @@ -76,18 +78,14 @@ def forward( def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} if __name__ == "__main__": config = FakeTransformerConfig() model = FakeTransformer(config) input_ids = torch.tensor([[0]] * 10) output = model.generate(input_ids, max_length=4, do_sample=True, return_dict_in_generate=True, output_scores=True) print(f"generated tokens: {output.sequences}") print(f"probs at step 1: {torch.exp(output.scores[0][0, :])}") print(f"probs at step 2: {torch.exp(output.scores[1][0, :])}") print(f"probs at step 3: {torch.exp(output.scores[2][0, :])}") -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -9,6 +9,7 @@ class FakeTransformerConfig(PretrainedConfig): def __init__(self, vocab_size=4, **kwargs): super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs) self.vocab_size = vocab_size self.max_new_tokens = 4 class FakeTransformer(PreTrainedModel): -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -14,7 +14,7 @@ def __init__(self, vocab_size=4, **kwargs): class FakeTransformer(PreTrainedModel): config_class = FakeTransformerConfig def __init__(self, config=None): config = FakeTransformerConfig() if config is None else config super().__init__(config) self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -15,6 +15,7 @@ class FakeTransformer(PreTrainedModel): config_class = FakeTransformerConfig def __init__(self, config): config = FakeTransformerConfig() if config is None else config super().__init__(config) self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -75,7 +75,7 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def main(): config = FakeTransformerConfig() model = FakeTransformer(config) -
manueldeprada revised this gist
Jul 17, 2023 . 1 changed file with 6 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -74,8 +74,8 @@ def forward( def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def main(): config = FakeTransformerConfig() model = FakeTransformer(config) @@ -85,3 +85,7 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs): print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}") print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}") print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}") if __name__ == "__main__": main() -
manueldeprada created this gist
Jul 16, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,87 @@ from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import Seq2SeqLMOutput import torch class FakeTransformerConfig(PretrainedConfig): model_type = "FakeTransformer" def __init__(self, vocab_size=4, **kwargs): super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs) self.vocab_size = vocab_size class FakeTransformer(PreTrainedModel): config_class = FakeTransformerConfig def __init__(self, config): super().__init__(config) self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model def forward( self, input_ids=None, attention_mask=None, encoder_outputs=None, past_key_values=None, return_dict=True, output_attentions=None, output_hidden_states=None, **kwargs ): seq_len = input_ids.shape[1] batch_size = input_ids.shape[0] # Placeholder for output probabilities output = torch.zeros(batch_size, seq_len, self.config.vocab_size) if seq_len >= 1: output[:, 0, 1] = 0.75 output[:, 0, 2] = 0.25 if seq_len >= 2: output[:, 1, 1] = 0.10 output[:, 1, 2] = 0.90 if seq_len >= 3: output[:, 2, :] = 0.0 output[:, 2, 3] = 1.0 if not past_key_values: # when using past_key_values, only last token logits are generated output = output[:, -1, :].unsqueeze(1) # Placeholder for past_key_values, attentions, hidden_states past_key_values = () if past_key_values is None else past_key_values attentions = () if output_attentions else None hidden_states = () if output_hidden_states else None # Create Seq2SeqLMOutput object if not return_dict: att, hidd = attentions if attentions is not None else (), hidden_states if hidden_states is not None else () return (output,) + past_key_values + hidd + att else: return Seq2SeqLMOutput( loss=None, logits=output.log(), past_key_values=past_key_values, decoder_hidden_states=hidden_states, decoder_attentions=attentions, encoder_last_hidden_state=None, encoder_hidden_states=None, encoder_attentions=None, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} if __name__ == "__main__": config = FakeTransformerConfig() model = FakeTransformer(config) input_ids = torch.tensor([[0]] * 10) output = model.generate(input_ids, max_length=4, do_sample=True, return_dict_in_generate=True, output_scores=True) print(f"generated tokens: {output.sequences}") print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}") print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}") print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}")