Skip to content

Instantly share code, notes, and snippets.

@tlepoint
Last active March 14, 2021 17:00
Show Gist options
  • Save tlepoint/11d6fc3e8c763b080334009e98c14147 to your computer and use it in GitHub Desktop.
Save tlepoint/11d6fc3e8c763b080334009e98c14147 to your computer and use it in GitHub Desktop.
prg_bench
/*
* Copyright 2021 https://github.com/tlepoint
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TLEPOINT_BARRETT_H
#define TLEPOINT_BARRETT_H
#include "absl/numeric/int128.h"
#include "absl/types/span.h"
// Modular reduction using Barrett reduction
// https://en.wikipedia.org/wiki/Barrett_reduction
template <typename Int>
struct Barrett {
Barrett(Int modulus)
: modulus(modulus),
precomputed(static_cast<Int>(
(static_cast<absl::uint128>(1) << (sizeof(Int) * 8)) / modulus)) {}
std::vector<Int> Reduce(absl::Span<const Int> in) {
std::vector<Int> out;
out.reserve(in.size());
std::transform(
in.begin(), in.end(), std::back_inserter(out),
[this](const Int &n) -> Int {
Int q = static_cast<Int>(
(static_cast<absl::uint128>(precomputed) * n) >> nbits);
q = n - q * modulus;
return (q >= modulus) ? q - modulus : q;
});
return out;
}
Int modulus;
Int precomputed;
static constexpr size_t nbits = sizeof(Int) * 8;
};
#endif // TLEPOINT_BARRETT_H
// Copyright 2021 https://github.com/tlepoint
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "barrett.h"
#include "absl/numeric/int128.h"
#include "gtest/gtest.h"
#include "openssl/rand.h"
namespace {
template <typename Int>
class BarrettTest : public testing::Test {
protected:
void SetUp() {
values_.resize(number_tests_);
for (int i = 0; i < number_tests_; i++) {
RAND_bytes(reinterpret_cast<uint8_t *>(&values_[i]), sizeof(values_[i]));
}
}
std::vector<Int> values_;
const int number_tests_ = 100;
};
typedef testing::Types<uint16_t, uint32_t, uint64_t> IntTypes;
TYPED_TEST_SUITE(BarrettTest, IntTypes);
TYPED_TEST(BarrettTest, Correctness) {
for (TypeParam modulus :
{static_cast<TypeParam>(1) << 2, static_cast<TypeParam>(1) << 8,
(static_cast<TypeParam>(-1) >> 6)}) {
Barrett<TypeParam> barrett(modulus);
std::vector<TypeParam> out = barrett.Reduce(this->values_);
for (int i = 0; i < this->number_tests_; i++) {
ASSERT_EQ(out[i], (this->values_[i] % modulus));
}
}
}
} // namespace
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
load("@rules_cc//cc:defs.bzl", "cc_test")
cc_library(
name = "barrett",
hdrs = ["barrett.h"],
deps = [
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "barrett_test",
srcs = ["barrett_test.cc"],
deps = [
":barrett",
"@boringssl//:crypto",
"@com_google_absl//absl/numeric:int128",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "prg_bench",
srcs = ["prg_bench.cc"],
deps = [
":barrett",
"@boringssl//:crypto",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_benchmark//:benchmark",
],
)
// Copyright 2021 https://github.com/tlepoint
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "absl/numeric/int128.h"
#include "absl/status/statusor.h"
#include "barrett.h"
#include "benchmark/benchmark.h"
#include "openssl/base.h"
#include "openssl/bn.h"
#include "openssl/chacha.h"
#include "openssl/ec.h"
#include "openssl/nid.h"
#include "openssl/rand.h"
namespace {
constexpr uint8_t kChaChaNonce[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<uint8_t> SampleChaChaKey() {
std::vector<uint8_t> key(32);
RAND_bytes(key.data(), key.size());
return key;
}
// PRG based on extra bits sampled and modular reduction.
template <typename Int>
class PrgExtraBits {
public:
PrgExtraBits(Int modulus) : barrett_(modulus) {}
absl::StatusOr<std::vector<Int>> Expand(absl::Span<const uint8_t> key,
int n) {
// Generate numbers with extra bits
std::vector<Int> out_with_extra_bits(n);
{
std::vector<Int> in(n);
CRYPTO_chacha_20(reinterpret_cast<uint8_t *>(out_with_extra_bits.data()),
reinterpret_cast<const uint8_t *>(in.data()),
in.size() * sizeof(Int), key.data(), kChaChaNonce, 0);
}
// Reduce!
return barrett_.Reduce(out_with_extra_bits);
}
private:
Barrett<Int> barrett_;
};
// PRG Expansion using rejection sampling: generate a random value
// of the same size as the modulus, and use it if it is smaller than
// the modulus.
template <typename Int>
class PrgRejectionSampling {
public:
PrgRejectionSampling(Int modulus, int numbers_to_generate,
double success_probability)
: modulus_(modulus) {
// Create the mask that may clean the MSBs.
mask_ = static_cast<Int>(-1);
if (std::log2(modulus) < sizeof(Int) * 8) {
mask_ >>= (sizeof(Int) * 8 - static_cast<int>(std::log2(modulus)));
}
// The PRG will be used to generate n values. With probability p > 1/2,
// a value will be accepted, hence we are looking for how many elements
// N to generate so that at least n are accepted.
//
// Hoeffding's inequality: Pr[#accept <= k] <= exp(-2N(p - k/N)^2))
// Now p = modulus / 2^bitsize(Int). First, this gives a lower bound on
// N: (n-1) / N < p => N > (n-1) * 2^bitsize(Int) / modulus
N_ = std::ceil((numbers_to_generate - 1) *
static_cast<double>(static_cast<absl::uint128>(1)
<< (sizeof(Int) * 8)) /
modulus);
// Next, we want that
// Pr[#accepts >= n] = 1 - Pr[#accepts <= n - 1] > success_probability
// i.e.
// Pr[#accepts <= n - 1] < 1 - success_probability
while (true) {
// Since p > 1/2 and N > 2(n-1), -(p - (n-1)/N)^2 < -(1/2 - (n-1)/N)^2.
double q = std::exp(
-2 * N_ *
std::pow(0.5 - static_cast<double>(numbers_to_generate - 1) / N_, 2));
if (q <= 1 - success_probability) {
return;
}
N_++;
}
}
absl::StatusOr<std::vector<Int>> Expand(absl::Span<const uint8_t> key,
int n) {
// Generate N_ random values.
std::vector<Int> out(N_);
{
std::vector<Int> in(N_);
CRYPTO_chacha_20(reinterpret_cast<uint8_t *>(out.data()),
reinterpret_cast<const uint8_t *>(in.data()),
in.size() * sizeof(Int), key.data(), kChaChaNonce, 0);
}
// Clear out most significant bits (if needed).
std::transform(out.begin(), out.end(), out.begin(),
[this](Int n) { return n & mask_; });
// Rejection sampling.
auto it = std::remove_if(out.begin(), out.end(),
[this](Int n) { return n >= modulus_; });
if (std::distance(out.begin(), it) < n) {
return absl::InvalidArgumentError(
absl::StrCat("Key does not allow to produce ", n, " elements."));
} else {
out.resize(n);
return out;
}
}
private:
Int modulus_;
Int mask_;
int N_;
};
// Benchmarks
void BM_Extra32Bits_4293918721(benchmark::State &state) {
int len = state.range(0);
PrgExtraBits<uint64_t> prg(4293918721ULL);
std::vector<uint8_t> key = SampleChaChaKey(); // Any key works
for (auto _ : state) {
::benchmark::DoNotOptimize(prg.Expand(key, len));
}
}
BENCHMARK(BM_Extra32Bits_4293918721)->Range(8, 8 << 12);
void BM_Extra96Bits_4293918721(benchmark::State &state) {
int len = state.range(0);
PrgExtraBits<absl::uint128> prg(4293918721ULL);
std::vector<uint8_t> key = SampleChaChaKey(); // Any key works
for (auto _ : state) {
::benchmark::DoNotOptimize(prg.Expand(key, len));
}
}
BENCHMARK(BM_Extra96Bits_4293918721)->Range(8, 8 << 12);
void BM_Extra64Bits_15564440312192434177(benchmark::State &state) {
int len = state.range(0);
PrgExtraBits<absl::uint128> prg(15564440312192434177ULL);
std::vector<uint8_t> key = SampleChaChaKey(); // Any key works
for (auto _ : state) {
::benchmark::DoNotOptimize(prg.Expand(key, len));
}
}
BENCHMARK(BM_Extra64Bits_15564440312192434177)->Range(8, 8 << 12);
void BM_RejectionSampling_4293918721(benchmark::State &state) {
int len = state.range(0);
PrgRejectionSampling<uint32_t> prg(4293918721ULL, len, 0.5);
// The cost of finding a valid key is on the client.
absl::Status status = absl::NotFoundError("");
std::vector<uint8_t> key;
while (!status.ok()) {
key = SampleChaChaKey();
status = prg.Expand(key, len).status();
}
for (auto _ : state) {
auto status_or_data = prg.Expand(key, len);
assert(status_or_data.ok()); // Expand should never fail.
}
}
BENCHMARK(BM_RejectionSampling_4293918721)->Range(8, 8 << 12);
void BM_RejectionSampling_15564440312192434177(benchmark::State &state) {
int len = state.range(0);
PrgRejectionSampling<uint64_t> prg(15564440312192434177ULL, len, 0.5);
// The cost of finding a valid key is on the client.
absl::Status status = absl::NotFoundError("");
std::vector<uint8_t> key;
while (!status.ok()) {
key = SampleChaChaKey();
status = prg.Expand(key, len).status();
}
for (auto _ : state) {
auto status_or_data = prg.Expand(key, len);
assert(status_or_data.ok()); // Expand should never fail.
}
}
BENCHMARK(BM_RejectionSampling_15564440312192434177)->Range(8, 8 << 12);
static void BM_Prg(benchmark::State &state) {
std::vector<uint8_t> key(32);
RAND_bytes(key.data(), key.size());
int n = state.range(0);
std::vector<uint8_t> in(n);
std::vector<uint8_t> out(n);
for (auto _ : state) {
CRYPTO_chacha_20(out.data(), in.data(), in.size(), key.data(), kChaChaNonce,
0);
::benchmark::DoNotOptimize(out);
}
}
BENCHMARK(BM_Prg)->Range(8, 8 << 12);
static void BM_ECDH(benchmark::State &state) {
bssl::UniquePtr<EC_GROUP> ec_group(
EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
bssl::UniquePtr<EC_KEY> ec_key(EC_KEY_new());
assert(EC_KEY_set_group(ec_key.get(), ec_group.get()) == 1);
assert(EC_KEY_generate_key(ec_key.get()) == 1);
const BIGNUM *priv_key = EC_KEY_get0_private_key(ec_key.get());
const EC_POINT *pub_key = EC_KEY_get0_public_key(ec_key.get());
for (auto _ : state) {
// Compute the shared point.
bssl::UniquePtr<EC_POINT> shared_point(EC_POINT_new(ec_group.get()));
assert(EC_POINT_mul(ec_group.get(), shared_point.get(), nullptr, pub_key,
priv_key, nullptr) == 1);
assert(EC_POINT_is_on_curve(ec_group.get(), shared_point.get(), nullptr) ==
1);
// Get shared point's x coordinate.
bssl::UniquePtr<BIGNUM> shared_x(BN_new());
assert(EC_POINT_get_affine_coordinates_GFp(
ec_group.get(), shared_point.get(), shared_x.get(), nullptr,
nullptr) == 1);
// Serialize BIGNUM.
std::vector<uint8_t> res((EC_GROUP_get_degree(ec_group.get()) + 7) / 8);
assert(BN_bn2bin_padded(res.data(), res.size(), shared_x.get()));
}
}
BENCHMARK(BM_ECDH);
} // namespace
BENCHMARK_MAIN();
workspace(name = "libprio_prg")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
http_archive(
name = "rules_cc",
sha256 = "92a89a2bbe6c6db2a8b87da4ce723aff6253656e8417f37e50d362817c39b98b",
strip_prefix = "rules_cc-88ef31b429631b787ceb5e4556d773b20ad797c8",
urls = ["https://github.com/bazelbuild/rules_cc/archive/88ef31b429631b787ceb5e4556d773b20ad797c8.zip"],
)
http_archive(
name = "com_google_googletest",
sha256 = "3c3e9ec31fe35a230d0fa335a31c5d2262dc50245a1cb1e5969b51c6f038cafc",
strip_prefix = "googletest-763eaa430540926fa16060654427149802c97fba",
urls = [
"https://github.com/google/googletest/archive/763eaa430540926fa16060654427149802c97fba.zip",
],
)
http_archive(
name = "com_google_benchmark",
sha256 = "bc60957389e8d9e37d1a40fad22da7a1950e382850cec80b0133fcbfa7d41016",
strip_prefix = "benchmark-cc9abfc8f12577ea83b2d093693ba70c3c0fd2c7",
urls = [
"https://github.com/google/benchmark/archive/cc9abfc8f12577ea83b2d093693ba70c3c0fd2c7.zip",
],
)
http_archive(
name = "com_google_absl",
strip_prefix = "abseil-cpp-ab21820d47e4f83875dda008b600514d3520fd35",
urls = ["https://github.com/abseil/abseil-cpp/archive/ab21820d47e4f83875dda008b600514d3520fd35.zip"],
)
git_repository(
name = "boringssl",
commit = "afd67cd00e55e4e22b14f096361c732531a0c539",
remote = "https://boringssl.googlesource.com/boringssl",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment