import marimo __generated_with = "0.14.17" app = marimo.App() @app.cell def _(): import marimo as mo import matplotlib.pylab as plt from mofresh import refresh_matplotlib, ImageRefreshWidget return ImageRefreshWidget, mo, plt, refresh_matplotlib @app.cell def _(): import unsloth import os import torch import pandas as pd import numpy as np from trl import SFTConfig, SFTTrainer from unsloth import FastLanguageModel from dotenv import load_dotenv load_dotenv() return FastLanguageModel, SFTConfig, SFTTrainer, os, pd, torch @app.cell(hide_code=True) def _(mo): mo.md(r"""# Model Config""") return @app.cell def _(mo): MODEL_NAME = mo.ui.text(value="unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", label="Model name", full_width=True) MODEL_IS_4BIT = mo.ui.checkbox(value=True, label='Load in 4bit') MAX_SEQ_LENGTH = mo.ui.number(value=2048, start=64, stop=131072, step=1, label="Max sequence length") OUTPUT_MODEL_NAME = mo.ui.text(value="Luau-Qwen3-4B-Instruct-v0.1", label="Output model name", full_width=True) mo.vstack([ MODEL_NAME, MODEL_IS_4BIT, MAX_SEQ_LENGTH, OUTPUT_MODEL_NAME, ]) return MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, OUTPUT_MODEL_NAME @app.cell def _(FastLanguageModel, MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, os): # Load model model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME.value, max_seq_length=MAX_SEQ_LENGTH.value, load_in_4bit=MODEL_IS_4BIT.value, dtype=None, # None for auto detection token=os.getenv("HF_TOKEN"), ) # Do model patching and add fast LoRA weights model = FastLanguageModel.get_peft_model( model, r=64, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=64, lora_dropout=0, # Dropout = 0 is currently optimized bias="none", # Bias = "none" is currently optimized use_gradient_checkpointing="unsloth", random_state=3407, ) return model, tokenizer @app.cell(hide_code=True) def _(mo): mo.md(r"""# Training Data""") return @app.cell def _(mo): num_datasets = mo.ui.number(value=2, label="Number of datasets") num_datasets return (num_datasets,) @app.cell def _(mo, num_datasets): dataset_names = mo.ui.array([ mo.ui.text(placeholder="Dataset name", full_width=True) for _ in range(num_datasets.value) ]) dataset_names return (dataset_names,) @app.cell def _(dataset_names, mo): dataset_formatters = mo.ui.array([ mo.ui.text_area(placeholder="Dataset format string", label=dataset_name, full_width=True, rows=10) for dataset_name in dataset_names.value ]) dataset_formatters return (dataset_formatters,) @app.cell def _(dataset_formatters, dataset_names, os, pd, tokenizer): from datasets import load_dataset, concatenate_datasets datasets = [] for i in range(len(dataset_names.value)): dataset_name = dataset_names.value[i] dataset_formatter = dataset_formatters.value[i] + tokenizer.eos_token datasets.append(load_dataset( dataset_name, split="train", token=os.getenv("HF_TOKEN"), ).map( lambda example: { "text": dataset_formatter.format(**example), "source": dataset_name, }, batched=False, ).select_columns([ "text", "source" ])) dataset = concatenate_datasets(datasets).shuffle(seed=42) pd.DataFrame(dataset) return (dataset,) @app.cell(hide_code=True) def _(mo): mo.md(r"""# Training""") return @app.cell(hide_code=True) def _(mo): LOGGING_STEPS = mo.ui.number(value=1, start=1, stop=1000, step=1, label="Logging steps") SAVE_STEPS = mo.ui.number(value=100, start=1, stop=1000, step=1, label="Save steps") EPOCHS = mo.ui.number(value=2, start=0.1, stop=5, step=0.1, label="Epochs") MAX_STEPS = mo.ui.number(value=-1, start=-1, stop=50000, step=1, label="Max steps (-1 for unlimited)") RESUME_FROM_CHECKPOINT = mo.ui.checkbox(value=True, label="Resume training from latest checkpoint") run_button = mo.ui.run_button(label="Start training") mo.vstack([ LOGGING_STEPS, SAVE_STEPS, EPOCHS, MAX_STEPS, RESUME_FROM_CHECKPOINT, run_button, ]) return ( EPOCHS, LOGGING_STEPS, MAX_STEPS, RESUME_FROM_CHECKPOINT, SAVE_STEPS, run_button, ) @app.cell(hide_code=True) def _(mo, torch): gpu_stats = torch.cuda.get_device_properties(0) start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) mo.hstack([ mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name), mo.stat(value=f"{start_gpu_memory} GB", label="VRAM Reserved"), ],justify="center") return gpu_stats, max_memory, start_gpu_memory @app.cell def _( EPOCHS, LOGGING_STEPS, MAX_SEQ_LENGTH, MAX_STEPS, SAVE_STEPS, SFTConfig, mo, run_button, ): mo.stop(not run_button.value) training_args = SFTConfig( output_dir="./outputs/checkpoints", max_length=MAX_SEQ_LENGTH.value, logging_steps=LOGGING_STEPS.value, torch_empty_cache_steps=150, save_strategy = "steps", save_steps=SAVE_STEPS.value, # Allows us to resume training from latest checkpoint learning_rate = 2e-5, # Reduce to 2e-5 for long training runs dataset_num_proc=1, num_train_epochs=EPOCHS.value, max_steps=MAX_STEPS.value, per_device_train_batch_size=1, gradient_accumulation_steps=8, # Use GA to mimic batch size! ) return (training_args,) @app.cell def _(plt, refresh_matplotlib): @refresh_matplotlib def loss_linechart(step_loss: dict[int, float]): x_values = [step for step in step_loss.keys()] y_values = [loss for loss in step_loss.values()] plt.plot(x_values, y_values) plt.xlabel("Step") plt.ylabel("Loss") plt.ylim(min(0.2, min(y_values)), max(y_values) * 1.05) plt.title("Training Loss Curve") return (loss_linechart,) @app.cell def _(ImageRefreshWidget, loss_linechart): widget = ImageRefreshWidget(src=loss_linechart({1: 1})) widget return (widget,) @app.cell def _( RESUME_FROM_CHECKPOINT, SFTTrainer, dataset, loss_linechart, mo, model, os, run_button, training_args, widget, ): mo.stop(not run_button.value) from transformers import TrainerCallback class PlotLogs(TrainerCallback): loss_history: dict[int, float] = {} def on_log(self, args, state, control, logs=None, **kwargs): if logs is not None and "loss" in logs: self.loss_history[state.global_step] = logs["loss"] widget.src = loss_linechart(self.loss_history) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, callbacks=[PlotLogs()], ) os.environ["UNSLOTH_RETURN_LOGITS"] = "1" trainer_stats = trainer.train(resume_from_checkpoint = RESUME_FROM_CHECKPOINT.value) return (trainer_stats,) @app.cell(hide_code=True) def _( gpu_stats, max_memory, mo, run_button, start_gpu_memory, torch, trainer_stats, ): mo.stop(not run_button.value) used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) used_memory_for_lora = round(used_memory - start_gpu_memory, 3) used_percentage = round(used_memory / max_memory * 100, 3) lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) mo.hstack([ mo.stat(value=f"{round(trainer_stats.metrics['train_runtime'] / 60, 1)} minutes", label="Training runtime"), mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name), mo.stat(value=f"{used_memory} GB", label="Peak VRAM Reserved"), mo.stat(value=f"{used_memory_for_lora} GB", label="Peak VRAM Reserved For Training"), ],justify="center") return @app.cell def _(MODEL_IS_4BIT, OUTPUT_MODEL_NAME, mo, model, run_button, tokenizer): mo.stop(not run_button.value) model.save_pretrained_merged(f"outputs/{OUTPUT_MODEL_NAME.value}", tokenizer, save_method = "merge_4bit_forced" if MODEL_IS_4BIT.value else "merged_16bit",) return if __name__ == "__main__": app.run()