Skip to content

Instantly share code, notes, and snippets.

@rubentea16
Created June 27, 2020 11:55
Show Gist options
  • Save rubentea16/005718aa1943826dc2641cb93096ef5f to your computer and use it in GitHub Desktop.
Save rubentea16/005718aa1943826dc2641cb93096ef5f to your computer and use it in GitHub Desktop.

Revisions

  1. rubentea16 created this gist Jun 27, 2020.
    17 changes: 17 additions & 0 deletions train_model_pl_mnist.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,17 @@
    model = LightningMNISTClassifier(lr_rate=1e-3)

    # Learning Rate Logger
    lr_logger = LearningRateLogger()

    # Set Early Stopping
    early_stopping = EarlyStopping('val_loss', mode='min', patience=5)

    # saves checkpoints to 'model_path' whenever 'val_loss' has a new min
    checkpoint_callback = ModelCheckpoint(filepath=model_path+'mnist_{epoch}-{val_loss:.2f}',
    monitor='val_loss', mode='min', save_top_k=3)

    trainer = pl.Trainer(max_epochs=30, profiler=True, callbacks=[lr_logger],
    early_stop_callback=early_stopping, checkpoint_callback=checkpoint_callback,
    default_root_dir=model_path) #gpus=1

    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)