Created
          July 7, 2020 10:08 
        
      - 
      
- 
        Save LTNMinh/bd1d28e5c1c77de934cdf4b424c02751 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import pkbar | |
| import pytorch_lightning as pl | |
| class PKProgressBar(pl.callbacks.progress.ProgressBarBase): | |
| """ | |
| This is a custom progress bar keras style for pytorch-lightning framework | |
| Requirement package pkbar. You could install it by pip: | |
| !pip install pkbar | |
| Usage: | |
| bar = LitProgressBar() | |
| trainer = pl.Trainer(callbacks=[bar],max_epochs=100) | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.enable = True | |
| self.kbar = None | |
| def disable(self): | |
| self.enable = False | |
| def on_epoch_start(self,trainer, pl_module): | |
| super().on_epoch_start(trainer, pl_module) | |
| print('Epoch: {}'.format(trainer.current_epoch)) | |
| self.kbar = pkbar.Kbar(target=self.total_train_batches, width=20) | |
| def on_batch_end(self, trainer, pl_module): | |
| super().on_batch_end(trainer, pl_module) | |
| visualize = [ (k,np.float64(v)) for k,v in list(trainer.progress_bar_dict.items())] | |
| self.kbar.update(self.train_batch_idx, values = visualize) | |
| def on_validation_start(self,trainer, pl_module): | |
| super().on_validation_start(trainer, pl_module) | |
| self.kbar = pkbar.Kbar(target=self.total_val_batches, width=22) | |
| def on_validation_batch_end(self,trainer, pl_module): | |
| super().on_validation_batch_end(trainer, pl_module) | |
| visualize = [ (k,float(v)) for k,v in list(trainer.progress_bar_dict.items())] | |
| if self.kbar: | |
| self.kbar.update(self.val_batch_idx - 1, values = visualize) | |
| def on_validation_end(self,trainer, pl_module): | |
| super().on_validation_end(trainer, pl_module) | |
| visualize = [ (k,float(v)) for k,v in list(trainer.progress_bar_dict.items())] | |
| self.kbar.add(1, values = visualize) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment