model = LightningMNISTClassifier(lr_rate=1e-3) # Learning Rate Logger lr_logger = LearningRateLogger() # Set Early Stopping early_stopping = EarlyStopping('val_loss', mode='min', patience=5) # saves checkpoints to 'model_path' whenever 'val_loss' has a new min checkpoint_callback = ModelCheckpoint(filepath=model_path+'mnist_{epoch}-{val_loss:.2f}', monitor='val_loss', mode='min', save_top_k=3) trainer = pl.Trainer(max_epochs=30, profiler=True, callbacks=[lr_logger], early_stop_callback=early_stopping, checkpoint_callback=checkpoint_callback, default_root_dir=model_path) #gpus=1 trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)