#include #include #include #include // using namespace at; using namespace torch; void submodular_select(Tensor candidate_points, Tensor features_done, Tensor features) { int max_idx = -1; float max_value = -1e-9; for (int i=0; i < candidate_points.size(0); i++) { std::vector temp; if (candidate_points.item() == 1) { temp.push_back(features_done); temp.push_back(features[candidate_points[i]]); auto stacked_temp = stack(temp); std::cout << std::get<0>(stacked_temp.max(1,false)) << std::endl; float value = std::get<0>(stacked_temp.max(1,false)).sum().item(); if (value > max_value) { max_value = value; max_idx = i; } } } std::cout<<"Max Value" << max_value << std::endl; std::cout << "Max Index" << max_idx << std::endl; // return max_idx; } int main() { int num_data_points = 6000; int num_features = 256; int batch_size = 64; Tensor features = torch::randn({num_data_points, num_features}, dtype(kFloat)); Tensor done = torch::randint(0, num_data_points, batch_size*4, dtype(kLong)); // Already Sampled Points Tensor done_index = torch::arange(0, batch_size*3, kLong).squeeze(); Tensor features_done = features.index(done_index); Tensor candidate_points = torch::ones(num_data_points, dtype(kLong)); auto scatter_val = torch::zeros(num_data_points, dtype(kLong)); candidate_points = candidate_points.scatter_(0, done, scatter_val); for (int batch=0; batch < batch_size; batch++) { submodular_select(candidate_points, features_done, features); // std::cout<