logger <- mx.metric.logger$new() mx.callback.plot.train.metric <- function(period, logger=NULL) { function(iteration, nbatch, env, verbose=TRUE) { if (nbatch %% period == 0 && !is.null(env$metric)) { N = env$end.round result <- env$metric$get(env$train.metric) plot(c(0.5,1)~c(0,N), col=NA, ylab = paste0("Train-", result$name),xlab = "") logger$train <- c(logger$train, result$value) lines(logger$train, lwd = 3, col="red") } return(TRUE) } } mx.set.seed(0) model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y, ctx=mx.gpu(), num.round=10, array.batch.size=100, learning.rate=0.05, momentum=0.9, eval.metric=mx.metric.accuracy, initializer=mx.init.uniform(0.07), epoch.end.callback=mx.callback.plot.train.metric(100, logger))