Skip to content

Instantly share code, notes, and snippets.

@w32zhong
Forked from ddh0/train.py
Created May 17, 2025 19:42
Show Gist options
  • Save w32zhong/f294c3620feb6c87b149e415b4d63d4b to your computer and use it in GitHub Desktop.
Save w32zhong/f294c3620feb6c87b149e415b4d63d4b to your computer and use it in GitHub Desktop.

Revisions

  1. @ddh0 ddh0 created this gist Feb 6, 2025.
    292 changes: 292 additions & 0 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,292 @@
    import os
    import torch
    import psutil
    import datasets
    import glob

    from transformers import (
    AutoTokenizer, LlamaConfig, LlamaForCausalLM, Trainer, TrainingArguments,
    DataCollatorForLanguageModeling
    )

    N_CTX = 512 # dont touch
    LOG_DIR = f'/home/dylan/Documents/AI/train/logs'
    OUTPUT_DIR = f'/home/dylan/Documents/AI/train/output'
    DATA_DIR_350BT = f'/media/dylan/SanDisk/350BT'
    TOKENIZED_DIR_350BT = f"{DATA_DIR_350BT}/tokenized" # Where to store processed data
    DATA_DIR_10BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/10BT'
    TOKENIZED_DIR_10BT = f"{DATA_DIR_10BT}/tokenized" # Where to store processed data
    DATA_FILE_1BT = f'/home/dylan/Documents/AI/datasets/fineweb/sample/1BT/1BT.parquet'
    TOKENIZED_FILE_1BT = f'{DATA_FILE_1BT}.tokenized'
    DATA_DIR_WIKITEXT = f'/home/dylan/Documents/AI/datasets/wikitext/wikitext-103-raw-v1'
    DATA_FILE_EVAL = f'{DATA_DIR_WIKITEXT}/train-00000-of-00002.parquet'
    TOKENIZED_FILE_EVAL = f"{DATA_FILE_EVAL}.tokenized"

    def print_used_ram():
    memory_info = psutil.virtual_memory()
    used_ram_gib = memory_info.used / (1024 ** 3)
    print(f"Used System RAM: {used_ram_gib:.2f} GiB")

    print(f"Script start.")
    print_used_ram()

    print(f"Loading tokenizer ...")
    tokenizer = AutoTokenizer.from_pretrained('./tokenizer/')
    tokenizer.pad_token = tokenizer.eos_token

    data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False # For causal LM
    )

    def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, max_length=N_CTX)

    n_cpu = os.cpu_count()

    def dataset_from_parquet(file_path: str) -> datasets.Dataset:
    if not os.path.exists(file_path):
    raise FileNotFoundError(f'file {file_path!r} does not exist')
    if os.path.isdir(file_path):
    raise IsADirectoryError(f'{file_path!r} is a directory, not a file')

    print(f'Loading parquet file {file_path!r} ...')
    ds = datasets.Dataset.from_parquet(
    path_or_paths=file_path,
    keep_in_memory=False, # XXX
    num_proc=n_cpu
    )

    print(f'Finished loading parquet file.')
    return ds

    def dataset_from_parquet_dir(dir_path: str) -> datasets.Dataset:
    if not os.path.exists(dir_path):
    raise FileNotFoundError(f'directory {dir_path!r} does not exist')
    if not os.path.isdir(dir_path):
    raise FileNotFoundError(f'{dir_path!r} is a file, not a directory')

    abs_dir_path = os.path.abspath(dir_path)

    file_paths: list[str] = []

    print(f'Looking for Parquet files in {abs_dir_path!r}:')
    for file_name in os.listdir(abs_dir_path):
    if file_name.endswith('.parquet'):
    print(f'-- Found {file_name!r}')
    file_path = os.path.join(abs_dir_path, file_name)
    file_paths.append(file_path)

    n_file_paths = len(file_paths)
    if n_file_paths == 0:
    raise RuntimeError('No Parquet files were found.')

    print(f'Loading {n_file_paths} Parquet files ...')
    ds = datasets.Dataset.from_parquet(
    path_or_paths=file_paths,
    keep_in_memory=False, # XXX
    num_proc=n_cpu
    )

    print(f'Finished loading {n_file_paths} Parquet files.')
    return ds

    def get_tokenized_dataset(data_dir: str, tokenized_dir: str) -> datasets.Dataset:
    """Load or create tokenized dataset with caching"""
    if os.path.exists(tokenized_dir):
    print(f"Loading pre-tokenized dataset from {tokenized_dir}")
    return datasets.load_from_disk(tokenized_dir)

    print(f"Tokenizing and caching dataset to {tokenized_dir}")
    raw_dataset = dataset_from_parquet_dir(data_dir)

    # Tokenize with parallel processing
    tokenized_dataset = raw_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=1024,
    num_proc=n_cpu,
    remove_columns=["text"]
    )

    # Save for future runs
    tokenized_dataset.save_to_disk(tokenized_dir)
    return tokenized_dataset

    def get_tokenized_dataset_file(data_file: str, tokenized_file: str) -> datasets.Dataset:
    """Load or create tokenized dataset with caching"""
    if os.path.exists(tokenized_file):
    print(f"Loading pre-tokenized dataset from {tokenized_file}")
    return datasets.load_from_disk(tokenized_file)

    print(f"Tokenizing and caching dataset to {tokenized_file}")
    raw_dataset = dataset_from_parquet(data_file)

    # Tokenize with parallel processing
    tokenized_dataset = raw_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=1024,
    num_proc=1,
    remove_columns=["text"]
    )

    # Save for future runs
    tokenized_dataset.save_to_disk(tokenized_file)
    return tokenized_dataset

    # extremely tiny model used for faster testing
    nano_model_config = LlamaConfig(
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=128000,
    eos_token_id=128001,
    head_dim=1,
    hidden_act="gelu",
    hidden_size=256,
    initializer_range=0.02,
    intermediate_size=512,
    max_position_embeddings=N_CTX,
    mlp_bias=False,
    num_attention_heads=1,
    num_key_value_heads=1,
    num_hidden_layers=1,
    rms_norm_eps=1e-05,
    pretraining_tp=1,
    tie_word_embeddings=False,
    rope_theta=10_000.0,
    rope_scaling=None,
    use_cache=True,
    vocab_size=128256
    )

    # small model used for testing training before training actual model.
    micro_model_config = LlamaConfig(
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=128000,
    eos_token_id=128001,
    head_dim=64,
    hidden_act="gelu",
    hidden_size=1024,
    initializer_range=0.02,
    intermediate_size=1536,
    max_position_embeddings=N_CTX,
    mlp_bias=False,
    num_attention_heads=4,
    num_key_value_heads=2,
    num_hidden_layers=6,
    rms_norm_eps=1e-05,
    pretraining_tp=1,
    tie_word_embeddings=False,
    rope_theta=10_000.0,
    rope_scaling=None,
    use_cache=True,
    vocab_size=128256
    )

    # actual target model to train - similar arch to llama 3.2 1B
    model_config = LlamaConfig(
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=128000,
    eos_token_id=128001,
    head_dim=64,
    hidden_act="gelu",
    hidden_size=2048,
    initializer_range=0.02,
    intermediate_size=8192,
    max_position_embeddings=N_CTX,
    mlp_bias=False,
    num_attention_heads=32,
    num_key_value_heads=8,
    num_hidden_layers=16,
    rms_norm_eps=1e-05,
    pretraining_tp=1,
    tie_word_embeddings=False,
    rope_theta=100_000.0,
    rope_scaling=None,
    use_cache=True,
    vocab_size=128256
    )

    def get_latest_checkpoint(output_dir):
    """Get the path of the latest checkpoint in the output directory."""
    checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
    if not checkpoints:
    return None
    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    return latest_checkpoint

    def main() -> int:
    print(f"Start main")
    print(f"Init model ...")
    model = LlamaForCausalLM(model_config) # actual model for real training (largest)
    print(f"n_params: {model.num_parameters():,}")

    if torch.cuda.is_available():
    print('Using CUDA device')
    device = torch.device("cuda")
    device_str = "CUDA device"
    else:
    print('Using CPU')
    device = "cpu"
    device_str = "CPU"

    bf16_enabled = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    print(f"bf16 enabled: {bf16_enabled}")

    print(f"Moving model to {device_str} ...")
    model.to(device, dtype=torch.bfloat16)

    # For training data
    print(f"Loading training data ...")
    training_data = get_tokenized_dataset_file(DATA_FILE_1BT, TOKENIZED_FILE_1BT)
    training_data.set_format("torch", columns=["input_ids", "attention_mask"])

    trainer_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    save_steps=64,
    save_total_limit=16,
    logging_dir=LOG_DIR,
    logging_steps=1,
    eval_strategy="no",
    learning_rate=2e-5,
    bf16=bf16_enabled,
    bf16_full_eval=bf16_enabled,
    )

    print(f"Create Trainer ...")
    trainer = Trainer(
    model=model,
    args=trainer_args,
    train_dataset=training_data,
    data_collator=data_collator
    )

    # Check for the latest checkpoint
    latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR)
    try:
    if latest_checkpoint:
    print(f"Resuming training from checkpoint: {latest_checkpoint}")
    trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
    print(f"Starting training from scratch ...")
    trainer.train()
    print(f"Done training.")
    except KeyboardInterrupt:
    print(f"KeyboardInterrupt: Training interrupted by user.")
    except Exception as e:
    print(f"Caught exception: {e}")
    finally:
    print(f"Save model ...")
    model.save_pretrained(f"{OUTPUT_DIR}/final", safe_serialization=True)
    tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
    return 0

    if __name__ == '__main__':
    exit(main())