Skip to content

Instantly share code, notes, and snippets.

@okapies
Created June 12, 2019 03:21
Show Gist options
  • Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.
Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.

Revisions

  1. okapies created this gist Jun 12, 2019.
    86 changes: 86 additions & 0 deletions train_mnist_logreport.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,86 @@
    #!/usr/bin/env python
    import argparse

    import chainer
    import chainer.functions as F
    import chainer.links as L
    from chainer import training
    from chainer.training import extensions

    import numpy as np


    # Network definition
    class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
    super(MLP, self).__init__()
    with self.init_scope():
    # the size of the inputs to each layer will be inferred
    self.l1 = L.Linear(None, n_units) # n_in -> n_units
    self.l2 = L.Linear(None, n_units) # n_units -> n_units
    self.l3 = L.Linear(None, n_out) # n_units -> n_out

    def forward(self, x):
    h1 = F.relu(self.l1(x))
    h2 = F.relu(self.l2(h1))
    return self.l3(h2)


    def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
    help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
    help='Number of sweeps over the dataset to train')
    parser.add_argument('--device', '-d', type=str, default='-1',
    help='Device specifier. Either ChainerX device '
    'specifier or an integer. If non-negative integer, '
    'CuPy arrays with specified device id are used. If '
    'negative integer, NumPy arrays are used')
    parser.add_argument('--out', '-o', default='result',
    help='Directory to output the result')
    parser.add_argument('--unit', '-u', type=int, default=1000,
    help='Number of units')
    args = parser.parse_args()

    device = chainer.get_device(args.device)

    print('Device: {}'.format(device))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))
    model.to_device(device)
    device.use()

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, _ = chainer.datasets.get_mnist()
    train = chainer.datasets.TupleDataset(
    np.stack([train[0][0]]), np.stack([train[0][1]]))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
    train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Run the training
    trainer.run()


    if __name__ == '__main__':
    main()