Created
October 21, 2021 19:15
-
-
Save Eleobert/6e3dd16f64f63cd927aa7c13da238922 to your computer and use it in GitHub Desktop.
Revisions
-
Eleobert created this gist
Oct 21, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,105 @@ #include <armadillo> #include <vector> auto get_cluster_sim(const arma::mat& sim, const arma::uvec& cluster_a, const arma::uvec& cluster_b) { arma::uvec combined_clusters = arma::join_cols(cluster_a, cluster_b); arma::mat combined_sims = sim(combined_clusters, combined_clusters); auto exemplar = combined_clusters(arma::mean(combined_sims).index_max()); arma::vec exe_sims = sim.col(exemplar); return (arma::mean(exe_sims(cluster_a)) + arma::mean(exe_sims(cluster_b))) / 2.0; } auto get_clusters_sim(const arma::mat& sim, const std::vector<arma::uvec>& clusters) { arma::mat res(clusters.size(), clusters.size()); res.fill(arma::datum::nan); for(size_t i = 0; i < clusters.size(); i++) { res(i, i) = -arma::datum::inf; for(size_t j = i + 1; j < clusters.size(); j++) { res(i, j) = get_cluster_sim(sim, clusters[i], clusters[j]); res(j, i) = res(i, j); } } return res; } auto get_index_max(const arma::mat& mat) { auto index = arma::index_max(mat.as_col()); return std::make_pair(index % mat.n_cols, index / mat.n_cols); } auto remove_clusters(std::vector<arma::uvec>& clusters, size_t idx1, size_t idx2) { auto [idx_min, idx_max] = std::minmax(idx1, idx2); clusters.erase(clusters.begin() + idx_max); clusters.erase(clusters.begin() + idx_min); } auto remove_cluster_similarities(arma::mat& clusters_sim, size_t idx1, size_t idx2) { auto [idx_min, idx_max] = std::minmax(idx1, idx2); clusters_sim.shed_col(idx_max); clusters_sim.shed_col(idx_min); clusters_sim.shed_row(idx_max); clusters_sim.shed_row(idx_min); } auto update_clusters_sim(const arma::mat& sim, arma::mat& clusters_sim, std::vector<arma::uvec>& clusters, const arma::uvec& new_cluster) { arma::vec dumb_vec(clusters_sim.n_cols); dumb_vec.fill(arma::datum::nan); clusters_sim = arma::join_cols(clusters_sim, dumb_vec.t()); dumb_vec = arma::vec(clusters_sim.n_cols + 1); dumb_vec.fill(arma::datum::nan); clusters_sim = arma::join_rows(clusters_sim, dumb_vec); auto j = clusters_sim.n_cols - 1; for(size_t i = 0; i < clusters_sim.n_rows - 1; i++) { clusters_sim(i, j) = get_cluster_sim(sim, clusters[i], new_cluster); clusters_sim(j, i) = clusters_sim(i, j); } clusters_sim(j, j) = -arma::datum::inf; } auto agcluster(const arma::mat& sim, std::vector<arma::uvec> clusters, float cut_height) { arma::mat clusters_sim = get_clusters_sim(sim, clusters); while(true) { auto [i_max, j_max] = get_index_max(clusters_sim); auto height = clusters_sim(i_max, j_max); if(height < cut_height) { return clusters; } arma::uvec new_cluster = arma::join_cols(clusters[i_max], clusters[j_max]); // remove the old clusters and insert the new one remove_clusters(clusters, i_max, j_max); clusters.emplace_back(new_cluster); if(clusters.size() == 1) { return clusters; } remove_cluster_similarities(clusters_sim, i_max, j_max); // add the new cluster similarities update_clusters_sim(sim, clusters_sim, clusters, new_cluster); } }