Created
June 12, 2019 03:21
-
-
Save okapies/ab7c8f413c3bb46c81b4b4e8ba0e7603 to your computer and use it in GitHub Desktop.
Revisions
-
okapies created this gist
Jun 12, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,86 @@ #!/usr/bin/env python import argparse import chainer import chainer.functions as F import chainer.links as L from chainer import training from chainer.training import extensions import numpy as np # Network definition class MLP(chainer.Chain): def __init__(self, n_units, n_out): super(MLP, self).__init__() with self.init_scope(): # the size of the inputs to each layer will be inferred self.l1 = L.Linear(None, n_units) # n_in -> n_units self.l2 = L.Linear(None, n_units) # n_units -> n_units self.l3 = L.Linear(None, n_out) # n_units -> n_out def forward(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--device', '-d', type=str, default='-1', help='Device specifier. Either ChainerX device ' 'specifier or an integer. If non-negative integer, ' 'CuPy arrays with specified device id are used. If ' 'negative integer, NumPy arrays are used') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') args = parser.parse_args() device = chainer.get_device(args.device) print('Device: {}'.format(device)) print('# unit: {}'.format(args.unit)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') # Set up a neural network to train # Classifier reports softmax cross entropy loss and accuracy at every # iteration, which will be used by the PrintReport extension below. model = L.Classifier(MLP(args.unit, 10)) model.to_device(device) device.use() # Setup an optimizer optimizer = chainer.optimizers.Adam() optimizer.setup(model) # Load the MNIST dataset train, _ = chainer.datasets.get_mnist() train = chainer.datasets.TupleDataset( np.stack([train[0][0]]), np.stack([train[0][1]])) train_iter = chainer.iterators.SerialIterator(train, args.batchsize) # Set up a trainer updater = training.updaters.StandardUpdater( train_iter, optimizer, device=device) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) # Run the training trainer.run() if __name__ == '__main__': main()