Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Created August 15, 2023 00:41
Show Gist options
  • Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.
Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.

Revisions

  1. napsternxg created this gist Aug 15, 2023.
    191 changes: 191 additions & 0 deletions generate.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,191 @@
    import functools

    import pandas as pd
    import torch
    import transformers
    from accelerate import Accelerator
    from datasets import Dataset
    from torch.utils.data import DataLoader
    from tqdm.auto import tqdm

    from t5_training_utils import (
    GenerationType,
    build_prefix_allowed_tokens_fn,
    convert_to_features,
    get_gen_type_attributes,
    get_model_full_name,
    get_prediction_name,
    )

    torch_dtype = "auto"


    model_ckpt = "t5-base"
    gen_type = GenerationType.ALL_TOKENS

    input_max_length = 512
    label_max_length = 6

    use_task_prefix = True

    class_names = [
    "Soccer",
    "Cricket",
    "Handball",
    "Snow Cycling",
    ]
    non_eligible_classes = {
    "Snow Cycling"
    }

    non_eligible_idx = [
    i for i, c in enumerate(class_names) if c in non_eligible_classes
    ]
    num_classes = len(class_names)


    # Model training

    ### Uncomment a config section for the model type

    ## For small test run
    train_batch_size = 8
    eval_batch_size = 8
    epochs = 30
    save_every_k_epochs = 5

    seed = 3333
    torch.manual_seed(seed)

    logging_steps = 100 # len(squad["train"]) // batch_size
    eval_step = 100

    learning_rate = 2e-5
    weight_decay = 0.01

    data_version = "guidelines-fixed-occasion"
    model_full_name = get_model_full_name(model_ckpt, gen_type, epochs, data_version)


    def get_model(model_local_ckpt):
    model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_local_ckpt)
    return model.eval()


    def get_dataset(
    data_path, tokenizer, class_text_map, task_prefix, accelerator, nrows=1024, offset=0
    ):
    df_data = pd.read_csv(data_path, sep="\t", nrows=nrows + offset).rename(
    columns={"query": "text"}
    )
    df_data = df_data.iloc[offset : offset + nrows]
    print(df_data)
    dataset = Dataset.from_pandas(df_data)
    dataset.reset_format()
    with accelerator.main_process_first():
    dataset = dataset.map(
    functools.partial(
    convert_to_features,
    class_text_map=class_text_map,
    task_prefix=task_prefix,
    query_key="text",
    label_key=None,
    tokenizer=tokenizer,
    )
    )
    return dataset, df_data


    def get_predictions_accelerate(data_path, model_local_ckpt, nrows=1024, offset=0):
    accelerator = Accelerator()

    device = accelerator.device

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)
    class_text_map, max_decoding_length, task_prefix = get_gen_type_attributes(
    gen_type, tokenizer, class_names
    )
    task_prefix = task_prefix if use_task_prefix else ""

    dataset, df_data = get_dataset(
    data_path,
    tokenizer,
    class_text_map,
    task_prefix,
    accelerator,
    nrows=nrows,
    offset=offset,
    )

    model = get_model(model_local_ckpt)
    model = model.to(device)

    allowed_sequences = [[0] + tokenizer.encode(x) for x in class_text_map.values()]
    dataset.set_format("pt")
    custom_dataloader = DataLoader(
    dataset, shuffle=True, batch_size=eval_batch_size, num_workers=4
    )

    model, custom_dataloader = accelerator.prepare(model, custom_dataloader)
    preds = []
    with torch.no_grad():
    for batch in tqdm(
    custom_dataloader, disable=not accelerator.is_local_main_process
    ):
    batch_input_ids = batch["input_ids"].to(device)
    batch_attention_mask = batch["attention_mask"].to(device)
    # For DDP models use accelerator.unwrap_model(model).generate(inputs)
    # Taken from: https://github.com/huggingface/transformers/issues/18974
    batch_outs = accelerator.unwrap_model(model).generate(
    input_ids=batch_input_ids,
    attention_mask=batch_attention_mask,
    max_length=max_decoding_length,
    prefix_allowed_tokens_fn=build_prefix_allowed_tokens_fn(
    allowed_sequences
    ),
    )
    batch_outs = accelerator.pad_across_processes(
    batch_outs, dim=1, pad_index=tokenizer.pad_token_id
    )
    batch_outs = accelerator.gather_for_metrics(batch_outs).cpu().numpy()
    preds.extend(tokenizer.batch_decode(batch_outs, skip_special_tokens=True))
    accelerator.wait_for_everyone()

    if accelerator.is_main_process:
    if len(preds) != len(dataset):
    raise ValueError(
    f"Predictions and labels have different lengths. preds: {len(preds)} "
    f"labels: {len(dataset)}"
    )
    pred_col = get_prediction_name(model_full_name)
    df_data[pred_col] = preds
    class_text_map_reversed = {val: key for key, val in class_text_map.items()}
    df_data[pred_col] = df_data[pred_col].apply(lambda x: class_text_map_reversed[x])

    # eligible = ~df_test[pred_col].isin(non_eligible_classes)
    eligible = ~df_data[pred_col].isin(
    {v for v in non_eligible_classes if v != "Cricket"}
    )
    df_data["eligible_pred"] = eligible
    output_path = data_path.replace(".tsv", f".predicted.{offset}.{nrows}.tsv")
    print(df_data)
    print(f"Writing df_data with predictions to {output_path}")
    df_data.to_csv(output_path, sep="\t", index=False)

    return df_data


    def main():
    data_path = "data.tsv"
    offset = 400_000
    nrows = 153 # 600_000
    model_local_ckpt = "./model_path/checkpoint-2830"
    print(data_path)
    print(nrows)
    print(model_local_ckpt)

    get_predictions_accelerate(data_path, model_local_ckpt, nrows=nrows, offset=offset)


    if __name__ == "__main__":
    main()
    230 changes: 230 additions & 0 deletions t5_training_utils.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,230 @@
    import enum
    import os
    import string
    from typing import List, Mapping

    import marisa_trie
    import torch


    class GenerationType(enum.Enum):
    ALL_TOKENS = "all"
    THREE_TOKENS = "gen3"
    TWO_TOKENS = "gen2"
    ONE_TOKEN = "gen1"
    OPTION_ID = "abcd"


    def get_model_full_name(
    model_ckpt, gen_type: GenerationType, epochs: int, data_version: str = "0"
    ):
    model_base_name = model_ckpt.split("/")[-1]
    gen_type = gen_type.value
    model_full_name = f"{model_base_name}_{gen_type}_ep{epochs}_dt{data_version}"
    return model_full_name


    def get_outdir(model_name: str):
    return os.path.join("./classifers/", model_name)


    def get_prediction_name(model_name: str):
    return f"{model_name}_predicted"


    class MarisaTrie(object):
    def __init__(
    self,
    sequences: List[List[int]] = [],
    cache_fist_branch=True,
    max_token_id=256001,
    ):

    self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + (
    [chr(i) for i in range(65000, max_token_id + 10000)]
    if max_token_id >= 55000
    else []
    )
    self.char2int = {self.int2char[i]: i for i in range(max_token_id)}

    self.cache_fist_branch = cache_fist_branch
    if self.cache_fist_branch:
    self.zero_iter = list({sequence[0] for sequence in sequences})
    assert len(self.zero_iter) == 1
    self.first_iter = list({sequence[1] for sequence in sequences})

    self.trie = marisa_trie.Trie(
    "".join([self.int2char[i] for i in sequence]) for sequence in sequences
    )

    def get(self, prefix_sequence: List[int]):
    if self.cache_fist_branch and len(prefix_sequence) == 0:
    return self.zero_iter
    elif (
    self.cache_fist_branch
    and len(prefix_sequence) == 1
    and self.zero_iter == prefix_sequence
    ):
    return self.first_iter
    else:
    key = "".join([self.int2char[i] for i in prefix_sequence])
    return list(
    {
    self.char2int[e[len(key)]]
    for e in self.trie.keys(key)
    if len(e) > len(key)
    }
    )

    def __iter__(self):
    for sequence in self.trie.iterkeys():
    yield [self.char2int[e] for e in sequence]

    def __len__(self):
    return len(self.trie)

    def __getitem__(self, value):
    return self.get(value)


    def map_class_name(tokenizer, class_raw_name, num_tokens=None, delim=" "):
    if num_tokens is None:
    return class_raw_name
    class_raw_words = class_raw_name.split(delim)
    for i in range(1, len(class_raw_words) + 1):
    class_name_candidate = " ".join(class_raw_words[:i])
    tokens = tokenizer.tokenize(class_name_candidate)
    if len(tokens) == num_tokens:
    return class_name_candidate
    raise ValueError(
    f"Cannot find class name at the specificed num_tokens: {class_raw_name}, {num_tokens}"
    )


    def create_class_text_map(tokenizer, class_raw_names, num_tokens, delim="_"):
    res = {}
    for raw_name in class_raw_names:
    res[raw_name] = map_class_name(tokenizer, raw_name, num_tokens, delim=delim)
    return res


    def get_task_prefix(gen_type, class_text_map):
    task_prefix = "Classify query intent into one of the following categories: "

    if (
    gen_type == GenerationType.TWO_TOKENS
    or gen_type == GenerationType.ONE_TOKEN
    or gen_type == GenerationType.ALL_TOKENS
    ):
    classes = [f"'{x}'" for x in class_text_map.values()]
    task_prefix += ", ".join(classes)
    task_prefix += ". query: "
    elif gen_type == GenerationType.OPTION_ID:
    classes = [
    f"{val}: {' '.join(key.split('_')[:-1])}"
    for key, val in class_text_map.items()
    ]
    task_prefix += "\n" + "\n".join(classes)
    task_prefix += "\nquery: "

    return task_prefix


    def get_gen_type_attributes(gen_type, tokenizer, class_names):
    if gen_type == GenerationType.THREE_TOKENS:
    class_text_map = create_class_text_map(tokenizer, class_names, 3)
    max_decoding_length = 3
    elif gen_type == GenerationType.TWO_TOKENS:
    class_text_map = create_class_text_map(tokenizer, class_names, 2)
    max_decoding_length = 2
    elif gen_type == GenerationType.ONE_TOKEN:
    class_text_map = create_class_text_map(tokenizer, class_names, 1)
    max_decoding_length = 1
    elif gen_type == GenerationType.ALL_TOKENS:
    class_text_map = create_class_text_map(tokenizer, class_names, None)
    max_decoding_length = max(
    [len([0] + tokenizer.encode(x)) for x in class_text_map.values()]
    )
    elif gen_type == GenerationType.OPTION_ID:
    class_text_map = {}
    for i, raw_name in enumerate(class_names):
    class_text_map[raw_name] = string.ascii_uppercase[i]
    max_decoding_length = 2
    else:
    raise ValueError(f"Non-existent `gen_type`: {gen_type}")

    task_prefix = get_task_prefix(gen_type, class_text_map)
    return class_text_map, max_decoding_length, task_prefix


    def convert_to_features(
    example_batch,
    class_text_map: Mapping[str, str],
    task_prefix: str,
    input_max_length=512,
    label_max_length=16,
    query_key="query",
    label_key="expected_single",
    class_names=None,
    tokenizer=None,
    ):
    q = example_batch[query_key]
    example_batch["input_text"] = f"{task_prefix}{q}"

    input_encodings = tokenizer(
    example_batch["input_text"],
    padding="max_length",
    max_length=input_max_length,
    truncation=True,
    )

    encodings = {
    "inputs": example_batch["input_text"],
    "input_ids": input_encodings["input_ids"],
    "attention_mask": input_encodings["attention_mask"],
    }

    if label_key:
    label = class_text_map[class_names[example_batch[label_key]]]
    example_batch["target_text"] = f"{label}"
    target_encodings = tokenizer(
    example_batch["target_text"],
    padding="max_length",
    max_length=label_max_length,
    truncation=True,
    )
    encodings["labels"] = target_encodings["input_ids"]

    return encodings


    def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels


    def build_prefix_allowed_tokens_fn(allowed_sequences):
    """Returns a function that provides next allowed tokens based on the prefix `seq`."""
    t = MarisaTrie(allowed_sequences)

    def fn(unused_batch_id, seq):
    return t.get(seq)

    return fn


    def process_golden_labels(example_batch, class_text_map, class_names):
    def fn(expected_text):
    return [class_text_map[y.strip()] for y in expected_text.split(",")]

    # example_batch['golden_labels'] = fn(example_batch['expected'])
    example_batch["golden_labels"] = fn(class_names[example_batch["Label"]])
    return example_batch


    def compute_metrics(preds):
    return {}