Skip to content

Instantly share code, notes, and snippets.

@rubentea16
Created June 27, 2020 11:55
Show Gist options
  • Save rubentea16/005718aa1943826dc2641cb93096ef5f to your computer and use it in GitHub Desktop.
Save rubentea16/005718aa1943826dc2641cb93096ef5f to your computer and use it in GitHub Desktop.
model = LightningMNISTClassifier(lr_rate=1e-3)
# Learning Rate Logger
lr_logger = LearningRateLogger()
# Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath=model_path+'mnist_{epoch}-{val_loss:.2f}',
monitor='val_loss', mode='min', save_top_k=3)
trainer = pl.Trainer(max_epochs=30, profiler=True, callbacks=[lr_logger],
early_stop_callback=early_stopping, checkpoint_callback=checkpoint_callback,
default_root_dir=model_path) #gpus=1
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment