Last active
March 14, 2021 17:00
-
-
Save tlepoint/11d6fc3e8c763b080334009e98c14147 to your computer and use it in GitHub Desktop.
prg_bench
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * Copyright 2021 https://github.com/tlepoint | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| #ifndef TLEPOINT_BARRETT_H | |
| #define TLEPOINT_BARRETT_H | |
| #include "absl/numeric/int128.h" | |
| #include "absl/types/span.h" | |
| // Modular reduction using Barrett reduction | |
| // https://en.wikipedia.org/wiki/Barrett_reduction | |
| template <typename Int> | |
| struct Barrett { | |
| Barrett(Int modulus) | |
| : modulus(modulus), | |
| precomputed(static_cast<Int>( | |
| (static_cast<absl::uint128>(1) << (sizeof(Int) * 8)) / modulus)) {} | |
| std::vector<Int> Reduce(absl::Span<const Int> in) { | |
| std::vector<Int> out; | |
| out.reserve(in.size()); | |
| std::transform( | |
| in.begin(), in.end(), std::back_inserter(out), | |
| [this](const Int &n) -> Int { | |
| Int q = static_cast<Int>( | |
| (static_cast<absl::uint128>(precomputed) * n) >> nbits); | |
| q = n - q * modulus; | |
| return (q >= modulus) ? q - modulus : q; | |
| }); | |
| return out; | |
| } | |
| Int modulus; | |
| Int precomputed; | |
| static constexpr size_t nbits = sizeof(Int) * 8; | |
| }; | |
| #endif // TLEPOINT_BARRETT_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Copyright 2021 https://github.com/tlepoint | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| #include "barrett.h" | |
| #include "absl/numeric/int128.h" | |
| #include "gtest/gtest.h" | |
| #include "openssl/rand.h" | |
| namespace { | |
| template <typename Int> | |
| class BarrettTest : public testing::Test { | |
| protected: | |
| void SetUp() { | |
| values_.resize(number_tests_); | |
| for (int i = 0; i < number_tests_; i++) { | |
| RAND_bytes(reinterpret_cast<uint8_t *>(&values_[i]), sizeof(values_[i])); | |
| } | |
| } | |
| std::vector<Int> values_; | |
| const int number_tests_ = 100; | |
| }; | |
| typedef testing::Types<uint16_t, uint32_t, uint64_t> IntTypes; | |
| TYPED_TEST_SUITE(BarrettTest, IntTypes); | |
| TYPED_TEST(BarrettTest, Correctness) { | |
| for (TypeParam modulus : | |
| {static_cast<TypeParam>(1) << 2, static_cast<TypeParam>(1) << 8, | |
| (static_cast<TypeParam>(-1) >> 6)}) { | |
| Barrett<TypeParam> barrett(modulus); | |
| std::vector<TypeParam> out = barrett.Reduce(this->values_); | |
| for (int i = 0; i < this->number_tests_; i++) { | |
| ASSERT_EQ(out[i], (this->values_[i] % modulus)); | |
| } | |
| } | |
| } | |
| } // namespace | |
| int main(int argc, char **argv) { | |
| ::testing::InitGoogleTest(&argc, argv); | |
| return RUN_ALL_TESTS(); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| load("@rules_cc//cc:defs.bzl", "cc_test") | |
| cc_library( | |
| name = "barrett", | |
| hdrs = ["barrett.h"], | |
| deps = [ | |
| "@com_google_absl//absl/numeric:int128", | |
| "@com_google_absl//absl/types:span", | |
| ], | |
| ) | |
| cc_test( | |
| name = "barrett_test", | |
| srcs = ["barrett_test.cc"], | |
| deps = [ | |
| ":barrett", | |
| "@boringssl//:crypto", | |
| "@com_google_absl//absl/numeric:int128", | |
| "@com_google_googletest//:gtest", | |
| ], | |
| ) | |
| cc_test( | |
| name = "prg_bench", | |
| srcs = ["prg_bench.cc"], | |
| deps = [ | |
| ":barrett", | |
| "@boringssl//:crypto", | |
| "@com_google_absl//absl/numeric:int128", | |
| "@com_google_absl//absl/status:statusor", | |
| "@com_google_absl//absl/types:span", | |
| "@com_google_benchmark//:benchmark", | |
| ], | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Copyright 2021 https://github.com/tlepoint | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| #include <iostream> | |
| #include "absl/numeric/int128.h" | |
| #include "absl/status/statusor.h" | |
| #include "barrett.h" | |
| #include "benchmark/benchmark.h" | |
| #include "openssl/base.h" | |
| #include "openssl/bn.h" | |
| #include "openssl/chacha.h" | |
| #include "openssl/ec.h" | |
| #include "openssl/nid.h" | |
| #include "openssl/rand.h" | |
| namespace { | |
| constexpr uint8_t kChaChaNonce[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; | |
| std::vector<uint8_t> SampleChaChaKey() { | |
| std::vector<uint8_t> key(32); | |
| RAND_bytes(key.data(), key.size()); | |
| return key; | |
| } | |
| // PRG based on extra bits sampled and modular reduction. | |
| template <typename Int> | |
| class PrgExtraBits { | |
| public: | |
| PrgExtraBits(Int modulus) : barrett_(modulus) {} | |
| absl::StatusOr<std::vector<Int>> Expand(absl::Span<const uint8_t> key, | |
| int n) { | |
| // Generate numbers with extra bits | |
| std::vector<Int> out_with_extra_bits(n); | |
| { | |
| std::vector<Int> in(n); | |
| CRYPTO_chacha_20(reinterpret_cast<uint8_t *>(out_with_extra_bits.data()), | |
| reinterpret_cast<const uint8_t *>(in.data()), | |
| in.size() * sizeof(Int), key.data(), kChaChaNonce, 0); | |
| } | |
| // Reduce! | |
| return barrett_.Reduce(out_with_extra_bits); | |
| } | |
| private: | |
| Barrett<Int> barrett_; | |
| }; | |
| // PRG Expansion using rejection sampling: generate a random value | |
| // of the same size as the modulus, and use it if it is smaller than | |
| // the modulus. | |
| template <typename Int> | |
| class PrgRejectionSampling { | |
| public: | |
| PrgRejectionSampling(Int modulus, int numbers_to_generate, | |
| double success_probability) | |
| : modulus_(modulus) { | |
| // Create the mask that may clean the MSBs. | |
| mask_ = static_cast<Int>(-1); | |
| if (std::log2(modulus) < sizeof(Int) * 8) { | |
| mask_ >>= (sizeof(Int) * 8 - static_cast<int>(std::log2(modulus))); | |
| } | |
| // The PRG will be used to generate n values. With probability p > 1/2, | |
| // a value will be accepted, hence we are looking for how many elements | |
| // N to generate so that at least n are accepted. | |
| // | |
| // Hoeffding's inequality: Pr[#accept <= k] <= exp(-2N(p - k/N)^2)) | |
| // Now p = modulus / 2^bitsize(Int). First, this gives a lower bound on | |
| // N: (n-1) / N < p => N > (n-1) * 2^bitsize(Int) / modulus | |
| N_ = std::ceil((numbers_to_generate - 1) * | |
| static_cast<double>(static_cast<absl::uint128>(1) | |
| << (sizeof(Int) * 8)) / | |
| modulus); | |
| // Next, we want that | |
| // Pr[#accepts >= n] = 1 - Pr[#accepts <= n - 1] > success_probability | |
| // i.e. | |
| // Pr[#accepts <= n - 1] < 1 - success_probability | |
| while (true) { | |
| // Since p > 1/2 and N > 2(n-1), -(p - (n-1)/N)^2 < -(1/2 - (n-1)/N)^2. | |
| double q = std::exp( | |
| -2 * N_ * | |
| std::pow(0.5 - static_cast<double>(numbers_to_generate - 1) / N_, 2)); | |
| if (q <= 1 - success_probability) { | |
| return; | |
| } | |
| N_++; | |
| } | |
| } | |
| absl::StatusOr<std::vector<Int>> Expand(absl::Span<const uint8_t> key, | |
| int n) { | |
| // Generate N_ random values. | |
| std::vector<Int> out(N_); | |
| { | |
| std::vector<Int> in(N_); | |
| CRYPTO_chacha_20(reinterpret_cast<uint8_t *>(out.data()), | |
| reinterpret_cast<const uint8_t *>(in.data()), | |
| in.size() * sizeof(Int), key.data(), kChaChaNonce, 0); | |
| } | |
| // Clear out most significant bits (if needed). | |
| std::transform(out.begin(), out.end(), out.begin(), | |
| [this](Int n) { return n & mask_; }); | |
| // Rejection sampling. | |
| auto it = std::remove_if(out.begin(), out.end(), | |
| [this](Int n) { return n >= modulus_; }); | |
| if (std::distance(out.begin(), it) < n) { | |
| return absl::InvalidArgumentError( | |
| absl::StrCat("Key does not allow to produce ", n, " elements.")); | |
| } else { | |
| out.resize(n); | |
| return out; | |
| } | |
| } | |
| private: | |
| Int modulus_; | |
| Int mask_; | |
| int N_; | |
| }; | |
| // Benchmarks | |
| void BM_Extra32Bits_4293918721(benchmark::State &state) { | |
| int len = state.range(0); | |
| PrgExtraBits<uint64_t> prg(4293918721ULL); | |
| std::vector<uint8_t> key = SampleChaChaKey(); // Any key works | |
| for (auto _ : state) { | |
| ::benchmark::DoNotOptimize(prg.Expand(key, len)); | |
| } | |
| } | |
| BENCHMARK(BM_Extra32Bits_4293918721)->Range(8, 8 << 12); | |
| void BM_Extra96Bits_4293918721(benchmark::State &state) { | |
| int len = state.range(0); | |
| PrgExtraBits<absl::uint128> prg(4293918721ULL); | |
| std::vector<uint8_t> key = SampleChaChaKey(); // Any key works | |
| for (auto _ : state) { | |
| ::benchmark::DoNotOptimize(prg.Expand(key, len)); | |
| } | |
| } | |
| BENCHMARK(BM_Extra96Bits_4293918721)->Range(8, 8 << 12); | |
| void BM_Extra64Bits_15564440312192434177(benchmark::State &state) { | |
| int len = state.range(0); | |
| PrgExtraBits<absl::uint128> prg(15564440312192434177ULL); | |
| std::vector<uint8_t> key = SampleChaChaKey(); // Any key works | |
| for (auto _ : state) { | |
| ::benchmark::DoNotOptimize(prg.Expand(key, len)); | |
| } | |
| } | |
| BENCHMARK(BM_Extra64Bits_15564440312192434177)->Range(8, 8 << 12); | |
| void BM_RejectionSampling_4293918721(benchmark::State &state) { | |
| int len = state.range(0); | |
| PrgRejectionSampling<uint32_t> prg(4293918721ULL, len, 0.5); | |
| // The cost of finding a valid key is on the client. | |
| absl::Status status = absl::NotFoundError(""); | |
| std::vector<uint8_t> key; | |
| while (!status.ok()) { | |
| key = SampleChaChaKey(); | |
| status = prg.Expand(key, len).status(); | |
| } | |
| for (auto _ : state) { | |
| auto status_or_data = prg.Expand(key, len); | |
| assert(status_or_data.ok()); // Expand should never fail. | |
| } | |
| } | |
| BENCHMARK(BM_RejectionSampling_4293918721)->Range(8, 8 << 12); | |
| void BM_RejectionSampling_15564440312192434177(benchmark::State &state) { | |
| int len = state.range(0); | |
| PrgRejectionSampling<uint64_t> prg(15564440312192434177ULL, len, 0.5); | |
| // The cost of finding a valid key is on the client. | |
| absl::Status status = absl::NotFoundError(""); | |
| std::vector<uint8_t> key; | |
| while (!status.ok()) { | |
| key = SampleChaChaKey(); | |
| status = prg.Expand(key, len).status(); | |
| } | |
| for (auto _ : state) { | |
| auto status_or_data = prg.Expand(key, len); | |
| assert(status_or_data.ok()); // Expand should never fail. | |
| } | |
| } | |
| BENCHMARK(BM_RejectionSampling_15564440312192434177)->Range(8, 8 << 12); | |
| static void BM_Prg(benchmark::State &state) { | |
| std::vector<uint8_t> key(32); | |
| RAND_bytes(key.data(), key.size()); | |
| int n = state.range(0); | |
| std::vector<uint8_t> in(n); | |
| std::vector<uint8_t> out(n); | |
| for (auto _ : state) { | |
| CRYPTO_chacha_20(out.data(), in.data(), in.size(), key.data(), kChaChaNonce, | |
| 0); | |
| ::benchmark::DoNotOptimize(out); | |
| } | |
| } | |
| BENCHMARK(BM_Prg)->Range(8, 8 << 12); | |
| static void BM_ECDH(benchmark::State &state) { | |
| bssl::UniquePtr<EC_GROUP> ec_group( | |
| EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1)); | |
| bssl::UniquePtr<EC_KEY> ec_key(EC_KEY_new()); | |
| assert(EC_KEY_set_group(ec_key.get(), ec_group.get()) == 1); | |
| assert(EC_KEY_generate_key(ec_key.get()) == 1); | |
| const BIGNUM *priv_key = EC_KEY_get0_private_key(ec_key.get()); | |
| const EC_POINT *pub_key = EC_KEY_get0_public_key(ec_key.get()); | |
| for (auto _ : state) { | |
| // Compute the shared point. | |
| bssl::UniquePtr<EC_POINT> shared_point(EC_POINT_new(ec_group.get())); | |
| assert(EC_POINT_mul(ec_group.get(), shared_point.get(), nullptr, pub_key, | |
| priv_key, nullptr) == 1); | |
| assert(EC_POINT_is_on_curve(ec_group.get(), shared_point.get(), nullptr) == | |
| 1); | |
| // Get shared point's x coordinate. | |
| bssl::UniquePtr<BIGNUM> shared_x(BN_new()); | |
| assert(EC_POINT_get_affine_coordinates_GFp( | |
| ec_group.get(), shared_point.get(), shared_x.get(), nullptr, | |
| nullptr) == 1); | |
| // Serialize BIGNUM. | |
| std::vector<uint8_t> res((EC_GROUP_get_degree(ec_group.get()) + 7) / 8); | |
| assert(BN_bn2bin_padded(res.data(), res.size(), shared_x.get())); | |
| } | |
| } | |
| BENCHMARK(BM_ECDH); | |
| } // namespace | |
| BENCHMARK_MAIN(); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| workspace(name = "libprio_prg") | |
| load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") | |
| load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") | |
| http_archive( | |
| name = "rules_cc", | |
| sha256 = "92a89a2bbe6c6db2a8b87da4ce723aff6253656e8417f37e50d362817c39b98b", | |
| strip_prefix = "rules_cc-88ef31b429631b787ceb5e4556d773b20ad797c8", | |
| urls = ["https://github.com/bazelbuild/rules_cc/archive/88ef31b429631b787ceb5e4556d773b20ad797c8.zip"], | |
| ) | |
| http_archive( | |
| name = "com_google_googletest", | |
| sha256 = "3c3e9ec31fe35a230d0fa335a31c5d2262dc50245a1cb1e5969b51c6f038cafc", | |
| strip_prefix = "googletest-763eaa430540926fa16060654427149802c97fba", | |
| urls = [ | |
| "https://github.com/google/googletest/archive/763eaa430540926fa16060654427149802c97fba.zip", | |
| ], | |
| ) | |
| http_archive( | |
| name = "com_google_benchmark", | |
| sha256 = "bc60957389e8d9e37d1a40fad22da7a1950e382850cec80b0133fcbfa7d41016", | |
| strip_prefix = "benchmark-cc9abfc8f12577ea83b2d093693ba70c3c0fd2c7", | |
| urls = [ | |
| "https://github.com/google/benchmark/archive/cc9abfc8f12577ea83b2d093693ba70c3c0fd2c7.zip", | |
| ], | |
| ) | |
| http_archive( | |
| name = "com_google_absl", | |
| strip_prefix = "abseil-cpp-ab21820d47e4f83875dda008b600514d3520fd35", | |
| urls = ["https://github.com/abseil/abseil-cpp/archive/ab21820d47e4f83875dda008b600514d3520fd35.zip"], | |
| ) | |
| git_repository( | |
| name = "boringssl", | |
| commit = "afd67cd00e55e4e22b14f096361c732531a0c539", | |
| remote = "https://boringssl.googlesource.com/boringssl", | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment