Skip to content

Instantly share code, notes, and snippets.

@darrenangle
Forked from wassname/choice_tree.py
Created June 3, 2024 13:44
Show Gist options
  • Save darrenangle/36850c556877751009f248cf10f1328f to your computer and use it in GitHub Desktop.
Save darrenangle/36850c556877751009f248cf10f1328f to your computer and use it in GitHub Desktop.

Revisions

  1. @wassname wassname revised this gist May 10, 2024. No changes.
  2. @wassname wassname revised this gist May 10, 2024. No changes.
  3. @wassname wassname created this gist May 10, 2024.
    61 changes: 61 additions & 0 deletions choice_tree.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,61 @@
    from jaxtyping import Float, Int
    import torch
    from torch.nn import functional as F
    from torch import Tensor
    from typing import List, Callable, Tuple, Dict, Optional
    import pandas as pd
    from transformers import AutoModelForCausalLM, AutoTokenizer


    def get_valid_next_choices(choices_tokens, current_tokens):
    next_choices = []
    for choice_tokens in choices_tokens:
    # if we have some more slots left
    if len(current_tokens) < len(choice_tokens):
    # see if current_tokens matches
    if (choice_tokens[: len(current_tokens)] == current_tokens).all():
    c = choice_tokens[len(current_tokens)].item()
    next_choices.append(c)

    next_choices = list(set(next_choices))
    return torch.LongTensor(next_choices)


    def choice_tree(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    input_ids: Int[Tensor, "seq"],
    choices_tokens: List[Int[Tensor, "seq"]],
    choice: Optional[Int[Tensor, ""]] = None,
    prob: float = 1,
    current_tokens: Int[Tensor, "seq"] = torch.LongTensor([]),
    z=[],
    ):
    if choice is not None:
    c = choice[None].to(current_tokens.device)
    current_tokens = torch.cat([current_tokens, c], dim=-1)
    c = choice[None].to(input_ids.device)
    input_ids = torch.cat([input_ids, c], dim=-1)

    next_choices = get_valid_next_choices(choices_tokens, current_tokens)
    if len(next_choices) == 0:
    s = tokenizer.decode(current_tokens)
    r = dict(prob=prob, choice=s)
    yield r
    else:
    o = model(input_ids[None])
    logits_constrained = o.logits[0, -1][next_choices]
    probs = F.softmax(logits_constrained, dim=-1)
    for i in range(len(next_choices)):
    next_choice = next_choices[i]
    next_prob = prob * probs[i].item()
    yield from choice_tree(
    model=model,
    tokenizer=tokenizer,
    choices_tokens=choices_tokens,
    input_ids=input_ids,
    choice=next_choice,
    prob=next_prob,
    current_tokens=current_tokens,
    z=z + [i],
    )