Created
August 15, 2023 00:41
-
-
Save napsternxg/86f3e1238ea66e12a39919687c085995 to your computer and use it in GitHub Desktop.
Revisions
-
napsternxg created this gist
Aug 15, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,191 @@ import functools import pandas as pd import torch import transformers from accelerate import Accelerator from datasets import Dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm from t5_training_utils import ( GenerationType, build_prefix_allowed_tokens_fn, convert_to_features, get_gen_type_attributes, get_model_full_name, get_prediction_name, ) torch_dtype = "auto" model_ckpt = "t5-base" gen_type = GenerationType.ALL_TOKENS input_max_length = 512 label_max_length = 6 use_task_prefix = True class_names = [ "Soccer", "Cricket", "Handball", "Snow Cycling", ] non_eligible_classes = { "Snow Cycling" } non_eligible_idx = [ i for i, c in enumerate(class_names) if c in non_eligible_classes ] num_classes = len(class_names) # Model training ### Uncomment a config section for the model type ## For small test run train_batch_size = 8 eval_batch_size = 8 epochs = 30 save_every_k_epochs = 5 seed = 3333 torch.manual_seed(seed) logging_steps = 100 # len(squad["train"]) // batch_size eval_step = 100 learning_rate = 2e-5 weight_decay = 0.01 data_version = "guidelines-fixed-occasion" model_full_name = get_model_full_name(model_ckpt, gen_type, epochs, data_version) def get_model(model_local_ckpt): model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_local_ckpt) return model.eval() def get_dataset( data_path, tokenizer, class_text_map, task_prefix, accelerator, nrows=1024, offset=0 ): df_data = pd.read_csv(data_path, sep="\t", nrows=nrows + offset).rename( columns={"query": "text"} ) df_data = df_data.iloc[offset : offset + nrows] print(df_data) dataset = Dataset.from_pandas(df_data) dataset.reset_format() with accelerator.main_process_first(): dataset = dataset.map( functools.partial( convert_to_features, class_text_map=class_text_map, task_prefix=task_prefix, query_key="text", label_key=None, tokenizer=tokenizer, ) ) return dataset, df_data def get_predictions_accelerate(data_path, model_local_ckpt, nrows=1024, offset=0): accelerator = Accelerator() device = accelerator.device tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt) class_text_map, max_decoding_length, task_prefix = get_gen_type_attributes( gen_type, tokenizer, class_names ) task_prefix = task_prefix if use_task_prefix else "" dataset, df_data = get_dataset( data_path, tokenizer, class_text_map, task_prefix, accelerator, nrows=nrows, offset=offset, ) model = get_model(model_local_ckpt) model = model.to(device) allowed_sequences = [[0] + tokenizer.encode(x) for x in class_text_map.values()] dataset.set_format("pt") custom_dataloader = DataLoader( dataset, shuffle=True, batch_size=eval_batch_size, num_workers=4 ) model, custom_dataloader = accelerator.prepare(model, custom_dataloader) preds = [] with torch.no_grad(): for batch in tqdm( custom_dataloader, disable=not accelerator.is_local_main_process ): batch_input_ids = batch["input_ids"].to(device) batch_attention_mask = batch["attention_mask"].to(device) # For DDP models use accelerator.unwrap_model(model).generate(inputs) # Taken from: https://github.com/huggingface/transformers/issues/18974 batch_outs = accelerator.unwrap_model(model).generate( input_ids=batch_input_ids, attention_mask=batch_attention_mask, max_length=max_decoding_length, prefix_allowed_tokens_fn=build_prefix_allowed_tokens_fn( allowed_sequences ), ) batch_outs = accelerator.pad_across_processes( batch_outs, dim=1, pad_index=tokenizer.pad_token_id ) batch_outs = accelerator.gather_for_metrics(batch_outs).cpu().numpy() preds.extend(tokenizer.batch_decode(batch_outs, skip_special_tokens=True)) accelerator.wait_for_everyone() if accelerator.is_main_process: if len(preds) != len(dataset): raise ValueError( f"Predictions and labels have different lengths. preds: {len(preds)} " f"labels: {len(dataset)}" ) pred_col = get_prediction_name(model_full_name) df_data[pred_col] = preds class_text_map_reversed = {val: key for key, val in class_text_map.items()} df_data[pred_col] = df_data[pred_col].apply(lambda x: class_text_map_reversed[x]) # eligible = ~df_test[pred_col].isin(non_eligible_classes) eligible = ~df_data[pred_col].isin( {v for v in non_eligible_classes if v != "Cricket"} ) df_data["eligible_pred"] = eligible output_path = data_path.replace(".tsv", f".predicted.{offset}.{nrows}.tsv") print(df_data) print(f"Writing df_data with predictions to {output_path}") df_data.to_csv(output_path, sep="\t", index=False) return df_data def main(): data_path = "data.tsv" offset = 400_000 nrows = 153 # 600_000 model_local_ckpt = "./model_path/checkpoint-2830" print(data_path) print(nrows) print(model_local_ckpt) get_predictions_accelerate(data_path, model_local_ckpt, nrows=nrows, offset=offset) if __name__ == "__main__": main() This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,230 @@ import enum import os import string from typing import List, Mapping import marisa_trie import torch class GenerationType(enum.Enum): ALL_TOKENS = "all" THREE_TOKENS = "gen3" TWO_TOKENS = "gen2" ONE_TOKEN = "gen1" OPTION_ID = "abcd" def get_model_full_name( model_ckpt, gen_type: GenerationType, epochs: int, data_version: str = "0" ): model_base_name = model_ckpt.split("/")[-1] gen_type = gen_type.value model_full_name = f"{model_base_name}_{gen_type}_ep{epochs}_dt{data_version}" return model_full_name def get_outdir(model_name: str): return os.path.join("./classifers/", model_name) def get_prediction_name(model_name: str): return f"{model_name}_predicted" class MarisaTrie(object): def __init__( self, sequences: List[List[int]] = [], cache_fist_branch=True, max_token_id=256001, ): self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + ( [chr(i) for i in range(65000, max_token_id + 10000)] if max_token_id >= 55000 else [] ) self.char2int = {self.int2char[i]: i for i in range(max_token_id)} self.cache_fist_branch = cache_fist_branch if self.cache_fist_branch: self.zero_iter = list({sequence[0] for sequence in sequences}) assert len(self.zero_iter) == 1 self.first_iter = list({sequence[1] for sequence in sequences}) self.trie = marisa_trie.Trie( "".join([self.int2char[i] for i in sequence]) for sequence in sequences ) def get(self, prefix_sequence: List[int]): if self.cache_fist_branch and len(prefix_sequence) == 0: return self.zero_iter elif ( self.cache_fist_branch and len(prefix_sequence) == 1 and self.zero_iter == prefix_sequence ): return self.first_iter else: key = "".join([self.int2char[i] for i in prefix_sequence]) return list( { self.char2int[e[len(key)]] for e in self.trie.keys(key) if len(e) > len(key) } ) def __iter__(self): for sequence in self.trie.iterkeys(): yield [self.char2int[e] for e in sequence] def __len__(self): return len(self.trie) def __getitem__(self, value): return self.get(value) def map_class_name(tokenizer, class_raw_name, num_tokens=None, delim=" "): if num_tokens is None: return class_raw_name class_raw_words = class_raw_name.split(delim) for i in range(1, len(class_raw_words) + 1): class_name_candidate = " ".join(class_raw_words[:i]) tokens = tokenizer.tokenize(class_name_candidate) if len(tokens) == num_tokens: return class_name_candidate raise ValueError( f"Cannot find class name at the specificed num_tokens: {class_raw_name}, {num_tokens}" ) def create_class_text_map(tokenizer, class_raw_names, num_tokens, delim="_"): res = {} for raw_name in class_raw_names: res[raw_name] = map_class_name(tokenizer, raw_name, num_tokens, delim=delim) return res def get_task_prefix(gen_type, class_text_map): task_prefix = "Classify query intent into one of the following categories: " if ( gen_type == GenerationType.TWO_TOKENS or gen_type == GenerationType.ONE_TOKEN or gen_type == GenerationType.ALL_TOKENS ): classes = [f"'{x}'" for x in class_text_map.values()] task_prefix += ", ".join(classes) task_prefix += ". query: " elif gen_type == GenerationType.OPTION_ID: classes = [ f"{val}: {' '.join(key.split('_')[:-1])}" for key, val in class_text_map.items() ] task_prefix += "\n" + "\n".join(classes) task_prefix += "\nquery: " return task_prefix def get_gen_type_attributes(gen_type, tokenizer, class_names): if gen_type == GenerationType.THREE_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, 3) max_decoding_length = 3 elif gen_type == GenerationType.TWO_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, 2) max_decoding_length = 2 elif gen_type == GenerationType.ONE_TOKEN: class_text_map = create_class_text_map(tokenizer, class_names, 1) max_decoding_length = 1 elif gen_type == GenerationType.ALL_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, None) max_decoding_length = max( [len([0] + tokenizer.encode(x)) for x in class_text_map.values()] ) elif gen_type == GenerationType.OPTION_ID: class_text_map = {} for i, raw_name in enumerate(class_names): class_text_map[raw_name] = string.ascii_uppercase[i] max_decoding_length = 2 else: raise ValueError(f"Non-existent `gen_type`: {gen_type}") task_prefix = get_task_prefix(gen_type, class_text_map) return class_text_map, max_decoding_length, task_prefix def convert_to_features( example_batch, class_text_map: Mapping[str, str], task_prefix: str, input_max_length=512, label_max_length=16, query_key="query", label_key="expected_single", class_names=None, tokenizer=None, ): q = example_batch[query_key] example_batch["input_text"] = f"{task_prefix}{q}" input_encodings = tokenizer( example_batch["input_text"], padding="max_length", max_length=input_max_length, truncation=True, ) encodings = { "inputs": example_batch["input_text"], "input_ids": input_encodings["input_ids"], "attention_mask": input_encodings["attention_mask"], } if label_key: label = class_text_map[class_names[example_batch[label_key]]] example_batch["target_text"] = f"{label}" target_encodings = tokenizer( example_batch["target_text"], padding="max_length", max_length=label_max_length, truncation=True, ) encodings["labels"] = target_encodings["input_ids"] return encodings def preprocess_logits_for_metrics(logits, labels): """ Original Trainer may have a memory leak. This is a workaround to avoid storing too many tensors that are not needed. """ pred_ids = torch.argmax(logits[0], dim=-1) return pred_ids, labels def build_prefix_allowed_tokens_fn(allowed_sequences): """Returns a function that provides next allowed tokens based on the prefix `seq`.""" t = MarisaTrie(allowed_sequences) def fn(unused_batch_id, seq): return t.get(seq) return fn def process_golden_labels(example_batch, class_text_map, class_names): def fn(expected_text): return [class_text_map[y.strip()] for y in expected_text.split(",")] # example_batch['golden_labels'] = fn(example_batch['expected']) example_batch["golden_labels"] = fn(class_names[example_batch["Label"]]) return example_batch def compute_metrics(preds): return {}