Skip to content

Instantly share code, notes, and snippets.

@m-Py
Last active February 25, 2020 19:16
Show Gist options
  • Select an option

  • Save m-Py/a844fe03838a4f3de017d76f2f18d8ae to your computer and use it in GitHub Desktop.

Select an option

Save m-Py/a844fe03838a4f3de017d76f2f18d8ae to your computer and use it in GitHub Desktop.

Revisions

  1. m-Py revised this gist Feb 25, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion KNN_RANN.R
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    # Author: Martin Papenber
    # Author: Martin Papenberg
    # Year: 2019

    # Perform fast KNN classifier using RANN for nearest neighbour search
  2. m-Py revised this gist Feb 25, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion KNN_RANN.R
    Original file line number Diff line number Diff line change
    @@ -24,7 +24,7 @@ knn_rann_dt <- function(data, labels, k = 10) {

    # convert indices to category
    nn_categories <- labels[nn_idx]
    # restore dimensionality of matrix disance.index
    # restore dimensionality of nearest neighbour matrix
    dim(nn_categories) <- dim(nn_idx)

    # By category: determine the number of nearest neighbours having
  3. m-Py revised this gist Feb 25, 2020. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions KNN_RANN.R
    Original file line number Diff line number Diff line change
    @@ -3,7 +3,7 @@

    # Perform fast KNN classifier using RANN for nearest neighbour search

    libray("RANN")
    library("RANN")
    library("data.table")

    # param data: The numeric data matrix used
    @@ -32,7 +32,7 @@ knn_rann_dt <- function(data, labels, k = 10) {
    nn_by_category <- function(i) {
    colSums(t(nn_categories) == i)
    }
    ncats <- length(unique(labels))
    ncats <- length(factor_levels)
    sum_nn_by_category <- sapply(1:ncats, nn_by_category)

    # use `data.table` to get index of maximum column, which corresponds
    @@ -51,7 +51,7 @@ mean(knns == iris$Species) # performance of KNN classifier
    # randomly generate some data for testing running time.
    # Runs KNN for N = 100000 in 0.3 sec
    # in ~ 5 sec for N = 1 million
    N <- 100000
    N <- 1000000
    data <- rnorm(N)
    labels <- sample(1:4, size = N, replace = TRUE)

  4. m-Py created this gist Feb 25, 2020.
    60 changes: 60 additions & 0 deletions KNN_RANN.R
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,60 @@
    # Author: Martin Papenber
    # Year: 2019

    # Perform fast KNN classifier using RANN for nearest neighbour search

    libray("RANN")
    library("data.table")

    # param data: The numeric data matrix used
    # param labels: the labels to predict
    # param k: The k used in KNN
    # return: The predicted KNN labels

    knn_rann_dt <- function(data, labels, k = 10) {
    data <- as.matrix(data)
    # use numeric representation of factor levels
    labels <- factor(labels)
    factor_levels <- levels(labels)
    labels <- as.numeric(labels)

    # imperfect approximation of removing self as neighbour,
    # just removing first column
    nn_idx <- nn2(data, k = min(k, nrow(data)))$nn.idx[, -1]

    # convert indices to category
    nn_categories <- labels[nn_idx]
    # restore dimensionality of matrix disance.index
    dim(nn_categories) <- dim(nn_idx)

    # By category: determine the number of nearest neighbours having
    # this category
    nn_by_category <- function(i) {
    colSums(t(nn_categories) == i)
    }
    ncats <- length(unique(labels))
    sum_nn_by_category <- sapply(1:ncats, nn_by_category)

    # use `data.table` to get index of maximum column, which corresponds
    # to the most frequent category across the nearest neighbours
    sum_nn_by_category <- as.data.table(sum_nn_by_category)
    sum_nn_by_category[, maximum_element := do.call(pmax, .SD), .SDcols = 1:ncats]
    factor_levels[sum_nn_by_category[, maximum_column := max.col(.SD), .SDcols = 1:ncats]$maximum_column]
    }


    ## Some example applications:

    knns <- knn_rann_dt(iris[, 1:4], iris[, 5], k = 10)
    mean(knns == iris$Species) # performance of KNN classifier

    # randomly generate some data for testing running time.
    # Runs KNN for N = 100000 in 0.3 sec
    # in ~ 5 sec for N = 1 million
    N <- 100000
    data <- rnorm(N)
    labels <- sample(1:4, size = N, replace = TRUE)

    start <- Sys.time()
    knns1 <- knn_rann_dt(data, labels)
    Sys.time() - start