Skip to content

Instantly share code, notes, and snippets.

@shi510
Last active January 29, 2023 20:29
Show Gist options
  • Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.
Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.

Revisions

  1. shi510 revised this gist Dec 4, 2019. 1 changed file with 34 additions and 0 deletions.
    34 changes: 34 additions & 0 deletions train.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,34 @@
    class OHEMCallback(tf.keras.callbacks.Callback):
    def __init__(self, generator):
    super(OHEMCallback, self).__init__()
    self.generator = generator

    def on_epoch_begin(self, epoch, logs=None):
    self.generator.on_epoch_end()


    def make_dataset(x, y, batch):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.repeat()
    ds = ds.shuffle(size)
    ds = ds.batch(batch)
    ds = ds.prefetch(1024)
    return ds


    def train(x, y, lr=1e-4, batch=512, valid_x, valid_y):
    # ...
    # build your model.
    # mode.compile() ...
    valid_ds = make_dataset(valid_x, valid_y, batch)
    x_gen = HardExampleMiner(model, x, y, batch)
    ohem_callback = OHEMCallback(x_gen)
    model.fit(
    x=x_gen,
    validation_data=valid_ds,
    validation_steps=int(math.floor(len(valid_x) / batch)),
    epochs=1000,
    callbacks=[ohem_callback],
    max_queue_size=256,
    workers=8
    )
  2. shi510 revised this gist Dec 4, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion ohem_generator.py
    Original file line number Diff line number Diff line change
    @@ -31,7 +31,7 @@ def on_epoch_end(self):
    diff = np.abs(outputs - sample_y).reshape(-1)
    self.errors[batch_id*self.batch_size:(batch_id+1) *
    self.batch_size] = diff
    self.hard_idxs = np.argsort(self.errors)
    self.hard_idxs = np.argsort(-self.errors)

    def _slice_batch(self, x, id):
    sample = x[self.batch_size * id: self.batch_size * (id + 1)]
  3. shi510 created this gist Dec 4, 2019.
    38 changes: 38 additions & 0 deletions ohem_generator.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,38 @@
    class HardExampleMiner(tf.keras.utils.Sequence):

    def __init__(self, model, x, y, batch_size, map_fn=None, ratio=0.8):
    self.x = np.array(x, dtype=np.float32)
    self.y = np.array(y, dtype=np.float32)
    self.batch_size = batch_size
    self.model = model
    self.ratio = ratio
    self.num_of_batch = int(math.floor((len(self.x) / self.batch_size)))
    self.hard_idxs = np.arange(self.num_of_batch * self.batch_size)
    self.errors = np.empty((self.num_of_batch * self.batch_size))
    self.sample_x = np.empty((self.batch_size, self.x.shape[1]))
    self.sample_y = np.empty((self.batch_size, self.y.shape[1]))

    def __len__(self):
    return int(self.num_of_batch * self.ratio)

    def __getitem__(self, batch_id):
    start = self.batch_size * batch_id
    end = self.batch_size * (batch_id + 1)
    for seq, idx in enumerate(self.hard_idxs[start:end]):
    self.sample_x[seq,] = self.x[idx]
    self.sample_y[seq,] = self.y[idx]
    return (self.sample_x, self.sample_y)

    def on_epoch_end(self):
    for batch_id in range(self.num_of_batch):
    sample_x = self._slice_batch(self.x, batch_id)
    sample_y = self._slice_batch(self.y, batch_id)
    outputs, _ = self.model.predict_on_batch(sample_x)
    diff = np.abs(outputs - sample_y).reshape(-1)
    self.errors[batch_id*self.batch_size:(batch_id+1) *
    self.batch_size] = diff
    self.hard_idxs = np.argsort(self.errors)

    def _slice_batch(self, x, id):
    sample = x[self.batch_size * id: self.batch_size * (id + 1)]
    return sample.reshape(self.batch_size, -1)