Last active
          January 29, 2023 20:29 
        
      - 
      
 - 
        
Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.  
Revisions
- 
        
shi510 revised this gist
Dec 4, 2019 . 1 changed file with 34 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,34 @@ class OHEMCallback(tf.keras.callbacks.Callback): def __init__(self, generator): super(OHEMCallback, self).__init__() self.generator = generator def on_epoch_begin(self, epoch, logs=None): self.generator.on_epoch_end() def make_dataset(x, y, batch): ds = tf.data.Dataset.from_tensor_slices((x, y)) ds = ds.repeat() ds = ds.shuffle(size) ds = ds.batch(batch) ds = ds.prefetch(1024) return ds def train(x, y, lr=1e-4, batch=512, valid_x, valid_y): # ... # build your model. # mode.compile() ... valid_ds = make_dataset(valid_x, valid_y, batch) x_gen = HardExampleMiner(model, x, y, batch) ohem_callback = OHEMCallback(x_gen) model.fit( x=x_gen, validation_data=valid_ds, validation_steps=int(math.floor(len(valid_x) / batch)), epochs=1000, callbacks=[ohem_callback], max_queue_size=256, workers=8 )  - 
        
shi510 revised this gist
Dec 4, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -31,7 +31,7 @@ def on_epoch_end(self): diff = np.abs(outputs - sample_y).reshape(-1) self.errors[batch_id*self.batch_size:(batch_id+1) * self.batch_size] = diff self.hard_idxs = np.argsort(-self.errors) def _slice_batch(self, x, id): sample = x[self.batch_size * id: self.batch_size * (id + 1)]  - 
        
shi510 created this gist
Dec 4, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,38 @@ class HardExampleMiner(tf.keras.utils.Sequence): def __init__(self, model, x, y, batch_size, map_fn=None, ratio=0.8): self.x = np.array(x, dtype=np.float32) self.y = np.array(y, dtype=np.float32) self.batch_size = batch_size self.model = model self.ratio = ratio self.num_of_batch = int(math.floor((len(self.x) / self.batch_size))) self.hard_idxs = np.arange(self.num_of_batch * self.batch_size) self.errors = np.empty((self.num_of_batch * self.batch_size)) self.sample_x = np.empty((self.batch_size, self.x.shape[1])) self.sample_y = np.empty((self.batch_size, self.y.shape[1])) def __len__(self): return int(self.num_of_batch * self.ratio) def __getitem__(self, batch_id): start = self.batch_size * batch_id end = self.batch_size * (batch_id + 1) for seq, idx in enumerate(self.hard_idxs[start:end]): self.sample_x[seq,] = self.x[idx] self.sample_y[seq,] = self.y[idx] return (self.sample_x, self.sample_y) def on_epoch_end(self): for batch_id in range(self.num_of_batch): sample_x = self._slice_batch(self.x, batch_id) sample_y = self._slice_batch(self.y, batch_id) outputs, _ = self.model.predict_on_batch(sample_x) diff = np.abs(outputs - sample_y).reshape(-1) self.errors[batch_id*self.batch_size:(batch_id+1) * self.batch_size] = diff self.hard_idxs = np.argsort(self.errors) def _slice_batch(self, x, id): sample = x[self.batch_size * id: self.batch_size * (id + 1)] return sample.reshape(self.batch_size, -1)