Skip to content

Instantly share code, notes, and snippets.

@Maximilian-Winter
Forked from thomwolf/top-k-top-p.py
Created June 10, 2023 19:04
Show Gist options
  • Select an option

  • Save Maximilian-Winter/b456124d9596e6bfeb69d583f43b37d0 to your computer and use it in GitHub Desktop.

Select an option

Save Maximilian-Winter/b456124d9596e6bfeb69d583f43b37d0 to your computer and use it in GitHub Desktop.

Revisions

  1. @thomwolf thomwolf revised this gist Jul 13, 2019. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -4,6 +4,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
    logits: logits distribution shape (vocabulary size)
    top_k >0: keep only top k tokens with highest probability (top-k filtering).
    top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1)) # Safety check
  2. @thomwolf thomwolf revised this gist May 16, 2019. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -1,10 +1,11 @@
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
    logits: logits distribution shape (vocabulary size) - batch size 1 for now
    logits: logits distribution shape (vocabulary size)
    top_k >0: keep only top k tokens with highest probability (top-k filtering).
    top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1)) # Safety check
    if top_k > 0:
    # Remove all tokens with a probability less than the last token of the top-k
  3. @thomwolf thomwolf revised this gist May 16, 2019. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -1,7 +1,7 @@
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    logits: logits distribution shape (vocabulary size) - batch size 1 for now
    top_k >0: keep only top k tokens with highest probability (top-k filtering).
    top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    @@ -33,8 +33,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
    # Get logits with a forward pass in our model (input is pre-defined)
    logits = model(input)

    # Keep only the last token predictions, apply a temperature coefficient and filter
    logits = logits[..., -1, :] / temperature
    # Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter
    logits = logits[0, -1, :] / temperature
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

    # Sample from the filtered distribution
  4. @thomwolf thomwolf revised this gist May 3, 2019. No changes.
  5. @thomwolf thomwolf revised this gist May 3, 2019. 1 changed file with 8 additions and 7 deletions.
    15 changes: 8 additions & 7 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -1,14 +1,15 @@
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    """ Filter a distribution of logits using top-k and/or top-p filtering
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    top_k >0: keep only top k tokens with highest probability.
    top_p >0.0: keep the top tokens with cumulative probability >= top_p.
    top_k >0: keep only top k tokens with highest probability (top-k filtering).
    top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
    """
    top_k = min(top_k, logits.size(-1)) # Safety check
    if top_k > 0:
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits[indices_to_remove] = filter_value

    if top_p > 0.0:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    @@ -21,7 +22,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = -float('Inf')
    logits[indices_to_remove] = filter_value
    return logits

    # Here is how to use this function for top-p sampling
  6. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -17,7 +17,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
  7. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -12,10 +12,10 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):

    if top_p > 0.0:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probabilities > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
    sorted_indices_to_remove[..., 0] = 0
  8. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -6,7 +6,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    top_p >0.0: keep the top tokens with cumulative probability >= top_p.
    """
    if top_k > 0:
    # Remove all tokens with a probability less than the last token in the top-k tokens
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')

  9. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -6,8 +6,10 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    top_p >0.0: keep the top tokens with cumulative probability >= top_p.
    """
    if top_k > 0:
    # Remove all tokens with a probability less than the last token in the top-k tokens
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')

    if top_p > 0.0:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    @@ -30,7 +32,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    # Get logits with a forward pass in our model (input is pre-defined)
    logits = model(input)

    # Keep only the predictions for the last token, apply a temperature coefficient and filter
    # Keep only the last token predictions, apply a temperature coefficient and filter
    logits = logits[..., -1, :] / temperature
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

  10. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 6 additions and 13 deletions.
    19 changes: 6 additions & 13 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -2,19 +2,13 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    """ Filter a distribution of logits using top-k and/or top-p filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
    top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
    whose total probability mass is greater than or equal to the threshold top_p.
    In practice, we select the highest probability tokens whose cumulative probability mass exceeds
    the threshold top_p.
    top_k >0: keep only top k tokens with highest probability.
    top_p >0.0: keep the top tokens with cumulative probability >= top_p.
    """
    if top_k > 0:
    # Remove all tokens with a probability less than the last token in the top-k tokens
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')

    if top_p > 0.0:
    # Compute cumulative probabilities of sorted tokens
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    @@ -24,23 +18,22 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
    sorted_indices_to_remove[..., 0] = 0

    # Back to unsorted indices and set them to -infinity
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = -float('Inf')

    return logits

    # Here is how to use this function for top-p sampling
    temperature = 1.0
    top_k = 0
    top_p = 0.9

    # Get the logits with a forward pass in our model assuming we have prepared an input tensor
    # Get logits with a forward pass in our model (input is pre-defined)
    logits = model(input)

    # Keep only the predictions for the last token and apply a temperature coefficient if needed
    # Keep only the predictions for the last token, apply a temperature coefficient and filter
    logits = logits[..., -1, :] / temperature
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    probabilities = F.softmax(filtered_logits, dim=-1)

    # Sample from the filtered distribution
    probabilities = F.softmax(filtered_logits, dim=-1)
    next_token = torch.multinomial(probabilities, 1)
  11. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 16 additions and 1 deletion.
    17 changes: 16 additions & 1 deletion top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    def apply_top_k_or_top_p(logits, top_k=0, top_p=0.0):
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0):
    """ Filter a distribution of logits using top-k and/or top-p filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    @@ -29,3 +29,18 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0.0):
    logits[indices_to_remove] = -float('Inf')

    return logits

    # Here is how to use this function for top-p sampling
    temperature = 1.0
    top_k = 0
    top_p = 0.9

    # Get the logits with a forward pass in our model assuming we have prepared an input tensor
    logits = model(input)

    # Keep only the predictions for the last token and apply a temperature coefficient if needed
    logits = logits[..., -1, :] / temperature
    filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    probabilities = F.softmax(filtered_logits, dim=-1)

    next_token = torch.multinomial(probabilities, 1)
  12. @thomwolf thomwolf revised this gist May 1, 2019. 1 changed file with 6 additions and 21 deletions.
    27 changes: 6 additions & 21 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -1,19 +1,19 @@
    def apply_top_k_or_top_p(logits, top_k=0, top_p=0):
    def apply_top_k_or_top_p(logits, top_k=0, top_p=0.0):
    """ Filter a distribution of logits using top-k and/or top-p filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    top_k: 0 => no filtering, >0 => keep only top k tokens with highest probability.
    top_p: 0 => no filtering, >0 => keep only a subset S of candidates, where S is the smallest subset
    top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
    top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
    whose total probability mass is greater than or equal to the threshold top_p.
    In practice, we select the highest probability tokens whose cumulative probability mass exceeds
    the threshold top_p.
    """
    if top_k != 0:
    if top_k > 0:
    # Remove all tokens with a probability less than the last token in the top-k tokens
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')

    if top_p != 0:
    if top_p > 0.0:
    # Compute cumulative probabilities of sorted tokens
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    @@ -24,23 +24,8 @@ def apply_top_k_or_top_p(logits, top_k=0, top_p=0):
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
    sorted_indices_to_remove[..., 0] = 0

    # Back to unsorted indices
    # Back to unsorted indices and set them to -infinity
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = -float('Inf')

    return logits

    # Here is how to use this function for top-p sampling
    temperature = 1.0
    top_k = 0
    top_p = 0.9

    # Get the logits with a forward pass in our model assuming we have prepared an input tensor
    logits = model(input)

    # Keep only the predictions for the last token and apply a temperature coefficient if needed
    logits = logits[..., -1, :] / temperature
    filtered_logits = apply_top_k_or_top_p(logits, top_k=top_k, top_p=top_p)
    probabilities = F.softmax(filtered_logits, dim=-1)

    next_token = torch.multinomial(probabilities, 1)
  13. @thomwolf thomwolf created this gist May 1, 2019.
    46 changes: 46 additions & 0 deletions top-k-top-p.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,46 @@
    def apply_top_k_or_top_p(logits, top_k=0, top_p=0):
    """ Filter a distribution of logits using top-k and/or top-p filtering
    Args:
    logits: logits distribution shape (..., vocabulary size)
    top_k: 0 => no filtering, >0 => keep only top k tokens with highest probability.
    top_p: 0 => no filtering, >0 => keep only a subset S of candidates, where S is the smallest subset
    whose total probability mass is greater than or equal to the threshold top_p.
    In practice, we select the highest probability tokens whose cumulative probability mass exceeds
    the threshold top_p.
    """
    if top_k != 0:
    # Remove all tokens with a probability less than the last token in the top-k tokens
    indices_to_remove = logits < torch.topk(logits, top_k)[0][:, -1]
    logits[indices_to_remove] = -float('Inf')

    if top_p != 0:
    # Compute cumulative probabilities of sorted tokens
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probabilities > top_p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
    sorted_indices_to_remove[..., 0] = 0

    # Back to unsorted indices
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = -float('Inf')

    return logits

    # Here is how to use this function for top-p sampling
    temperature = 1.0
    top_k = 0
    top_p = 0.9

    # Get the logits with a forward pass in our model assuming we have prepared an input tensor
    logits = model(input)

    # Keep only the predictions for the last token and apply a temperature coefficient if needed
    logits = logits[..., -1, :] / temperature
    filtered_logits = apply_top_k_or_top_p(logits, top_k=top_k, top_p=top_p)
    probabilities = F.softmax(filtered_logits, dim=-1)

    next_token = torch.multinomial(probabilities, 1)