-
-
Save Maximilian-Winter/b456124d9596e6bfeb69d583f43b37d0 to your computer and use it in GitHub Desktop.
Revisions
-
thomwolf revised this gist
Jul 13, 2019 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,6 +4,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') logits: logits distribution shape (vocabulary size) top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check -
thomwolf revised this gist
May 16, 2019 . 1 changed file with 2 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,10 +1,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k -
thomwolf revised this gist
May 16, 2019 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,7 +1,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) - batch size 1 for now top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). """ @@ -33,8 +33,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') # Get logits with a forward pass in our model (input is pre-defined) logits = model(input) # Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter logits = logits[0, -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample from the filtered distribution -
thomwolf revised this gist
May 3, 2019 . No changes.There are no files selected for viewing
-
thomwolf revised this gist
May 3, 2019 . 1 changed file with 8 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,14 +1,15 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (..., vocabulary size) top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). """ top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -21,7 +22,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits # Here is how to use this function for top-p sampling -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -17,7 +17,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,10 +12,10 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1] sorted_indices_to_remove[..., 0] = 0 -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,7 +6,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): top_p >0.0: keep the top tokens with cumulative probability >= top_p. """ if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1] logits[indices_to_remove] = -float('Inf') -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,8 +6,10 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): top_p >0.0: keep the top tokens with cumulative probability >= top_p. """ if top_k > 0: # Remove all tokens with a probability less than the last token in the top-k tokens indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1] logits[indices_to_remove] = -float('Inf') if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -30,7 +32,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): # Get logits with a forward pass in our model (input is pre-defined) logits = model(input) # Keep only the last token predictions, apply a temperature coefficient and filter logits = logits[..., -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 6 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,19 +2,13 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): """ Filter a distribution of logits using top-k and/or top-p filtering Args: logits: logits distribution shape (..., vocabulary size) top_k >0: keep only top k tokens with highest probability. top_p >0.0: keep the top tokens with cumulative probability >= top_p. """ if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1] logits[indices_to_remove] = -float('Inf') if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -24,23 +18,22 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1] sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = -float('Inf') return logits # Here is how to use this function for top-p sampling temperature = 1.0 top_k = 0 top_p = 0.9 # Get logits with a forward pass in our model (input is pre-defined) logits = model(input) # Keep only the predictions for the last token, apply a temperature coefficient and filter logits = logits[..., -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample from the filtered distribution probabilities = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, 1) -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 16 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0): """ Filter a distribution of logits using top-k and/or top-p filtering Args: logits: logits distribution shape (..., vocabulary size) @@ -29,3 +29,18 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0.0): logits[indices_to_remove] = -float('Inf') return logits # Here is how to use this function for top-p sampling temperature = 1.0 top_k = 0 top_p = 0.9 # Get the logits with a forward pass in our model assuming we have prepared an input tensor logits = model(input) # Keep only the predictions for the last token and apply a temperature coefficient if needed logits = logits[..., -1, :] / temperature filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) probabilities = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, 1) -
thomwolf revised this gist
May 1, 2019 . 1 changed file with 6 additions and 21 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,19 +1,19 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0.0): """ Filter a distribution of logits using top-k and/or top-p filtering Args: logits: logits distribution shape (..., vocabulary size) top_k: <=0: no filtering, >0: keep only top k tokens with highest probability. top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset whose total probability mass is greater than or equal to the threshold top_p. In practice, we select the highest probability tokens whose cumulative probability mass exceeds the threshold top_p. """ if top_k > 0: # Remove all tokens with a probability less than the last token in the top-k tokens indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1] logits[indices_to_remove] = -float('Inf') if top_p > 0.0: # Compute cumulative probabilities of sorted tokens sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -24,23 +24,8 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0): sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1] sorted_indices_to_remove[..., 0] = 0 # Back to unsorted indices and set them to -infinity indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = -float('Inf') return logits -
thomwolf created this gist
May 1, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,46 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0): """ Filter a distribution of logits using top-k and/or top-p filtering Args: logits: logits distribution shape (..., vocabulary size) top_k: 0 => no filtering, >0 => keep only top k tokens with highest probability. top_p: 0 => no filtering, >0 => keep only a subset S of candidates, where S is the smallest subset whose total probability mass is greater than or equal to the threshold top_p. In practice, we select the highest probability tokens whose cumulative probability mass exceeds the threshold top_p. """ if top_k != 0: # Remove all tokens with a probability less than the last token in the top-k tokens indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1] logits[indices_to_remove] = -float('Inf') if top_p != 0: # Compute cumulative probabilities of sorted tokens sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probabilities > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1] sorted_indices_to_remove[..., 0] = 0 # Back to unsorted indices indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = -float('Inf') return logits # Here is how to use this function for top-p sampling temperature = 1.0 top_k = 0 top_p = 0.9 # Get the logits with a forward pass in our model assuming we have prepared an input tensor logits = model(input) # Keep only the predictions for the last token and apply a temperature coefficient if needed logits = logits[..., -1, :] / temperature filtered_logits = apply_top_k_or_top_p(logits, top_k=top_k, top_p=top_p) probabilities = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probabilities, 1)