import os import torch import psutil import datasets import glob from transformers import ( AutoTokenizer, LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling ) N_CTX = 512 # dont touch LOG_DIR = f'/home/dylan/Documents/AI/train/logs' OUTPUT_DIR = f'/home/dylan/Documents/AI/train/output' DATA_DIR_350BT = f'/media/dylan/SanDisk/350BT' TOKENIZED_DIR_350BT = f"{DATA_DIR_350BT}/tokenized" # Where to store processed data DATA_DIR_10BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/10BT' TOKENIZED_DIR_10BT = f"{DATA_DIR_10BT}/tokenized" # Where to store processed data DATA_FILE_1BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/1BT/1BT.parquet' TOKENIZED_FILE_1BT = f'{DATA_FILE_1BT}.tokenized' DATA_DIR_WIKITEXT = f'/home/dylan/Documents/AI/datasets/wikitext/wikitext-103-raw-v1' DATA_FILE_EVAL = f'{DATA_DIR_WIKITEXT}/train-00000-of-00002.parquet' TOKENIZED_FILE_EVAL = f"{DATA_FILE_EVAL}.tokenized" def print_used_ram(): memory_info = psutil.virtual_memory() used_ram_gib = memory_info.used / (1024 ** 3) print(f"Used System RAM: {used_ram_gib:.2f} GiB") print(f"Script start.") print_used_ram() print(f"Loading tokenizer ...") tokenizer = AutoTokenizer.from_pretrained('./tokenizer/') tokenizer.pad_token = tokenizer.eos_token data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False # For causal LM ) def tokenize_function(examples): return tokenizer(examples['text'], truncation=True, max_length=N_CTX) n_cpu = os.cpu_count() def dataset_from_parquet(file_path: str) -> datasets.Dataset: if not os.path.exists(file_path): raise FileNotFoundError(f'file {file_path!r} does not exist') if os.path.isdir(file_path): raise IsADirectoryError(f'{file_path!r} is a directory, not a file') print(f'Loading parquet file {file_path!r} ...') ds = datasets.Dataset.from_parquet( path_or_paths=file_path, keep_in_memory=False, # XXX num_proc=n_cpu ) print(f'Finished loading parquet file.') return ds def dataset_from_parquet_dir(dir_path: str) -> datasets.Dataset: if not os.path.exists(dir_path): raise FileNotFoundError(f'directory {dir_path!r} does not exist') if not os.path.isdir(dir_path): raise FileNotFoundError(f'{dir_path!r} is a file, not a directory') abs_dir_path = os.path.abspath(dir_path) file_paths: list[str] = [] print(f'Looking for Parquet files in {abs_dir_path!r}:') for file_name in os.listdir(abs_dir_path): if file_name.endswith('.parquet'): print(f'-- Found {file_name!r}') file_path = os.path.join(abs_dir_path, file_name) file_paths.append(file_path) n_file_paths = len(file_paths) if n_file_paths == 0: raise RuntimeError('No Parquet files were found.') print(f'Loading {n_file_paths} Parquet files ...') ds = datasets.Dataset.from_parquet( path_or_paths=file_paths, keep_in_memory=False, # XXX num_proc=n_cpu ) print(f'Finished loading {n_file_paths} Parquet files.') return ds def get_tokenized_dataset(data_dir: str, tokenized_dir: str) -> datasets.Dataset: """Load or create tokenized dataset with caching""" if os.path.exists(tokenized_dir): print(f"Loading pre-tokenized dataset from {tokenized_dir}") return datasets.load_from_disk(tokenized_dir) print(f"Tokenizing and caching dataset to {tokenized_dir}") raw_dataset = dataset_from_parquet_dir(data_dir) # Tokenize with parallel processing tokenized_dataset = raw_dataset.map( tokenize_function, batched=True, batch_size=1024, num_proc=n_cpu, remove_columns=["text"] ) # Save for future runs tokenized_dataset.save_to_disk(tokenized_dir) return tokenized_dataset def get_tokenized_dataset_file(data_file: str, tokenized_file: str) -> datasets.Dataset: """Load or create tokenized dataset with caching""" if os.path.exists(tokenized_file): print(f"Loading pre-tokenized dataset from {tokenized_file}") return datasets.load_from_disk(tokenized_file) print(f"Tokenizing and caching dataset to {tokenized_file}") raw_dataset = dataset_from_parquet(data_file) # Tokenize with parallel processing tokenized_dataset = raw_dataset.map( tokenize_function, batched=True, batch_size=1024, num_proc=1, remove_columns=["text"] ) # Save for future runs tokenized_dataset.save_to_disk(tokenized_file) return tokenized_dataset # extremely tiny model used for faster testing nano_model_config = LlamaConfig( attention_bias=False, attention_dropout=0.0, bos_token_id=128000, eos_token_id=128001, head_dim=1, hidden_act="gelu", hidden_size=256, initializer_range=0.02, intermediate_size=512, max_position_embeddings=N_CTX, mlp_bias=False, num_attention_heads=1, num_key_value_heads=1, num_hidden_layers=1, rms_norm_eps=1e-05, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10_000.0, rope_scaling=None, use_cache=True, vocab_size=128256 ) # small model used for testing training before training actual model. micro_model_config = LlamaConfig( attention_bias=False, attention_dropout=0.0, bos_token_id=128000, eos_token_id=128001, head_dim=64, hidden_act="gelu", hidden_size=1024, initializer_range=0.02, intermediate_size=1536, max_position_embeddings=N_CTX, mlp_bias=False, num_attention_heads=4, num_key_value_heads=2, num_hidden_layers=6, rms_norm_eps=1e-05, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10_000.0, rope_scaling=None, use_cache=True, vocab_size=128256 ) # actual target model to train - similar arch to llama 3.2 1B model_config = LlamaConfig( attention_bias=False, attention_dropout=0.0, bos_token_id=128000, eos_token_id=128001, head_dim=64, hidden_act="gelu", hidden_size=2048, initializer_range=0.02, intermediate_size=8192, max_position_embeddings=N_CTX, mlp_bias=False, num_attention_heads=32, num_key_value_heads=8, num_hidden_layers=16, rms_norm_eps=1e-05, pretraining_tp=1, tie_word_embeddings=False, rope_theta=100_000.0, rope_scaling=None, use_cache=True, vocab_size=128256 ) def get_latest_checkpoint(output_dir): """Get the path of the latest checkpoint in the output directory.""" checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) if not checkpoints: return None latest_checkpoint = max(checkpoints, key=os.path.getctime) return latest_checkpoint def main() -> int: print(f"Start main") print(f"Init model ...") model = LlamaForCausalLM(model_config) # actual model for real training (largest) print(f"n_params: {model.num_parameters():,}") if torch.cuda.is_available(): print('Using CUDA device') device = torch.device("cuda") device_str = "CUDA device" else: print('Using CPU') device = "cpu" device_str = "CPU" bf16_enabled = torch.cuda.is_available() and torch.cuda.is_bf16_supported() print(f"bf16 enabled: {bf16_enabled}") print(f"Moving model to {device_str} ...") model.to(device, dtype=torch.bfloat16) # For training data print(f"Loading training data ...") training_data = get_tokenized_dataset_file(DATA_FILE_1BT, TOKENIZED_FILE_1BT) training_data.set_format("torch", columns=["input_ids", "attention_mask"]) trainer_args = TrainingArguments( output_dir=OUTPUT_DIR, overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=16, save_steps=64, save_total_limit=16, logging_dir=LOG_DIR, logging_steps=1, eval_strategy="no", learning_rate=2e-5, bf16=bf16_enabled, bf16_full_eval=bf16_enabled, ) print(f"Create Trainer ...") trainer = Trainer( model=model, args=trainer_args, train_dataset=training_data, data_collator=data_collator ) # Check for the latest checkpoint latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR) try: if latest_checkpoint: print(f"Resuming training from checkpoint: {latest_checkpoint}") trainer.train(resume_from_checkpoint=latest_checkpoint) else: print(f"Starting training from scratch ...") trainer.train() print(f"Done training.") except KeyboardInterrupt: print(f"KeyboardInterrupt: Training interrupted by user.") except Exception as e: print(f"Caught exception: {e}") finally: print(f"Save model ...") model.save_pretrained(f"{OUTPUT_DIR}/final", safe_serialization=True) tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") return 0 if __name__ == '__main__': exit(main())