Skip to content

Instantly share code, notes, and snippets.

@thirdwing
Created June 3, 2016 22:17
Show Gist options
  • Select an option

  • Save thirdwing/ceaf2f8725349d98249e7631227b5f35 to your computer and use it in GitHub Desktop.

Select an option

Save thirdwing/ceaf2f8725349d98249e7631227b5f35 to your computer and use it in GitHub Desktop.

Revisions

  1. Qiang Kou (KK) revised this gist Jun 3, 2016. 1 changed file with 11 additions and 1 deletion.
    12 changes: 11 additions & 1 deletion callback.plot.R
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,5 @@
    logger <- mx.metric.logger$new()

    mx.callback.plot.train.metric <- function(period, logger=NULL) {
    function(iteration, nbatch, env, verbose=TRUE) {
    if (nbatch %% period == 0 && !is.null(env$metric)) {
    @@ -9,4 +11,12 @@ mx.callback.plot.train.metric <- function(period, logger=NULL) {
    }
    return(TRUE)
    }
    }
    }

    mx.set.seed(0)
    model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
    ctx=mx.gpu(), num.round=10, array.batch.size=100,
    learning.rate=0.05, momentum=0.9,
    eval.metric=mx.metric.accuracy,
    initializer=mx.init.uniform(0.07),
    epoch.end.callback=mx.callback.plot.train.metric(100, logger))
  2. Qiang Kou (KK) created this gist Jun 3, 2016.
    12 changes: 12 additions & 0 deletions callback.plot.R
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,12 @@
    mx.callback.plot.train.metric <- function(period, logger=NULL) {
    function(iteration, nbatch, env, verbose=TRUE) {
    if (nbatch %% period == 0 && !is.null(env$metric)) {
    N = env$end.round
    result <- env$metric$get(env$train.metric)
    plot(c(0.5,1)~c(0,N), col=NA, ylab = paste0("Train-", result$name),xlab = "")
    logger$train <- c(logger$train, result$value)
    lines(logger$train, lwd = 3, col="red")
    }
    return(TRUE)
    }
    }