Skip to content

Instantly share code, notes, and snippets.

@LTNMinh
Created July 7, 2020 10:08
Show Gist options
  • Save LTNMinh/bd1d28e5c1c77de934cdf4b424c02751 to your computer and use it in GitHub Desktop.
Save LTNMinh/bd1d28e5c1c77de934cdf4b424c02751 to your computer and use it in GitHub Desktop.
import pkbar
import pytorch_lightning as pl
class PKProgressBar(pl.callbacks.progress.ProgressBarBase):
"""
This is a custom progress bar keras style for pytorch-lightning framework
Requirement package pkbar. You could install it by pip:
!pip install pkbar
Usage:
bar = LitProgressBar()
trainer = pl.Trainer(callbacks=[bar],max_epochs=100)
"""
def __init__(self):
super().__init__()
self.enable = True
self.kbar = None
def disable(self):
self.enable = False
def on_epoch_start(self,trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
print('Epoch: {}'.format(trainer.current_epoch))
self.kbar = pkbar.Kbar(target=self.total_train_batches, width=20)
def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
visualize = [ (k,np.float64(v)) for k,v in list(trainer.progress_bar_dict.items())]
self.kbar.update(self.train_batch_idx, values = visualize)
def on_validation_start(self,trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self.kbar = pkbar.Kbar(target=self.total_val_batches, width=22)
def on_validation_batch_end(self,trainer, pl_module):
super().on_validation_batch_end(trainer, pl_module)
visualize = [ (k,float(v)) for k,v in list(trainer.progress_bar_dict.items())]
if self.kbar:
self.kbar.update(self.val_batch_idx - 1, values = visualize)
def on_validation_end(self,trainer, pl_module):
super().on_validation_end(trainer, pl_module)
visualize = [ (k,float(v)) for k,v in list(trainer.progress_bar_dict.items())]
self.kbar.add(1, values = visualize)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment