|
|
@@ -0,0 +1,292 @@ |
|
|
import os |
|
|
import torch |
|
|
import psutil |
|
|
import datasets |
|
|
import glob |
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, |
|
|
DataCollatorForLanguageModeling |
|
|
) |
|
|
|
|
|
N_CTX = 512 # dont touch |
|
|
LOG_DIR = f'/home/dylan/Documents/AI/train/logs' |
|
|
OUTPUT_DIR = f'/home/dylan/Documents/AI/train/output' |
|
|
DATA_DIR_350BT = f'/media/dylan/SanDisk/350BT' |
|
|
TOKENIZED_DIR_350BT = f"{DATA_DIR_350BT}/tokenized" # Where to store processed data |
|
|
DATA_DIR_10BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/10BT' |
|
|
TOKENIZED_DIR_10BT = f"{DATA_DIR_10BT}/tokenized" # Where to store processed data |
|
|
DATA_FILE_1BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/1BT/1BT.parquet' |
|
|
TOKENIZED_FILE_1BT = f'{DATA_FILE_1BT}.tokenized' |
|
|
DATA_DIR_WIKITEXT = f'/home/dylan/Documents/AI/datasets/wikitext/wikitext-103-raw-v1' |
|
|
DATA_FILE_EVAL = f'{DATA_DIR_WIKITEXT}/train-00000-of-00002.parquet' |
|
|
TOKENIZED_FILE_EVAL = f"{DATA_FILE_EVAL}.tokenized" |
|
|
|
|
|
def print_used_ram(): |
|
|
memory_info = psutil.virtual_memory() |
|
|
used_ram_gib = memory_info.used / (1024 ** 3) |
|
|
print(f"Used System RAM: {used_ram_gib:.2f} GiB") |
|
|
|
|
|
print(f"Script start.") |
|
|
print_used_ram() |
|
|
|
|
|
print(f"Loading tokenizer ...") |
|
|
tokenizer = AutoTokenizer.from_pretrained('./tokenizer/') |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False # For causal LM |
|
|
) |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer(examples['text'], truncation=True, max_length=N_CTX) |
|
|
|
|
|
n_cpu = os.cpu_count() |
|
|
|
|
|
def dataset_from_parquet(file_path: str) -> datasets.Dataset: |
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f'file {file_path!r} does not exist') |
|
|
if os.path.isdir(file_path): |
|
|
raise IsADirectoryError(f'{file_path!r} is a directory, not a file') |
|
|
|
|
|
print(f'Loading parquet file {file_path!r} ...') |
|
|
ds = datasets.Dataset.from_parquet( |
|
|
path_or_paths=file_path, |
|
|
keep_in_memory=False, # XXX |
|
|
num_proc=n_cpu |
|
|
) |
|
|
|
|
|
print(f'Finished loading parquet file.') |
|
|
return ds |
|
|
|
|
|
def dataset_from_parquet_dir(dir_path: str) -> datasets.Dataset: |
|
|
if not os.path.exists(dir_path): |
|
|
raise FileNotFoundError(f'directory {dir_path!r} does not exist') |
|
|
if not os.path.isdir(dir_path): |
|
|
raise FileNotFoundError(f'{dir_path!r} is a file, not a directory') |
|
|
|
|
|
abs_dir_path = os.path.abspath(dir_path) |
|
|
|
|
|
file_paths: list[str] = [] |
|
|
|
|
|
print(f'Looking for Parquet files in {abs_dir_path!r}:') |
|
|
for file_name in os.listdir(abs_dir_path): |
|
|
if file_name.endswith('.parquet'): |
|
|
print(f'-- Found {file_name!r}') |
|
|
file_path = os.path.join(abs_dir_path, file_name) |
|
|
file_paths.append(file_path) |
|
|
|
|
|
n_file_paths = len(file_paths) |
|
|
if n_file_paths == 0: |
|
|
raise RuntimeError('No Parquet files were found.') |
|
|
|
|
|
print(f'Loading {n_file_paths} Parquet files ...') |
|
|
ds = datasets.Dataset.from_parquet( |
|
|
path_or_paths=file_paths, |
|
|
keep_in_memory=False, # XXX |
|
|
num_proc=n_cpu |
|
|
) |
|
|
|
|
|
print(f'Finished loading {n_file_paths} Parquet files.') |
|
|
return ds |
|
|
|
|
|
def get_tokenized_dataset(data_dir: str, tokenized_dir: str) -> datasets.Dataset: |
|
|
"""Load or create tokenized dataset with caching""" |
|
|
if os.path.exists(tokenized_dir): |
|
|
print(f"Loading pre-tokenized dataset from {tokenized_dir}") |
|
|
return datasets.load_from_disk(tokenized_dir) |
|
|
|
|
|
print(f"Tokenizing and caching dataset to {tokenized_dir}") |
|
|
raw_dataset = dataset_from_parquet_dir(data_dir) |
|
|
|
|
|
# Tokenize with parallel processing |
|
|
tokenized_dataset = raw_dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
batch_size=1024, |
|
|
num_proc=n_cpu, |
|
|
remove_columns=["text"] |
|
|
) |
|
|
|
|
|
# Save for future runs |
|
|
tokenized_dataset.save_to_disk(tokenized_dir) |
|
|
return tokenized_dataset |
|
|
|
|
|
def get_tokenized_dataset_file(data_file: str, tokenized_file: str) -> datasets.Dataset: |
|
|
"""Load or create tokenized dataset with caching""" |
|
|
if os.path.exists(tokenized_file): |
|
|
print(f"Loading pre-tokenized dataset from {tokenized_file}") |
|
|
return datasets.load_from_disk(tokenized_file) |
|
|
|
|
|
print(f"Tokenizing and caching dataset to {tokenized_file}") |
|
|
raw_dataset = dataset_from_parquet(data_file) |
|
|
|
|
|
# Tokenize with parallel processing |
|
|
tokenized_dataset = raw_dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
batch_size=1024, |
|
|
num_proc=1, |
|
|
remove_columns=["text"] |
|
|
) |
|
|
|
|
|
# Save for future runs |
|
|
tokenized_dataset.save_to_disk(tokenized_file) |
|
|
return tokenized_dataset |
|
|
|
|
|
# extremely tiny model used for faster testing |
|
|
nano_model_config = LlamaConfig( |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
bos_token_id=128000, |
|
|
eos_token_id=128001, |
|
|
head_dim=1, |
|
|
hidden_act="gelu", |
|
|
hidden_size=256, |
|
|
initializer_range=0.02, |
|
|
intermediate_size=512, |
|
|
max_position_embeddings=N_CTX, |
|
|
mlp_bias=False, |
|
|
num_attention_heads=1, |
|
|
num_key_value_heads=1, |
|
|
num_hidden_layers=1, |
|
|
rms_norm_eps=1e-05, |
|
|
pretraining_tp=1, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=10_000.0, |
|
|
rope_scaling=None, |
|
|
use_cache=True, |
|
|
vocab_size=128256 |
|
|
) |
|
|
|
|
|
# small model used for testing training before training actual model. |
|
|
micro_model_config = LlamaConfig( |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
bos_token_id=128000, |
|
|
eos_token_id=128001, |
|
|
head_dim=64, |
|
|
hidden_act="gelu", |
|
|
hidden_size=1024, |
|
|
initializer_range=0.02, |
|
|
intermediate_size=1536, |
|
|
max_position_embeddings=N_CTX, |
|
|
mlp_bias=False, |
|
|
num_attention_heads=4, |
|
|
num_key_value_heads=2, |
|
|
num_hidden_layers=6, |
|
|
rms_norm_eps=1e-05, |
|
|
pretraining_tp=1, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=10_000.0, |
|
|
rope_scaling=None, |
|
|
use_cache=True, |
|
|
vocab_size=128256 |
|
|
) |
|
|
|
|
|
# actual target model to train - similar arch to llama 3.2 1B |
|
|
model_config = LlamaConfig( |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
bos_token_id=128000, |
|
|
eos_token_id=128001, |
|
|
head_dim=64, |
|
|
hidden_act="gelu", |
|
|
hidden_size=2048, |
|
|
initializer_range=0.02, |
|
|
intermediate_size=8192, |
|
|
max_position_embeddings=N_CTX, |
|
|
mlp_bias=False, |
|
|
num_attention_heads=32, |
|
|
num_key_value_heads=8, |
|
|
num_hidden_layers=16, |
|
|
rms_norm_eps=1e-05, |
|
|
pretraining_tp=1, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=100_000.0, |
|
|
rope_scaling=None, |
|
|
use_cache=True, |
|
|
vocab_size=128256 |
|
|
) |
|
|
|
|
|
def get_latest_checkpoint(output_dir): |
|
|
"""Get the path of the latest checkpoint in the output directory.""" |
|
|
checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) |
|
|
if not checkpoints: |
|
|
return None |
|
|
latest_checkpoint = max(checkpoints, key=os.path.getctime) |
|
|
return latest_checkpoint |
|
|
|
|
|
def main() -> int: |
|
|
print(f"Start main") |
|
|
print(f"Init model ...") |
|
|
model = LlamaForCausalLM(model_config) # actual model for real training (largest) |
|
|
print(f"n_params: {model.num_parameters():,}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print('Using CUDA device') |
|
|
device = torch.device("cuda") |
|
|
device_str = "CUDA device" |
|
|
else: |
|
|
print('Using CPU') |
|
|
device = "cpu" |
|
|
device_str = "CPU" |
|
|
|
|
|
bf16_enabled = torch.cuda.is_available() and torch.cuda.is_bf16_supported() |
|
|
print(f"bf16 enabled: {bf16_enabled}") |
|
|
|
|
|
print(f"Moving model to {device_str} ...") |
|
|
model.to(device, dtype=torch.bfloat16) |
|
|
|
|
|
# For training data |
|
|
print(f"Loading training data ...") |
|
|
training_data = get_tokenized_dataset_file(DATA_FILE_1BT, TOKENIZED_FILE_1BT) |
|
|
training_data.set_format("torch", columns=["input_ids", "attention_mask"]) |
|
|
|
|
|
trainer_args = TrainingArguments( |
|
|
output_dir=OUTPUT_DIR, |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=1, |
|
|
per_device_train_batch_size=2, |
|
|
gradient_accumulation_steps=16, |
|
|
save_steps=64, |
|
|
save_total_limit=16, |
|
|
logging_dir=LOG_DIR, |
|
|
logging_steps=1, |
|
|
eval_strategy="no", |
|
|
learning_rate=2e-5, |
|
|
bf16=bf16_enabled, |
|
|
bf16_full_eval=bf16_enabled, |
|
|
) |
|
|
|
|
|
print(f"Create Trainer ...") |
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=trainer_args, |
|
|
train_dataset=training_data, |
|
|
data_collator=data_collator |
|
|
) |
|
|
|
|
|
# Check for the latest checkpoint |
|
|
latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR) |
|
|
try: |
|
|
if latest_checkpoint: |
|
|
print(f"Resuming training from checkpoint: {latest_checkpoint}") |
|
|
trainer.train(resume_from_checkpoint=latest_checkpoint) |
|
|
else: |
|
|
print(f"Starting training from scratch ...") |
|
|
trainer.train() |
|
|
print(f"Done training.") |
|
|
except KeyboardInterrupt: |
|
|
print(f"KeyboardInterrupt: Training interrupted by user.") |
|
|
except Exception as e: |
|
|
print(f"Caught exception: {e}") |
|
|
finally: |
|
|
print(f"Save model ...") |
|
|
model.save_pretrained(f"{OUTPUT_DIR}/final", safe_serialization=True) |
|
|
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") |
|
|
return 0 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
exit(main()) |