From 62d1e78a19b8552a2e928b48f00b162f12380fb6 Mon Sep 17 00:00:00 2001 From: FireMartZ Date: Mon, 12 Feb 2018 18:46:04 -0500 Subject: [PATCH 1/2] Merge with libsnarks used for HUSH 1.0.12 (https://github.com/radix42/libsnark/commit/4015f558bb7a1b603de4ee1559c974389d3abb54) --- .../algebra/curves/alt_bn128/alt_bn128_g1.cpp | 8 +- .../algebra/curves/alt_bn128/alt_bn128_g1.hpp | 8 +- .../algebra/curves/alt_bn128/alt_bn128_g2.cpp | 8 +- .../algebra/curves/alt_bn128/alt_bn128_g2.hpp | 8 +- .../curves/alt_bn128/alt_bn128_pairing.cpp | 6 +- src/snark/src/algebra/curves/curve_utils.tcc | 2 +- .../algebra/curves/tests/test_bilinearity.cpp | 42 ++- .../src/algebra/curves/tests/test_groups.cpp | 101 ++++--- .../domains/basic_radix2_domain_aux.tcc | 10 +- .../domains/extended_radix2_domain.hpp | 48 ++++ .../domains/extended_radix2_domain.tcc | 180 +++++++++++++ .../domains/step_radix2_domain.hpp | 50 ++++ .../domains/step_radix2_domain.tcc | 247 ++++++++++++++++++ .../evaluation_domain/evaluation_domain.tcc | 12 +- .../algebra/exponentiation/exponentiation.hpp | 2 +- .../algebra/exponentiation/exponentiation.tcc | 4 +- src/snark/src/algebra/fields/bigint.hpp | 4 +- src/snark/src/algebra/fields/bigint.tcc | 13 +- src/snark/src/algebra/fields/field_utils.hpp | 8 +- src/snark/src/algebra/fields/field_utils.tcc | 32 +-- src/snark/src/algebra/fields/fp.hpp | 28 +- src/snark/src/algebra/fields/fp.tcc | 22 +- .../src/algebra/fields/fp12_2over3over2.hpp | 4 +- .../src/algebra/fields/fp12_2over3over2.tcc | 10 +- src/snark/src/algebra/fields/fp2.hpp | 8 +- src/snark/src/algebra/fields/fp2.tcc | 8 +- src/snark/src/algebra/fields/fp3.hpp | 6 +- src/snark/src/algebra/fields/fp3.tcc | 6 +- src/snark/src/algebra/fields/fp6_3over2.hpp | 4 +- src/snark/src/algebra/fields/fp6_3over2.tcc | 4 +- .../src/algebra/fields/tests/test_bigint.cpp | 90 ++++--- .../src/algebra/fields/tests/test_fields.cpp | 108 ++++++-- .../scalar_multiplication/kc_multiexp.tcc | 1 + .../scalar_multiplication/multiexp.tcc | 16 +- .../algebra/scalar_multiplication/wnaf.hpp | 2 +- .../algebra/scalar_multiplication/wnaf.tcc | 18 +- src/snark/src/common/assert_except.hpp | 8 +- .../common/data_structures/merkle_tree.tcc | 18 +- .../common/data_structures/sparse_vector.hpp | 14 +- .../common/data_structures/sparse_vector.tcc | 36 +-- src/snark/src/common/profiling.cpp | 67 +++-- src/snark/src/common/profiling.hpp | 10 +- src/snark/src/common/utils.cpp | 28 +- src/snark/src/common/utils.hpp | 10 +- .../src/gadgetlib1/gadgets/basic_gadgets.tcc | 40 +-- .../gadgets/hashes/sha256/sha256_aux.tcc | 2 +- .../hashes/sha256/sha256_components.hpp | 4 +- .../hashes/sha256/sha256_components.tcc | 6 +- .../sha256/tests/test_sha256_gadget.cpp | 6 +- .../merkle_authentication_path_variable.tcc | 4 +- .../merkle_tree_check_read_gadget.tcc | 4 +- .../merkle_tree_check_update_gadget.hpp | 1 + .../merkle_tree_check_update_gadget.tcc | 4 +- .../tests/test_merkle_tree_gadgets.cpp | 22 +- src/snark/src/gadgetlib1/pb_variable.hpp | 4 +- src/snark/src/gadgetlib1/pb_variable.tcc | 4 +- .../qap/tests/test_qap.cpp | 31 ++- src/snark/src/relations/variable.hpp | 2 +- .../examples/run_r1cs_ppzksnark.tcc | 2 +- .../tests/test_r1cs_ppzksnark.cpp | 14 +- 60 files changed, 1069 insertions(+), 400 deletions(-) create mode 100644 src/snark/src/algebra/evaluation_domain/domains/extended_radix2_domain.hpp create mode 100644 src/snark/src/algebra/evaluation_domain/domains/extended_radix2_domain.tcc create mode 100644 src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.hpp create mode 100644 src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.tcc diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.cpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.cpp index bf7f43d6f..9cc29c614 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.cpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.cpp @@ -10,12 +10,12 @@ namespace libsnark { #ifdef PROFILE_OP_COUNTS -long long alt_bn128_G1::add_cnt = 0; -long long alt_bn128_G1::dbl_cnt = 0; +int64_t alt_bn128_G1::add_cnt = 0; +int64_t alt_bn128_G1::dbl_cnt = 0; #endif -std::vector alt_bn128_G1::wnaf_window_table; -std::vector alt_bn128_G1::fixed_base_exp_window_table; +std::vector alt_bn128_G1::wnaf_window_table; +std::vector alt_bn128_G1::fixed_base_exp_window_table; alt_bn128_G1 alt_bn128_G1::G1_zero; alt_bn128_G1 alt_bn128_G1::G1_one; diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp index da11a2e8c..567f2fa3f 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp @@ -20,11 +20,11 @@ std::istream& operator>>(std::istream &, alt_bn128_G1&); class alt_bn128_G1 { public: #ifdef PROFILE_OP_COUNTS - static long long add_cnt; - static long long dbl_cnt; + static int64_t add_cnt; + static int64_t dbl_cnt; #endif - static std::vector wnaf_window_table; - static std::vector fixed_base_exp_window_table; + static std::vector wnaf_window_table; + static std::vector fixed_base_exp_window_table; static alt_bn128_G1 G1_zero; static alt_bn128_G1 G1_one; diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.cpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.cpp index c4152e437..6f3e430d3 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.cpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.cpp @@ -10,12 +10,12 @@ namespace libsnark { #ifdef PROFILE_OP_COUNTS -long long alt_bn128_G2::add_cnt = 0; -long long alt_bn128_G2::dbl_cnt = 0; +int64_t alt_bn128_G2::add_cnt = 0; +int64_t alt_bn128_G2::dbl_cnt = 0; #endif -std::vector alt_bn128_G2::wnaf_window_table; -std::vector alt_bn128_G2::fixed_base_exp_window_table; +std::vector alt_bn128_G2::wnaf_window_table; +std::vector alt_bn128_G2::fixed_base_exp_window_table; alt_bn128_G2 alt_bn128_G2::G2_zero; alt_bn128_G2 alt_bn128_G2::G2_one; diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp index a996a2d1a..57bad1a4b 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp @@ -20,11 +20,11 @@ std::istream& operator>>(std::istream &, alt_bn128_G2&); class alt_bn128_G2 { public: #ifdef PROFILE_OP_COUNTS - static long long add_cnt; - static long long dbl_cnt; + static int64_t add_cnt; + static int64_t dbl_cnt; #endif - static std::vector wnaf_window_table; - static std::vector fixed_base_exp_window_table; + static std::vector wnaf_window_table; + static std::vector fixed_base_exp_window_table; static alt_bn128_G2 G2_zero; static alt_bn128_G2 G2_one; diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_pairing.cpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_pairing.cpp index db556c5b2..07b6a8c71 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_pairing.cpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_pairing.cpp @@ -324,7 +324,7 @@ alt_bn128_ate_G2_precomp alt_bn128_ate_precompute_G2(const alt_bn128_G2& Q) bool found_one = false; alt_bn128_ate_ell_coeffs c; - for (long i = loop_count.max_bits(); i >= 0; --i) + for (int64_t i = loop_count.max_bits(); i >= 0; --i) { const bool bit = loop_count.test_bit(i); if (!found_one) @@ -378,7 +378,7 @@ alt_bn128_Fq12 alt_bn128_ate_miller_loop(const alt_bn128_ate_G1_precomp &prec_P, const bigint &loop_count = alt_bn128_ate_loop_count; alt_bn128_ate_ell_coeffs c; - for (long i = loop_count.max_bits(); i >= 0; --i) + for (int64_t i = loop_count.max_bits(); i >= 0; --i) { const bool bit = loop_count.test_bit(i); if (!found_one) @@ -432,7 +432,7 @@ alt_bn128_Fq12 alt_bn128_ate_double_miller_loop(const alt_bn128_ate_G1_precomp & size_t idx = 0; const bigint &loop_count = alt_bn128_ate_loop_count; - for (long i = loop_count.max_bits(); i >= 0; --i) + for (int64_t i = loop_count.max_bits(); i >= 0; --i) { const bool bit = loop_count.test_bit(i); if (!found_one) diff --git a/src/snark/src/algebra/curves/curve_utils.tcc b/src/snark/src/algebra/curves/curve_utils.tcc index 251d75d8b..38140cd48 100644 --- a/src/snark/src/algebra/curves/curve_utils.tcc +++ b/src/snark/src/algebra/curves/curve_utils.tcc @@ -16,7 +16,7 @@ GroupT scalar_mul(const GroupT &base, const bigint &scalar) GroupT result = GroupT::zero(); bool found_one = false; - for (long i = scalar.max_bits() - 1; i >= 0; --i) + for (int64_t i = scalar.max_bits() - 1; i >= 0; --i) { if (found_one) { diff --git a/src/snark/src/algebra/curves/tests/test_bilinearity.cpp b/src/snark/src/algebra/curves/tests/test_bilinearity.cpp index fe6593bae..18e68f7bb 100644 --- a/src/snark/src/algebra/curves/tests/test_bilinearity.cpp +++ b/src/snark/src/algebra/curves/tests/test_bilinearity.cpp @@ -4,13 +4,17 @@ * and contributors (see AUTHORS). * @copyright MIT license (see LICENSE file) *****************************************************************************/ +#include #include "common/profiling.hpp" +//#include "algebra/curves/edwards/edwards_pp.hpp" #ifdef CURVE_BN128 #include "algebra/curves/bn128/bn128_pp.hpp" #endif #include "algebra/curves/alt_bn128/alt_bn128_pp.hpp" - -#include +//#include "algebra/curves/mnt/mnt4/mnt4_pp.hpp" +//#include "algebra/curves/mnt/mnt6/mnt6_pp.hpp" +#include "algebra/curves/alt_bn128/alt_bn128_pairing.hpp" +#include "algebra/curves/alt_bn128/alt_bn128_pairing.cpp" using namespace libsnark; @@ -45,11 +49,11 @@ void pairing_test() ans1.print(); ans2.print(); ans3.print(); - EXPECT_EQ(ans1, ans2); - EXPECT_EQ(ans2, ans3); + assert(ans1 == ans2); + assert(ans2 == ans3); - EXPECT_NE(ans1, GT_one); - EXPECT_EQ((ans1^Fr::field_char()), GT_one); + assert(ans1 != GT_one); + assert((ans1^Fr::field_char()) == GT_one); printf("\n\n"); } @@ -69,7 +73,7 @@ void double_miller_loop_test() const Fqk ans_1 = ppT::miller_loop(prec_P1, prec_Q1); const Fqk ans_2 = ppT::miller_loop(prec_P2, prec_Q2); const Fqk ans_12 = ppT::double_miller_loop(prec_P1, prec_Q1, prec_P2, prec_Q2); - EXPECT_EQ(ans_1 * ans_2, ans_12); + assert(ans_1 * ans_2 == ans_12); } template @@ -98,17 +102,31 @@ void affine_pairing_test() ans1.print(); ans2.print(); ans3.print(); - EXPECT_EQ(ans1, ans2); - EXPECT_EQ(ans2, ans3); + assert(ans1 == ans2); + assert(ans2 == ans3); - EXPECT_NE(ans1, GT_one); - EXPECT_EQ((ans1^Fr::field_char()), GT_one); + assert(ans1 != GT_one); + assert((ans1^Fr::field_char()) == GT_one); printf("\n\n"); } -TEST(algebra, bilinearity) +int main(void) { start_profiling(); + edwards_pp::init_public_params(); + pairing_test(); + double_miller_loop_test(); + + mnt6_pp::init_public_params(); + pairing_test(); + double_miller_loop_test(); + affine_pairing_test(); + + mnt4_pp::init_public_params(); + pairing_test(); + double_miller_loop_test(); + affine_pairing_test(); + alt_bn128_pp::init_public_params(); pairing_test(); double_miller_loop_test(); diff --git a/src/snark/src/algebra/curves/tests/test_groups.cpp b/src/snark/src/algebra/curves/tests/test_groups.cpp index 7bb7c31cc..4f64334ba 100644 --- a/src/snark/src/algebra/curves/tests/test_groups.cpp +++ b/src/snark/src/algebra/curves/tests/test_groups.cpp @@ -5,14 +5,15 @@ * @copyright MIT license (see LICENSE file) *****************************************************************************/ #include "common/profiling.hpp" +//#include "algebra/curves/edwards/edwards_pp.hpp" +//#include "algebra/curves/mnt/mnt4/mnt4_pp.hpp" +//#include "algebra/curves/mnt/mnt6/mnt6_pp.hpp" #ifdef CURVE_BN128 #include "algebra/curves/bn128/bn128_pp.hpp" #endif #include "algebra/curves/alt_bn128/alt_bn128_pp.hpp" #include -#include - using namespace libsnark; template @@ -24,31 +25,31 @@ void test_mixed_add() el = GroupT::zero(); el.to_special(); result = base.mixed_add(el); - EXPECT_EQ(result, base + el); + assert(result == base + el); base = GroupT::zero(); el = GroupT::random_element(); el.to_special(); result = base.mixed_add(el); - EXPECT_EQ(result, base + el); + assert(result == base + el); base = GroupT::random_element(); el = GroupT::zero(); el.to_special(); result = base.mixed_add(el); - EXPECT_EQ(result, base + el); + assert(result == base + el); base = GroupT::random_element(); el = GroupT::random_element(); el.to_special(); result = base.mixed_add(el); - EXPECT_EQ(result, base + el); + assert(result == base + el); base = GroupT::random_element(); el = base; el.to_special(); result = base.mixed_add(el); - EXPECT_EQ(result, base.dbl()); + assert(result == base.dbl()); } template @@ -59,53 +60,53 @@ void test_group() bigint<1> randsum = bigint<1>("121160274"); GroupT zero = GroupT::zero(); - EXPECT_EQ(zero, zero); + assert(zero == zero); GroupT one = GroupT::one(); - EXPECT_EQ(one, one); + assert(one == one); GroupT two = bigint<1>(2l) * GroupT::one(); - EXPECT_EQ(two, two); + assert(two == two); GroupT five = bigint<1>(5l) * GroupT::one(); GroupT three = bigint<1>(3l) * GroupT::one(); GroupT four = bigint<1>(4l) * GroupT::one(); - EXPECT_EQ(two+five, three+four); + assert(two+five == three+four); GroupT a = GroupT::random_element(); GroupT b = GroupT::random_element(); - EXPECT_NE(one, zero); - EXPECT_NE(a, zero); - EXPECT_NE(a, one); + assert(one != zero); + assert(a != zero); + assert(a != one); - EXPECT_NE(b, zero); - EXPECT_NE(b, one); + assert(b != zero); + assert(b != one); - EXPECT_EQ(a.dbl(), a + a); - EXPECT_EQ(b.dbl(), b + b); - EXPECT_EQ(one.add(two), three); - EXPECT_EQ(two.add(one), three); - EXPECT_EQ(a + b, b + a); - EXPECT_EQ(a - a, zero); - EXPECT_EQ(a - b, a + (-b)); - EXPECT_EQ(a - b, (-b) + a); + assert(a.dbl() == a + a); + assert(b.dbl() == b + b); + assert(one.add(two) == three); + assert(two.add(one) == three); + assert(a + b == b + a); + assert(a - a == zero); + assert(a - b == a + (-b)); + assert(a - b == (-b) + a); // handle special cases - EXPECT_EQ(zero + (-a), -a); - EXPECT_EQ(zero - a, -a); - EXPECT_EQ(a - zero, a); - EXPECT_EQ(a + zero, a); - EXPECT_EQ(zero + a, a); + assert(zero + (-a) == -a); + assert(zero - a == -a); + assert(a - zero == a); + assert(a + zero == a); + assert(zero + a == a); - EXPECT_EQ((a + b).dbl(), (a + b) + (b + a)); - EXPECT_EQ(bigint<1>("2") * (a + b), (a + b) + (b + a)); + assert((a + b).dbl() == (a + b) + (b + a)); + assert(bigint<1>("2") * (a + b) == (a + b) + (b + a)); - EXPECT_EQ((rand1 * a) + (rand2 * a), (randsum * a)); + assert((rand1 * a) + (rand2 * a) == (randsum * a)); - EXPECT_EQ(GroupT::order() * a, zero); - EXPECT_EQ(GroupT::order() * one, zero); - EXPECT_NE((GroupT::order() * a) - a, zero); - EXPECT_NE((GroupT::order() * one) - one, zero); + assert(GroupT::order() * a == zero); + assert(GroupT::order() * one == zero); + assert((GroupT::order() * a) - a != zero); + assert((GroupT::order() * one) - one != zero); test_mixed_add(); } @@ -114,7 +115,7 @@ template void test_mul_by_q() { GroupT a = GroupT::random_element(); - EXPECT_EQ((GroupT::base_field_char()*a), a.mul_by_q()); + assert((GroupT::base_field_char()*a) == a.mul_by_q()); } template @@ -128,14 +129,36 @@ void test_output() ss << g; GroupT gg; ss >> gg; - EXPECT_EQ(g, gg); + assert(g == gg); /* use a random point in next iteration */ g = GroupT::random_element(); } } -TEST(algebra, groups) +int main(void) { +/* + edwards_pp::init_public_params(); + test_group >(); + test_output >(); + test_group >(); + test_output >(); + test_mul_by_q >(); + + mnt4_pp::init_public_params(); + test_group >(); + test_output >(); + test_group >(); + test_output >(); + test_mul_by_q >(); + + mnt6_pp::init_public_params(); + test_group >(); + test_output >(); + test_group >(); + test_output >(); + test_mul_by_q >(); +*/ alt_bn128_pp::init_public_params(); test_group >(); test_output >(); diff --git a/src/snark/src/algebra/evaluation_domain/domains/basic_radix2_domain_aux.tcc b/src/snark/src/algebra/evaluation_domain/domains/basic_radix2_domain_aux.tcc index 138b82dbc..79f2ffec7 100644 --- a/src/snark/src/algebra/evaluation_domain/domains/basic_radix2_domain_aux.tcc +++ b/src/snark/src/algebra/evaluation_domain/domains/basic_radix2_domain_aux.tcc @@ -74,11 +74,11 @@ void _basic_serial_radix2_FFT(std::vector &a, const FieldT &omega) template void _basic_parallel_radix2_FFT_inner(std::vector &a, const FieldT &omega, const size_t log_cpus) { - const size_t num_cpus = 1ul< &a, const FieldT &omeg std::vector > tmp(num_cpus); for (size_t j = 0; j < num_cpus; ++j) { - tmp[j].resize(1ul<<(log_m-log_cpus), FieldT::zero()); + tmp[j].resize(UINT64_C(1)<<(log_m-log_cpus), FieldT::zero()); } #ifdef MULTICORE @@ -102,7 +102,7 @@ void _basic_parallel_radix2_FFT_inner(std::vector &a, const FieldT &omeg const FieldT omega_step = omega^(j<<(log_m - log_cpus)); FieldT elt = FieldT::one(); - for (size_t i = 0; i < 1ul<<(log_m - log_cpus); ++i) + for (size_t i = 0; i < UINT64_C(1)<<(log_m - log_cpus); ++i) { for (size_t s = 0; s < num_cpus; ++s) { @@ -135,7 +135,7 @@ void _basic_parallel_radix2_FFT_inner(std::vector &a, const FieldT &omeg #endif for (size_t i = 0; i < num_cpus; ++i) { - for (size_t j = 0; j < 1ul<<(log_m - log_cpus); ++j) + for (size_t j = 0; j < UINT64_C(1)<<(log_m - log_cpus); ++j) { // now: i = idx >> (log_m - log_cpus) and j = idx % (1u << (log_m - log_cpus)), for idx = ((i<<(log_m-log_cpus))+j) % (1u << log_m) a[(j< +class extended_radix2_domain : public evaluation_domain { +public: + + size_t small_m; + FieldT omega; + FieldT shift; + + extended_radix2_domain(const size_t m); + + void FFT(std::vector &a); + void iFFT(std::vector &a); + void cosetFFT(std::vector &a, const FieldT &g); + void icosetFFT(std::vector &a, const FieldT &g); + std::vector lagrange_coeffs(const FieldT &t); + FieldT get_element(const size_t idx); + FieldT compute_Z(const FieldT &t); + void add_poly_Z(const FieldT &coeff, std::vector &H); + void divide_by_Z_on_coset(std::vector &P); + +}; + +} // libsnark + +#include "algebra/evaluation_domain/domains/extended_radix2_domain.tcc" + +#endif // EXTENDED_RADIX2_DOMAIN_HPP_ diff --git a/src/snark/src/algebra/evaluation_domain/domains/extended_radix2_domain.tcc b/src/snark/src/algebra/evaluation_domain/domains/extended_radix2_domain.tcc new file mode 100644 index 000000000..bd5c700aa --- /dev/null +++ b/src/snark/src/algebra/evaluation_domain/domains/extended_radix2_domain.tcc @@ -0,0 +1,180 @@ +/** @file + ***************************************************************************** + + Implementation of interfaces for the "extended radix-2" evaluation domain. + + See extended_radix2_domain.hpp . + + ***************************************************************************** + * @author This file is part of libsnark, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef EXTENDED_RADIX2_DOMAIN_TCC_ + +#include "algebra/evaluation_domain/domains/basic_radix2_domain_aux.hpp" + +namespace libsnark { + +template +extended_radix2_domain::extended_radix2_domain(const size_t m) : evaluation_domain(m) +{ + assert(m > 1); + + const size_t logm = log2(m); + + assert(logm == FieldT::s + 1); + + small_m = m/2; + omega = get_root_of_unity(small_m); + shift = coset_shift(); +} + +template +void extended_radix2_domain::FFT(std::vector &a) +{ + assert(a.size() == this->m); + + std::vector a0(small_m, FieldT::zero()); + std::vector a1(small_m, FieldT::zero()); + + const FieldT shift_to_small_m = shift^bigint<1>(small_m); + + FieldT shift_i = FieldT::one(); + for (size_t i = 0; i < small_m; ++i) + { + a0[i] = a[i] + a[small_m + i]; + a1[i] = shift_i * (a[i] + shift_to_small_m * a[small_m + i]); + + shift_i *= shift; + } + + _basic_radix2_FFT(a0, omega); + _basic_radix2_FFT(a1, omega); + + for (size_t i = 0; i < small_m; ++i) + { + a[i] = a0[i]; + a[i+small_m] = a1[i]; + } +} + +template +void extended_radix2_domain::iFFT(std::vector &a) +{ + assert(a.size() == this->m); + + // note: this is not in-place + std::vector a0(a.begin(), a.begin() + small_m); + std::vector a1(a.begin() + small_m, a.end()); + + const FieldT omega_inverse = omega.inverse(); + _basic_radix2_FFT(a0, omega_inverse); + _basic_radix2_FFT(a1, omega_inverse); + + const FieldT shift_to_small_m = shift^bigint<1>(small_m); + const FieldT sconst = (FieldT(small_m) * (FieldT::one()-shift_to_small_m)).inverse(); + + const FieldT shift_inverse = shift.inverse(); + FieldT shift_inverse_i = FieldT::one(); + + for (size_t i = 0; i < small_m; ++i) + { + a[i] = sconst * (-shift_to_small_m * a0[i] + shift_inverse_i * a1[i]); + a[i+small_m] = sconst * (a0[i] - shift_inverse_i * a1[i]); + + shift_inverse_i *= shift_inverse; + } +} + +template +void extended_radix2_domain::cosetFFT(std::vector &a, const FieldT &g) +{ + _multiply_by_coset(a, g); + FFT(a); +} + +template +void extended_radix2_domain::icosetFFT(std::vector &a, const FieldT &g) +{ + iFFT(a); + _multiply_by_coset(a, g.inverse()); +} + +template +std::vector extended_radix2_domain::lagrange_coeffs(const FieldT &t) +{ + const std::vector T0 = _basic_radix2_lagrange_coeffs(small_m, t); + const std::vector T1 = _basic_radix2_lagrange_coeffs(small_m, t * shift.inverse()); + + std::vector result(this->m, FieldT::zero()); + + const FieldT t_to_small_m = t ^ bigint<1>(small_m); + const FieldT shift_to_small_m = shift ^ bigint<1>(small_m); + const FieldT one_over_denom = (shift_to_small_m - FieldT::one()).inverse(); + const FieldT T0_coeff = (t_to_small_m - shift_to_small_m) * (-one_over_denom); + const FieldT T1_coeff = (t_to_small_m - FieldT::one()) * one_over_denom; + for (size_t i = 0; i < small_m; ++i) + { + result[i] = T0[i] * T0_coeff; + result[i+small_m] = T1[i] * T1_coeff; + } + + return result; +} + +template +FieldT extended_radix2_domain::get_element(const size_t idx) +{ + if (idx < small_m) + { + return omega^idx; + } + else + { + return shift*(omega^(idx-small_m)); + } +} + +template +FieldT extended_radix2_domain::compute_Z(const FieldT &t) +{ + return ((t^small_m) - FieldT::one()) * ((t^small_m) - (shift^small_m)); +} + +template +void extended_radix2_domain::add_poly_Z(const FieldT &coeff, std::vector &H) +{ + assert(H.size() == this->m+1); + const FieldT shift_to_small_m = shift^small_m; + + H[this->m] += coeff; + H[small_m] -= coeff * (shift_to_small_m + FieldT::one()); + H[0] += coeff * shift_to_small_m; +} + +template +void extended_radix2_domain::divide_by_Z_on_coset(std::vector &P) +{ + const FieldT coset = FieldT::multiplicative_generator; + + const FieldT coset_to_small_m = coset^small_m; + const FieldT shift_to_small_m = shift^small_m; + + const FieldT Z0 = (coset_to_small_m - FieldT::one()) * (coset_to_small_m - shift_to_small_m); + const FieldT Z1 = (coset_to_small_m*shift_to_small_m - FieldT::one()) * (coset_to_small_m * shift_to_small_m - shift_to_small_m); + + const FieldT Z0_inverse = Z0.inverse(); + const FieldT Z1_inverse = Z1.inverse(); + + for (size_t i = 0; i < small_m; ++i) + { + P[i] *= Z0_inverse; + P[i+small_m] *= Z1_inverse; + } +} + +} // libsnark + +#endif // EXTENDED_RADIX2_DOMAIN_TCC_ diff --git a/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.hpp b/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.hpp new file mode 100644 index 000000000..ae9818a07 --- /dev/null +++ b/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.hpp @@ -0,0 +1,50 @@ +/** @file + ***************************************************************************** + + Declaration of interfaces for the "step radix-2" evaluation domain. + + Roughly, the domain has size m = 2^k + 2^r and consists of + "the 2^k-th roots of unity" union "a coset of 2^r-th roots of unity". + + ***************************************************************************** + * @author This file is part of libsnark, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef STEP_RADIX2_DOMAIN_HPP_ +#define STEP_RADIX2_DOMAIN_HPP_ + +#include "algebra/evaluation_domain/evaluation_domain.hpp" + +namespace libsnark { + +template +class step_radix2_domain : public evaluation_domain { +public: + + size_t big_m; + size_t small_m; + FieldT omega; + FieldT big_omega; + FieldT small_omega; + + step_radix2_domain(const size_t m); + + void FFT(std::vector &a); + void iFFT(std::vector &a); + void cosetFFT(std::vector &a, const FieldT &g); + void icosetFFT(std::vector &a, const FieldT &g); + std::vector lagrange_coeffs(const FieldT &t); + FieldT get_element(const size_t idx); + FieldT compute_Z(const FieldT &t); + void add_poly_Z(const FieldT &coeff, std::vector &H); + void divide_by_Z_on_coset(std::vector &P); + +}; + +} // libsnark + +#include "algebra/evaluation_domain/domains/step_radix2_domain.tcc" + +#endif // STEP_RADIX2_DOMAIN_HPP_ diff --git a/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.tcc b/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.tcc new file mode 100644 index 000000000..c3baf6969 --- /dev/null +++ b/src/snark/src/algebra/evaluation_domain/domains/step_radix2_domain.tcc @@ -0,0 +1,247 @@ +/** @file + ***************************************************************************** + + Implementation of interfaces for the "step radix-2" evaluation domain. + + See step_radix2_domain.hpp . + + ***************************************************************************** + * @author This file is part of libsnark, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef STEP_RADIX2_DOMAIN_TCC_ + +#include "algebra/evaluation_domain/domains/basic_radix2_domain_aux.hpp" + +namespace libsnark { + +template +step_radix2_domain::step_radix2_domain(const size_t m) : evaluation_domain(m) +{ + assert(m > 1); + + big_m = UINT64_C(1)<<(log2(m)-1); + small_m = m - big_m; + + assert(small_m == UINT64_C(1)<(UINT64_C(1)<(small_m); +} + +template +void step_radix2_domain::FFT(std::vector &a) +{ + assert(a.size() == this->m); + std::vector c(big_m, FieldT::zero()); + std::vector d(big_m, FieldT::zero()); + + FieldT omega_i = FieldT::one(); + for (size_t i = 0; i < big_m; ++i) + { + c[i] = (i < small_m ? a[i] + a[i+big_m] : a[i]); + d[i] = omega_i * (i < small_m ? a[i] - a[i+big_m] : a[i]); + omega_i *= omega; + } + + std::vector e(small_m, FieldT::zero()); + const size_t compr = UINT64_C(1)<<(log2(big_m) - log2(small_m)); + for (size_t i = 0; i < small_m; ++i) + { + for (size_t j = 0; j < compr; ++j) + { + e[i] += d[i + j * small_m]; + } + } + + _basic_radix2_FFT(c, omega.squared()); + _basic_radix2_FFT(e, get_root_of_unity(small_m)); + + for (size_t i = 0; i < big_m; ++i) + { + a[i] = c[i]; + } + + for (size_t i = 0; i < small_m; ++i) + { + a[i+big_m] = e[i]; + } +} + +template +void step_radix2_domain::iFFT(std::vector &a) +{ + assert(a.size() == this->m); + + std::vector U0(a.begin(), a.begin() + big_m); + std::vector U1(a.begin() + big_m, a.end()); + + _basic_radix2_FFT(U0, omega.squared().inverse()); + _basic_radix2_FFT(U1, get_root_of_unity(small_m).inverse()); + + const FieldT U0_size_inv = FieldT(big_m).inverse(); + for (size_t i = 0; i < big_m; ++i) + { + U0[i] *= U0_size_inv; + } + + const FieldT U1_size_inv = FieldT(small_m).inverse(); + for (size_t i = 0; i < small_m; ++i) + { + U1[i] *= U1_size_inv; + } + + std::vector tmp = U0; + FieldT omega_i = FieldT::one(); + for (size_t i = 0; i < big_m; ++i) + { + tmp[i] *= omega_i; + omega_i *= omega; + } + + // save A_suffix + for (size_t i = small_m; i < big_m; ++i) + { + a[i] = U0[i]; + } + + const size_t compr = UINT64_C(1)<<(log2(big_m) - log2(small_m)); + for (size_t i = 0; i < small_m; ++i) + { + for (size_t j = 1; j < compr; ++j) + { + U1[i] -= tmp[i + j * small_m]; + } + } + + const FieldT omega_inv = omega.inverse(); + FieldT omega_inv_i = FieldT::one(); + for (size_t i = 0; i < small_m; ++i) + { + U1[i] *= omega_inv_i; + omega_inv_i *= omega_inv; + } + + // compute A_prefix + const FieldT over_two = FieldT(2).inverse(); + for (size_t i = 0; i < small_m; ++i) + { + a[i] = (U0[i]+U1[i]) * over_two; + } + + // compute B2 + for (size_t i = 0; i < small_m; ++i) + { + a[big_m + i] = (U0[i]-U1[i]) * over_two; + } +} + +template +void step_radix2_domain::cosetFFT(std::vector &a, const FieldT &g) +{ + _multiply_by_coset(a, g); + FFT(a); +} + +template +void step_radix2_domain::icosetFFT(std::vector &a, const FieldT &g) +{ + iFFT(a); + _multiply_by_coset(a, g.inverse()); +} + +template +std::vector step_radix2_domain::lagrange_coeffs(const FieldT &t) +{ + std::vector inner_big = _basic_radix2_lagrange_coeffs(big_m, t); + std::vector inner_small = _basic_radix2_lagrange_coeffs(small_m, t * omega.inverse()); + + std::vector result(this->m, FieldT::zero()); + + const FieldT L0 = (t^small_m)-(omega^small_m); + const FieldT omega_to_small_m = omega^small_m; + const FieldT big_omega_to_small_m = big_omega ^ small_m; + FieldT elt = FieldT::one(); + for (size_t i = 0; i < big_m; ++i) + { + result[i] = inner_big[i] * L0 * (elt - omega_to_small_m).inverse(); + elt *= big_omega_to_small_m; + } + + const FieldT L1 = ((t^big_m)-FieldT::one()) * ((omega^big_m) - FieldT::one()).inverse(); + + for (size_t i = 0; i < small_m; ++i) + { + result[big_m + i] = L1 * inner_small[i]; + } + + return result; +} + +template +FieldT step_radix2_domain::get_element(const size_t idx) +{ + if (idx < big_m) + { + return big_omega^idx; + } + else + { + return omega * (small_omega^(idx-big_m)); + } +} + +template +FieldT step_radix2_domain::compute_Z(const FieldT &t) +{ + return ((t^big_m) - FieldT::one()) * ((t^small_m) - (omega^small_m)); +} + +template +void step_radix2_domain::add_poly_Z(const FieldT &coeff, std::vector &H) +{ + assert(H.size() == this->m+1); + const FieldT omega_to_small_m = omega^small_m; + + H[this->m] += coeff; + H[big_m] -= coeff * omega_to_small_m; + H[small_m] -= coeff; + H[0] += coeff * omega_to_small_m; +} + +template +void step_radix2_domain::divide_by_Z_on_coset(std::vector &P) +{ + // (c^{2^k}-1) * (c^{2^r} * w^{2^{r+1}*i) - w^{2^r}) + const FieldT coset = FieldT::multiplicative_generator; + + const FieldT Z0 = (coset^big_m) - FieldT::one(); + const FieldT coset_to_small_m_times_Z0 = (coset^small_m) * Z0; + const FieldT omega_to_small_m_times_Z0 = (omega^small_m) * Z0; + const FieldT omega_to_2small_m = omega^(2*small_m); + FieldT elt = FieldT::one(); + + for (size_t i = 0; i < big_m; ++i) + { + P[i] *= (coset_to_small_m_times_Z0 * elt - omega_to_small_m_times_Z0).inverse(); + elt *= omega_to_2small_m; + } + + // (c^{2^k}*w^{2^k}-1) * (c^{2^k} * w^{2^r} - w^{2^r}) + + const FieldT Z1 = ((((coset*omega)^big_m) - FieldT::one()) * (((coset * omega)^small_m) - (omega^small_m))); + const FieldT Z1_inverse = Z1.inverse(); + + for (size_t i = 0; i < small_m; ++i) + { + P[big_m + i] *= Z1_inverse; + } + +} + +} // libsnark + +#endif // STEP_RADIX2_DOMAIN_TCC_ diff --git a/src/snark/src/algebra/evaluation_domain/evaluation_domain.tcc b/src/snark/src/algebra/evaluation_domain/evaluation_domain.tcc index 8e3ea7a62..5b4b22e48 100644 --- a/src/snark/src/algebra/evaluation_domain/evaluation_domain.tcc +++ b/src/snark/src/algebra/evaluation_domain/evaluation_domain.tcc @@ -22,6 +22,8 @@ #include #include "algebra/fields/field_utils.hpp" #include "algebra/evaluation_domain/domains/basic_radix2_domain.hpp" +#include "algebra/evaluation_domain/domains/extended_radix2_domain.hpp" +#include "algebra/evaluation_domain/domains/step_radix2_domain.hpp" namespace libsnark { @@ -41,7 +43,7 @@ std::shared_ptr > get_evaluation_domain(const size_t m { print_indent(); printf("* Selected domain: extended_radix2\n"); } - assert(0); + result.reset(new extended_radix2_domain(min_size)); } else { @@ -54,9 +56,9 @@ std::shared_ptr > get_evaluation_domain(const size_t m } else { - const size_t big = 1ul<<(log2(min_size)-1); + const size_t big = UINT64_C(1)<<(log2(min_size)-1); const size_t small = min_size - big; - const size_t rounded_small = (1ul< > get_evaluation_domain(const size_t m { print_indent(); printf("* Selected domain: extended_radix2\n"); } - assert(0); + result.reset(new extended_radix2_domain(big + rounded_small)); } } else @@ -82,7 +84,7 @@ std::shared_ptr > get_evaluation_domain(const size_t m { print_indent(); printf("* Selected domain: step_radix2\n"); } - assert(0); + result.reset(new step_radix2_domain(big + rounded_small)); } } diff --git a/src/snark/src/algebra/exponentiation/exponentiation.hpp b/src/snark/src/algebra/exponentiation/exponentiation.hpp index a8a2c925c..836ebf002 100644 --- a/src/snark/src/algebra/exponentiation/exponentiation.hpp +++ b/src/snark/src/algebra/exponentiation/exponentiation.hpp @@ -22,7 +22,7 @@ template FieldT power(const FieldT &base, const bigint &exponent); template -FieldT power(const FieldT &base, const unsigned long exponent); +FieldT power(const FieldT &base, const uint64_t exponent); } // libsnark diff --git a/src/snark/src/algebra/exponentiation/exponentiation.tcc b/src/snark/src/algebra/exponentiation/exponentiation.tcc index dd557eb12..7ac3bf5d3 100644 --- a/src/snark/src/algebra/exponentiation/exponentiation.tcc +++ b/src/snark/src/algebra/exponentiation/exponentiation.tcc @@ -25,7 +25,7 @@ FieldT power(const FieldT &base, const bigint &exponent) bool found_one = false; - for (long i = exponent.max_bits() - 1; i >= 0; --i) + for (int64_t i = exponent.max_bits() - 1; i >= 0; --i) { if (found_one) { @@ -43,7 +43,7 @@ FieldT power(const FieldT &base, const bigint &exponent) } template -FieldT power(const FieldT &base, const unsigned long exponent) +FieldT power(const FieldT &base, const uint64_t exponent) { return power(base, bigint<1>(exponent)); } diff --git a/src/snark/src/algebra/fields/bigint.hpp b/src/snark/src/algebra/fields/bigint.hpp index ff00dd5cf..dc47a7efc 100644 --- a/src/snark/src/algebra/fields/bigint.hpp +++ b/src/snark/src/algebra/fields/bigint.hpp @@ -33,7 +33,7 @@ public: mp_limb_t data[n] = {0}; bigint() = default; - bigint(const unsigned long x); /// Initalize from a small integer + bigint(const uint64_t x); /// Initalize from a small integer bigint(const char* s); /// Initialize from a string containing an integer in decimal notation bigint(const mpz_t r); /// Initialize from MPZ element @@ -46,7 +46,7 @@ public: size_t max_bits() const { return n * GMP_NUMB_BITS; } size_t num_bits() const; - unsigned long as_ulong() const; /* return the last limb of the integer */ + uint64_t as_ulong() const; /* return the last limb of the integer */ void to_mpz(mpz_t r) const; bool test_bit(const std::size_t bitno) const; diff --git a/src/snark/src/algebra/fields/bigint.tcc b/src/snark/src/algebra/fields/bigint.tcc index f81addf45..fd295f479 100644 --- a/src/snark/src/algebra/fields/bigint.tcc +++ b/src/snark/src/algebra/fields/bigint.tcc @@ -9,6 +9,7 @@ #ifndef BIGINT_TCC_ #define BIGINT_TCC_ +#include #include #include #include @@ -17,9 +18,9 @@ namespace libsnark { template -bigint::bigint(const unsigned long x) /// Initalize from a small integer +bigint::bigint(const uint64_t x) /// Initalize from a small integer { - static_assert(ULONG_MAX <= GMP_NUMB_MAX, "unsigned long does not fit in a GMP limb"); + static_assert(UINT64_MAX <= GMP_NUMB_MAX, "uint64_t does not fit in a GMP limb"); this->data[0] = x; } @@ -105,7 +106,7 @@ template size_t bigint::num_bits() const { /* - for (long i = max_bits(); i >= 0; --i) + for (int64_t i = max_bits(); i >= 0; --i) { if (this->test_bit(i)) { @@ -115,7 +116,7 @@ size_t bigint::num_bits() const return 0; */ - for (long i = n-1; i >= 0; --i) + for (int64_t i = n-1; i >= 0; --i) { mp_limb_t x = this->data[i]; if (x == 0) @@ -124,14 +125,14 @@ size_t bigint::num_bits() const } else { - return ((i+1) * GMP_NUMB_BITS) - __builtin_clzl(x); + return (((i+1) * GMP_NUMB_BITS) - __builtin_clzl(x)) / 2; } } return 0; } template -unsigned long bigint::as_ulong() const +uint64_t bigint::as_ulong() const { return this->data[0]; } diff --git a/src/snark/src/algebra/fields/field_utils.hpp b/src/snark/src/algebra/fields/field_utils.hpp index a07ecfe28..9fac6c38d 100644 --- a/src/snark/src/algebra/fields/field_utils.hpp +++ b/src/snark/src/algebra/fields/field_utils.hpp @@ -16,13 +16,13 @@ namespace libsnark { // returns root of unity of order n (for n a power of 2), if one exists template -FieldT get_root_of_unity(const size_t n); +FieldT get_root_of_unity(const unsigned long long n); template -std::vector pack_int_vector_into_field_element_vector(const std::vector &v, const size_t w); +std::vector pack_int_vector_into_field_element_vector(const std::vector &v, const unsigned long long w); template -std::vector pack_bit_vector_into_field_element_vector(const bit_vector &v, const size_t chunk_bits); +std::vector pack_bit_vector_into_field_element_vector(const bit_vector &v, const unsigned long long chunk_bits); template std::vector pack_bit_vector_into_field_element_vector(const bit_vector &v); @@ -37,7 +37,7 @@ template bit_vector convert_field_element_to_bit_vector(const FieldT &el); template -bit_vector convert_field_element_to_bit_vector(const FieldT &el, const size_t bitcount); +bit_vector convert_field_element_to_bit_vector(const FieldT &el, const unsigned long long bitcount); template FieldT convert_bit_vector_to_field_element(const bit_vector &v); diff --git a/src/snark/src/algebra/fields/field_utils.tcc b/src/snark/src/algebra/fields/field_utils.tcc index 13197b226..449aaec9b 100644 --- a/src/snark/src/algebra/fields/field_utils.tcc +++ b/src/snark/src/algebra/fields/field_utils.tcc @@ -21,14 +21,14 @@ FieldT coset_shift() } template -FieldT get_root_of_unity(const size_t n) +FieldT get_root_of_unity(const unsigned long long n) { - const size_t logn = log2(n); + const unsigned long long logn = log2(n); assert(n == (1u << logn)); assert(logn <= FieldT::s); FieldT omega = FieldT::root_of_unity; - for (size_t i = FieldT::s; i > logn; --i) + for (unsigned long long i = FieldT::s; i > logn; --i) { omega *= omega; } @@ -37,21 +37,21 @@ FieldT get_root_of_unity(const size_t n) } template -std::vector pack_int_vector_into_field_element_vector(const std::vector &v, const size_t w) +std::vector pack_int_vector_into_field_element_vector(const std::vector &v, const unsigned long long w) { - const size_t chunk_bits = FieldT::capacity(); - const size_t repacked_size = div_ceil(v.size() * w, chunk_bits); + const unsigned long long chunk_bits = FieldT::capacity(); + const unsigned long long repacked_size = div_ceil(v.size() * w, chunk_bits); std::vector result(repacked_size); - for (size_t i = 0; i < repacked_size; ++i) + for (unsigned long long i = 0; i < repacked_size; ++i) { bigint b; - for (size_t j = 0; j < chunk_bits; ++j) + for (unsigned long long j = 0; j < chunk_bits; ++j) { - const size_t word_index = (i * chunk_bits + j) / w; - const size_t pos_in_word = (i * chunk_bits + j) % w; - const size_t word_or_0 = (word_index < v.size() ? v[word_index] : 0); - const size_t bit = (word_or_0 >> pos_in_word) & 1; + const unsigned long long word_index = (i * chunk_bits + j) / w; + const unsigned long long pos_in_word = (i * chunk_bits + j) % w; + const unsigned long long word_or_0 = (word_index < v.size() ? v[word_index] : 0); + const unsigned long long bit = (word_or_0 >> pos_in_word) & 1; b.data[j / GMP_NUMB_BITS] |= bit << (j % GMP_NUMB_BITS); } @@ -62,11 +62,11 @@ std::vector pack_int_vector_into_field_element_vector(const std::vector< } template -std::vector pack_bit_vector_into_field_element_vector(const bit_vector &v, const size_t chunk_bits) +std::vector pack_bit_vector_into_field_element_vector(const bit_vector &v, const unsigned long long chunk_bits) { assert(chunk_bits <= FieldT::capacity()); - const size_t repacked_size = div_ceil(v.size(), chunk_bits); + const unsigned long long repacked_size = div_ceil(v.size(), chunk_bits); std::vector result(repacked_size); for (size_t i = 0; i < repacked_size; ++i) @@ -131,7 +131,7 @@ bit_vector convert_field_element_to_bit_vector(const FieldT &el) } template -bit_vector convert_field_element_to_bit_vector(const FieldT &el, const size_t bitcount) +bit_vector convert_field_element_to_bit_vector(const FieldT &el, const unsigned long long bitcount) { bit_vector result = convert_field_element_to_bit_vector(el); result.resize(bitcount); @@ -171,7 +171,7 @@ void batch_invert(std::vector &vec) FieldT acc_inverse = acc.inverse(); - for (long i = vec.size()-1; i >= 0; --i) + for (int64_t i = vec.size()-1; i >= 0; --i) { const FieldT old_el = vec[i]; vec[i] = acc_inverse * prod[i]; diff --git a/src/snark/src/algebra/fields/fp.hpp b/src/snark/src/algebra/fields/fp.hpp index a4986833c..1dce26c2d 100644 --- a/src/snark/src/algebra/fields/fp.hpp +++ b/src/snark/src/algebra/fields/fp.hpp @@ -44,11 +44,11 @@ public: static const mp_size_t num_limbs = n; static const constexpr bigint& mod = modulus; #ifdef PROFILE_OP_COUNTS - static long long add_cnt; - static long long sub_cnt; - static long long mul_cnt; - static long long sqr_cnt; - static long long inv_cnt; + static int64_t add_cnt; + static int64_t sub_cnt; + static int64_t mul_cnt; + static int64_t sqr_cnt; + static int64_t inv_cnt; #endif static size_t num_bits; static bigint euler; // (modulus-1)/2 @@ -67,9 +67,9 @@ public: Fp_model() {}; Fp_model(const bigint &b); - Fp_model(const long x, const bool is_unsigned=false); + Fp_model(const int64_t x, const bool is_unsigned=false); - void set_ulong(const unsigned long x); + void set_ulong(const uint64_t x); void mul_reduce(const bigint &other); @@ -82,7 +82,7 @@ public: /* Return the last limb of the standard representation of the field element. E.g. on 64-bit architectures Fp(123).as_ulong() and Fp(2^64+123).as_ulong() would both return 123. */ - unsigned long as_ulong() const; + uint64_t as_ulong() const; bool operator==(const Fp_model& other) const; bool operator!=(const Fp_model& other) const; @@ -93,7 +93,7 @@ public: Fp_model& operator+=(const Fp_model& other); Fp_model& operator-=(const Fp_model& other); Fp_model& operator*=(const Fp_model& other); - Fp_model& operator^=(const unsigned long pow); + Fp_model& operator^=(const uint64_t pow); template Fp_model& operator^=(const bigint &pow); @@ -107,12 +107,12 @@ public: Fp_model inverse() const; Fp_model sqrt() const; // HAS TO BE A SQUARE (else does not terminate) - Fp_model operator^(const unsigned long pow) const; + Fp_model operator^(const unsigned long long pow) const; template Fp_model operator^(const bigint &pow) const; - static size_t size_in_bits() { return num_bits; } - static size_t capacity() { return num_bits - 1; } + static unsigned long long size_in_bits() { return num_bits; } + static unsigned long long capacity() { return num_bits - 1; } static bigint field_char() { return modulus; } static Fp_model zero(); @@ -141,13 +141,13 @@ long long Fp_model::inv_cnt = 0; #endif template& modulus> -size_t Fp_model::num_bits; +unsigned long long Fp_model::num_bits; template& modulus> bigint Fp_model::euler; template& modulus> -size_t Fp_model::s; +unsigned long long Fp_model::s; template& modulus> bigint Fp_model::t; diff --git a/src/snark/src/algebra/fields/fp.tcc b/src/snark/src/algebra/fields/fp.tcc index 566e99324..6c14f0eed 100644 --- a/src/snark/src/algebra/fields/fp.tcc +++ b/src/snark/src/algebra/fields/fp.tcc @@ -194,7 +194,7 @@ Fp_model::Fp_model(const bigint &b) } template& modulus> -Fp_model::Fp_model(const long x, const bool is_unsigned) +Fp_model::Fp_model(const int64_t x, const bool is_unsigned) { if (is_unsigned || x >= 0) { @@ -210,7 +210,7 @@ Fp_model::Fp_model(const long x, const bool is_unsigned) } template& modulus> -void Fp_model::set_ulong(const unsigned long x) +void Fp_model::set_ulong(const uint64_t x) { this->mont_repr.clear(); this->mont_repr.data[0] = x; @@ -237,7 +237,7 @@ bigint Fp_model::as_bigint() const } template& modulus> -unsigned long Fp_model::as_ulong() const +uint64_t Fp_model::as_ulong() const { return this->as_bigint().as_ulong(); } @@ -502,7 +502,7 @@ Fp_model& Fp_model::operator*=(const Fp_model& } template& modulus> -Fp_model& Fp_model::operator^=(const unsigned long pow) +Fp_model& Fp_model::operator^=(const uint64_t pow) { (*this) = power >(*this, pow); return (*this); @@ -538,7 +538,7 @@ Fp_model Fp_model::operator*(const Fp_model& ot } template& modulus> -Fp_model Fp_model::operator^(const unsigned long pow) const +Fp_model Fp_model::operator^(const uint64_t pow) const { Fp_model r(*this); return (r ^= pow); @@ -684,13 +684,13 @@ Fp_model Fp_model::random_element() /// returns random el r.mont_repr.randomize(); /* clear all bits higher than MSB of modulus */ - size_t bitno = GMP_NUMB_BITS * n - 1; + unsigned long long bitno = GMP_NUMB_BITS * n - 1; while (modulus.test_bit(bitno) == false) { - const std::size_t part = bitno/GMP_NUMB_BITS; - const std::size_t bit = bitno - (GMP_NUMB_BITS*part); + const unsigned long long part = bitno/GMP_NUMB_BITS; + const unsigned long long bit = bitno - (GMP_NUMB_BITS*part); - r.mont_repr.data[part] &= ~(1ul< Fp_model::sqrt() const Fp_model one = Fp_model::one(); - size_t v = Fp_model::s; + unsigned long long v = Fp_model::s; Fp_model z = Fp_model::nqr_to_t; Fp_model w = (*this)^Fp_model::t_minus_1_over_2; Fp_model x = (*this) * w; @@ -734,7 +734,7 @@ Fp_model Fp_model::sqrt() const while (b != one) { - size_t m = 0; + unsigned long long m = 0; Fp_model b2m = b; while (b2m != one) { diff --git a/src/snark/src/algebra/fields/fp12_2over3over2.hpp b/src/snark/src/algebra/fields/fp12_2over3over2.hpp index 1de9d88b4..61a6c57c7 100644 --- a/src/snark/src/algebra/fields/fp12_2over3over2.hpp +++ b/src/snark/src/algebra/fields/fp12_2over3over2.hpp @@ -66,7 +66,7 @@ public: Fp12_2over3over2_model squared_karatsuba() const; Fp12_2over3over2_model squared_complex() const; Fp12_2over3over2_model inverse() const; - Fp12_2over3over2_model Frobenius_map(unsigned long power) const; + Fp12_2over3over2_model Frobenius_map(uint64_t power) const; Fp12_2over3over2_model unitary_inverse() const; Fp12_2over3over2_model cyclotomic_squared() const; @@ -78,7 +78,7 @@ public: Fp12_2over3over2_model cyclotomic_exp(const bigint &exponent) const; static bigint base_field_char() { return modulus; } - static size_t extension_degree() { return 12; } + static unsigned long long extension_degree() { return 12; } friend std::ostream& operator<< (std::ostream &out, const Fp12_2over3over2_model &el); friend std::istream& operator>> (std::istream &in, Fp12_2over3over2_model &el); diff --git a/src/snark/src/algebra/fields/fp12_2over3over2.tcc b/src/snark/src/algebra/fields/fp12_2over3over2.tcc index 2fbc0b649..680d9429e 100644 --- a/src/snark/src/algebra/fields/fp12_2over3over2.tcc +++ b/src/snark/src/algebra/fields/fp12_2over3over2.tcc @@ -156,7 +156,7 @@ Fp12_2over3over2_model Fp12_2over3over2_model::inverse() c } template& modulus> -Fp12_2over3over2_model Fp12_2over3over2_model::Frobenius_map(unsigned long power) const +Fp12_2over3over2_model Fp12_2over3over2_model::Frobenius_map(uint64_t power) const { return Fp12_2over3over2_model(c0.Frobenius_map(power), Frobenius_coeffs_c1[power % 12] * c1.Frobenius_map(power)); @@ -339,16 +339,16 @@ Fp12_2over3over2_model Fp12_2over3over2_model::cyclotomic Fp12_2over3over2_model res = Fp12_2over3over2_model::one(); bool found_one = false; - for (long i = m-1; i >= 0; --i) + for (int64_t i = m-1; i >= 0; --i) { - for (long j = GMP_NUMB_BITS - 1; j >= 0; --j) + for (int64_t j = GMP_NUMB_BITS - 1; j >= 0; --j) { if (found_one) { res = res.cyclotomic_squared(); } - if (exponent.data[i] & (1ul<>(std::istream& in, std::vector> s; char b; diff --git a/src/snark/src/algebra/fields/fp2.hpp b/src/snark/src/algebra/fields/fp2.hpp index f07726918..449e60807 100644 --- a/src/snark/src/algebra/fields/fp2.hpp +++ b/src/snark/src/algebra/fields/fp2.hpp @@ -37,7 +37,7 @@ public: typedef Fp_model my_Fp; static bigint<2*n> euler; // (modulus^2-1)/2 - static size_t s; // modulus^2 = 2^s * t + 1 + static unsigned long long s; // modulus^2 = 2^s * t + 1 static bigint<2*n> t; // with t odd static bigint<2*n> t_minus_1_over_2; // (t-1)/2 static my_Fp non_residue; // X^4-non_residue irreducible over Fp; used for constructing Fp2 = Fp[X] / (X^2 - non_residue) @@ -66,7 +66,7 @@ public: Fp2_model operator-() const; Fp2_model squared() const; // default is squared_complex Fp2_model inverse() const; - Fp2_model Frobenius_map(unsigned long power) const; + Fp2_model Frobenius_map(uint64_t power) const; Fp2_model sqrt() const; // HAS TO BE A SQUARE (else does not terminate) Fp2_model squared_karatsuba() const; Fp2_model squared_complex() const; @@ -74,7 +74,7 @@ public: template Fp2_model operator^(const bigint &other) const; - static size_t size_in_bits() { return 2*my_Fp::size_in_bits(); } + static unsigned long long size_in_bits() { return 2*my_Fp::size_in_bits(); } static bigint base_field_char() { return modulus; } friend std::ostream& operator<< (std::ostream &out, const Fp2_model &el); @@ -94,7 +94,7 @@ template& modulus> bigint<2*n> Fp2_model::euler; template& modulus> -size_t Fp2_model::s; +unsigned long long Fp2_model::s; template& modulus> bigint<2*n> Fp2_model::t; diff --git a/src/snark/src/algebra/fields/fp2.tcc b/src/snark/src/algebra/fields/fp2.tcc index 1632a04c7..84aa3035c 100644 --- a/src/snark/src/algebra/fields/fp2.tcc +++ b/src/snark/src/algebra/fields/fp2.tcc @@ -136,7 +136,7 @@ Fp2_model Fp2_model::inverse() const } template& modulus> -Fp2_model Fp2_model::Frobenius_map(unsigned long power) const +Fp2_model Fp2_model::Frobenius_map(uint64_t power) const { return Fp2_model(c0, Frobenius_coeffs_c1[power % 2] * c1); @@ -151,7 +151,7 @@ Fp2_model Fp2_model::sqrt() const Fp2_model one = Fp2_model::one(); - size_t v = Fp2_model::s; + unsigned long long v = Fp2_model::s; Fp2_model z = Fp2_model::nqr_to_t; Fp2_model w = (*this)^Fp2_model::t_minus_1_over_2; Fp2_model x = (*this) * w; @@ -175,7 +175,7 @@ Fp2_model Fp2_model::sqrt() const while (b != one) { - size_t m = 0; + unsigned long long m = 0; Fp2_model b2m = b; while (b2m != one) { @@ -239,7 +239,7 @@ std::istream& operator>>(std::istream& in, std::vector > & { v.clear(); - size_t s; + unsigned long long s; in >> s; char b; diff --git a/src/snark/src/algebra/fields/fp3.hpp b/src/snark/src/algebra/fields/fp3.hpp index 53b178a27..c07a4635b 100644 --- a/src/snark/src/algebra/fields/fp3.hpp +++ b/src/snark/src/algebra/fields/fp3.hpp @@ -37,7 +37,7 @@ public: typedef Fp_model my_Fp; static bigint<3*n> euler; // (modulus^3-1)/2 - static size_t s; // modulus^3 = 2^s * t + 1 + static unsigned long long s; // modulus^3 = 2^s * t + 1 static bigint<3*n> t; // with t odd static bigint<3*n> t_minus_1_over_2; // (t-1)/2 static my_Fp non_residue; // X^6-non_residue irreducible over Fp; used for constructing Fp3 = Fp[X] / (X^3 - non_residue) @@ -73,7 +73,7 @@ public: template Fp3_model operator^(const bigint &other) const; - static size_t size_in_bits() { return 3*my_Fp::size_in_bits(); } + static unsigned long long size_in_bits() { return 3*my_Fp::size_in_bits(); } static bigint base_field_char() { return modulus; } friend std::ostream& operator<< (std::ostream &out, const Fp3_model &el); @@ -93,7 +93,7 @@ template& modulus> bigint<3*n> Fp3_model::euler; template& modulus> -size_t Fp3_model::s; +unsigned long long Fp3_model::s; template& modulus> bigint<3*n> Fp3_model::t; diff --git a/src/snark/src/algebra/fields/fp3.tcc b/src/snark/src/algebra/fields/fp3.tcc index 590a2a987..fb0e0fe02 100644 --- a/src/snark/src/algebra/fields/fp3.tcc +++ b/src/snark/src/algebra/fields/fp3.tcc @@ -149,7 +149,7 @@ Fp3_model Fp3_model::sqrt() const { Fp3_model one = Fp3_model::one(); - size_t v = Fp3_model::s; + unsigned long long v = Fp3_model::s; Fp3_model z = Fp3_model::nqr_to_t; Fp3_model w = (*this)^Fp3_model::t_minus_1_over_2; Fp3_model x = (*this) * w; @@ -173,7 +173,7 @@ Fp3_model Fp3_model::sqrt() const while (b != one) { - size_t m = 0; + unsigned long long m = 0; Fp3_model b2m = b; while (b2m != one) { @@ -237,7 +237,7 @@ std::istream& operator>>(std::istream& in, std::vector > & { v.clear(); - size_t s; + unsigned long long s; in >> s; char b; diff --git a/src/snark/src/algebra/fields/fp6_3over2.hpp b/src/snark/src/algebra/fields/fp6_3over2.hpp index 335d61c53..4441fb36a 100644 --- a/src/snark/src/algebra/fields/fp6_3over2.hpp +++ b/src/snark/src/algebra/fields/fp6_3over2.hpp @@ -63,7 +63,7 @@ public: Fp6_3over2_model operator-() const; Fp6_3over2_model squared() const; Fp6_3over2_model inverse() const; - Fp6_3over2_model Frobenius_map(unsigned long power) const; + Fp6_3over2_model Frobenius_map(uint64_t power) const; static my_Fp2 mul_by_non_residue(const my_Fp2 &elt); @@ -71,7 +71,7 @@ public: Fp6_3over2_model operator^(const bigint &other) const; static bigint base_field_char() { return modulus; } - static size_t extension_degree() { return 6; } + static unsigned long long extension_degree() { return 6; } friend std::ostream& operator<< (std::ostream &out, const Fp6_3over2_model &el); friend std::istream& operator>> (std::istream &in, Fp6_3over2_model &el); diff --git a/src/snark/src/algebra/fields/fp6_3over2.tcc b/src/snark/src/algebra/fields/fp6_3over2.tcc index f4fffde04..de9b83d11 100644 --- a/src/snark/src/algebra/fields/fp6_3over2.tcc +++ b/src/snark/src/algebra/fields/fp6_3over2.tcc @@ -149,7 +149,7 @@ Fp6_3over2_model Fp6_3over2_model::inverse() const } template& modulus> -Fp6_3over2_model Fp6_3over2_model::Frobenius_map(unsigned long power) const +Fp6_3over2_model Fp6_3over2_model::Frobenius_map(uint64_t power) const { return Fp6_3over2_model(c0.Frobenius_map(power), Frobenius_coeffs_c1[power % 6] * c1.Frobenius_map(power), @@ -194,7 +194,7 @@ std::istream& operator>>(std::istream& in, std::vector> s; char b; diff --git a/src/snark/src/algebra/fields/tests/test_bigint.cpp b/src/snark/src/algebra/fields/tests/test_bigint.cpp index d2da59e73..6392f27c9 100644 --- a/src/snark/src/algebra/fields/tests/test_bigint.cpp +++ b/src/snark/src/algebra/fields/tests/test_bigint.cpp @@ -7,13 +7,11 @@ #include "algebra/fields/bigint.hpp" -#include - using namespace libsnark; -TEST(algebra, bigint) +void test_bigint() { - static_assert(ULONG_MAX == 0xFFFFFFFFFFFFFFFFul, "unsigned long not 64-bit"); + static_assert(UINT64_MAX == 0xFFFFFFFFFFFFFFFFul, "uint64_t not 64-bit"); static_assert(GMP_NUMB_BITS == 64, "GMP limb not 64-bit"); const char *b1_decimal = "76749407"; @@ -22,76 +20,88 @@ TEST(algebra, bigint) const char *b2_binary = "0000000000000000000000000000010101111101101000000110100001011010" "1101101010001001000001101000101000100110011001110001111110100010"; - bigint<1> b0 = bigint<1>(0ul); + bigint<1> b0 = bigint<1>(UINT64_C(0)); bigint<1> b1 = bigint<1>(b1_decimal); bigint<2> b2 = bigint<2>(b2_decimal); - EXPECT_EQ(b0.as_ulong(), 0ul); - EXPECT_TRUE(b0.is_zero()); - EXPECT_EQ(b1.as_ulong(), 76749407ul); - EXPECT_FALSE(b1.is_zero()); - EXPECT_EQ(b2.as_ulong(), 15747124762497195938ul); - EXPECT_FALSE(b2.is_zero()); - EXPECT_NE(b0, b1); - EXPECT_FALSE(b0 == b1); - - EXPECT_EQ(b2.max_bits(), 128); - EXPECT_EQ(b2.num_bits(), 99); + assert(b0.as_ulong() == UINT64_C(0)); + assert(b0.is_zero()); + assert(b1.as_ulong() == UINT64_C(76749407)); + assert(!(b1.is_zero())); + assert(b2.as_ulong() == UINT64_C(15747124762497195938)); + assert(!(b2.is_zero())); + assert(b0 != b1); + assert(!(b0 == b1)); + + assert(b2.max_bits() == 128); + assert(b2.num_bits() == 99); for (size_t i = 0; i < 128; i++) { - EXPECT_EQ(b2.test_bit(i), (b2_binary[127-i] == '1')); + assert(b2.test_bit(i) == (b2_binary[127-i] == '1')); } bigint<3> b3 = b2 * b1; - EXPECT_EQ(b3, bigint<3>(b3_decimal)); - EXPECT_FALSE(b3.is_zero()); + assert(b3 == bigint<3>(b3_decimal)); + assert(!(b3.is_zero())); bigint<3> b3a { b3 }; - EXPECT_EQ(b3a, bigint<3>(b3_decimal)); - EXPECT_EQ(b3a, b3); - EXPECT_FALSE(b3a.is_zero()); + assert(b3a == bigint<3>(b3_decimal)); + assert(b3a == b3); + assert(!(b3a.is_zero())); mpz_t m3; mpz_init(m3); b3.to_mpz(m3); bigint<3> b3b { m3 }; - EXPECT_EQ(b3b, b3); + assert(b3b == b3); bigint<2> quotient; bigint<2> remainder; bigint<3>::div_qr(quotient, remainder, b3, b2); - EXPECT_LT(quotient.num_bits(), GMP_NUMB_BITS); - EXPECT_EQ(quotient.as_ulong(), b1.as_ulong()); + assert(quotient.num_bits() < GMP_NUMB_BITS); + assert(quotient.as_ulong() == b1.as_ulong()); bigint<1> b1inc = bigint<1>("76749408"); bigint<1> b1a = quotient.shorten(b1inc, "test"); - EXPECT_EQ(b1a, b1); - EXPECT_TRUE(remainder.is_zero()); + assert(b1a == b1); + assert(remainder.is_zero()); remainder.limit(b2, "test"); - EXPECT_THROW((void)(quotient.shorten(b1, "test")), std::domain_error); - EXPECT_THROW(remainder.limit(remainder, "test"), std::domain_error); + try { + (void)(quotient.shorten(b1, "test")); + assert(false); + } catch (std::domain_error) {} + try { + remainder.limit(remainder, "test"); + assert(false); + } catch (std::domain_error) {} bigint<1> br = bigint<1>("42"); b3 += br; - EXPECT_NE(b3, b3a); - EXPECT_GT(b3, b3a); - EXPECT_FALSE(b3a > b3); + assert(b3 != b3a); + assert(b3 > b3a); + assert(!(b3a > b3)); bigint<3>::div_qr(quotient, remainder, b3, b2); - EXPECT_LT(quotient.num_bits(), GMP_NUMB_BITS); - EXPECT_EQ(quotient.as_ulong(), b1.as_ulong()); - EXPECT_LT(remainder.num_bits(), GMP_NUMB_BITS); - EXPECT_EQ(remainder.as_ulong(), 42); + assert(quotient.num_bits() < GMP_NUMB_BITS); + assert(quotient.as_ulong() == b1.as_ulong()); + assert(remainder.num_bits() < GMP_NUMB_BITS); + assert(remainder.as_ulong() == 42); b3a.clear(); - EXPECT_TRUE(b3a.is_zero()); - EXPECT_EQ(b3a.num_bits(), 0); - EXPECT_FALSE(b3.is_zero()); + assert(b3a.is_zero()); + assert(b3a.num_bits() == 0); + assert(!(b3.is_zero())); bigint<4> bx = bigint<4>().randomize(); bigint<4> by = bigint<4>().randomize(); - EXPECT_FALSE(bx == by); + assert(!(bx == by)); // TODO: test serialization } +int main(void) +{ + test_bigint(); + return 0; +} + diff --git a/src/snark/src/algebra/fields/tests/test_fields.cpp b/src/snark/src/algebra/fields/tests/test_fields.cpp index 969800d8b..a05f601e6 100644 --- a/src/snark/src/algebra/fields/tests/test_fields.cpp +++ b/src/snark/src/algebra/fields/tests/test_fields.cpp @@ -5,6 +5,9 @@ * @copyright MIT license (see LICENSE file) *****************************************************************************/ #include "common/profiling.hpp" +#include "algebra/curves/edwards/edwards_pp.hpp" +#include "algebra/curves/mnt/mnt4/mnt4_pp.hpp" +#include "algebra/curves/mnt/mnt6/mnt6_pp.hpp" #ifdef CURVE_BN128 #include "algebra/curves/bn128/bn128_pp.hpp" #endif @@ -12,8 +15,6 @@ #include "algebra/fields/fp6_3over2.hpp" #include "algebra/fields/fp12_2over3over2.hpp" -#include - using namespace libsnark; template @@ -28,25 +29,25 @@ void test_field() FieldT a = FieldT::random_element(); FieldT a_ser; a_ser = reserialize(a); - EXPECT_EQ(a_ser, a); + assert(a_ser == a); FieldT b = FieldT::random_element(); FieldT c = FieldT::random_element(); FieldT d = FieldT::random_element(); - EXPECT_NE(a, zero); - EXPECT_NE(a, one); + assert(a != zero); + assert(a != one); - EXPECT_EQ(a * a, a.squared()); - EXPECT_EQ((a + b).squared(), a.squared() + a*b + b*a + b.squared()); - EXPECT_EQ((a + b)*(c + d), a*c + a*d + b*c + b*d); - EXPECT_EQ(a - b, a + (-b)); - EXPECT_EQ(a - b, (-b) + a); + assert(a * a == a.squared()); + assert((a + b).squared() == a.squared() + a*b + b*a + b.squared()); + assert((a + b)*(c + d) == a*c + a*d + b*c + b*d); + assert(a - b == a + (-b)); + assert(a - b == (-b) + a); - EXPECT_EQ((a ^ rand1) * (a ^ rand2), (a^randsum)); + assert((a ^ rand1) * (a ^ rand2) == (a^randsum)); - EXPECT_EQ(a * a.inverse(), one); - EXPECT_EQ((a + b) * c.inverse(), a * c.inverse() + (b.inverse() * c).inverse()); + assert(a * a.inverse() == one); + assert((a + b) * c.inverse() == a * c.inverse() + (b.inverse() * c).inverse()); } @@ -57,7 +58,7 @@ void test_sqrt() { FieldT a = FieldT::random_element(); FieldT asq = a.squared(); - EXPECT_TRUE(asq.sqrt() == a || asq.sqrt() == -a); + assert(asq.sqrt() == a || asq.sqrt() == -a); } } @@ -65,21 +66,21 @@ template void test_two_squarings() { FieldT a = FieldT::random_element(); - EXPECT_EQ(a.squared(), a * a); - EXPECT_EQ(a.squared(), a.squared_complex()); - EXPECT_EQ(a.squared(), a.squared_karatsuba()); + assert(a.squared() == a * a); + assert(a.squared() == a.squared_complex()); + assert(a.squared() == a.squared_karatsuba()); } template void test_Frobenius() { FieldT a = FieldT::random_element(); - EXPECT_EQ(a.Frobenius_map(0), a); + assert(a.Frobenius_map(0) == a); FieldT a_q = a ^ FieldT::base_field_char(); for (size_t power = 1; power < 10; ++power) { const FieldT a_qi = a.Frobenius_map(power); - EXPECT_EQ(a_qi, a_q); + assert(a_qi == a_q); a_q = a_q ^ FieldT::base_field_char(); } @@ -88,10 +89,49 @@ void test_Frobenius() template void test_unitary_inverse() { - EXPECT_EQ(FieldT::extension_degree() % 2, 0); + assert(FieldT::extension_degree() % 2 == 0); FieldT a = FieldT::random_element(); FieldT aqcubed_minus1 = a.Frobenius_map(FieldT::extension_degree()/2) * a.inverse(); - EXPECT_EQ(aqcubed_minus1.inverse(), aqcubed_minus1.unitary_inverse()); + assert(aqcubed_minus1.inverse() == aqcubed_minus1.unitary_inverse()); +} + +template +void test_cyclotomic_squaring(); + +template<> +void test_cyclotomic_squaring >() +{ + typedef Fqk FieldT; + assert(FieldT::extension_degree() % 2 == 0); + FieldT a = FieldT::random_element(); + FieldT a_unitary = a.Frobenius_map(FieldT::extension_degree()/2) * a.inverse(); + // beta = a^((q^(k/2)-1)*(q+1)) + FieldT beta = a_unitary.Frobenius_map(1) * a_unitary; + assert(beta.cyclotomic_squared() == beta.squared()); +} + +template<> +void test_cyclotomic_squaring >() +{ + typedef Fqk FieldT; + assert(FieldT::extension_degree() % 2 == 0); + FieldT a = FieldT::random_element(); + FieldT a_unitary = a.Frobenius_map(FieldT::extension_degree()/2) * a.inverse(); + // beta = a^(q^(k/2)-1) + FieldT beta = a_unitary; + assert(beta.cyclotomic_squared() == beta.squared()); +} + +template<> +void test_cyclotomic_squaring >() +{ + typedef Fqk FieldT; + assert(FieldT::extension_degree() % 2 == 0); + FieldT a = FieldT::random_element(); + FieldT a_unitary = a.Frobenius_map(FieldT::extension_degree()/2) * a.inverse(); + // beta = a^((q^(k/2)-1)*(q+1)) + FieldT beta = a_unitary.Frobenius_map(1) * a_unitary; + assert(beta.cyclotomic_squared() == beta.squared()); } template @@ -157,16 +197,16 @@ void test_Fp4_tom_cook() c2 = - (FieldT(5)*(FieldT(4).inverse()))* v0 + (FieldT(2)*(FieldT(3).inverse()))*(v1 + v2) - FieldT(24).inverse()*(v3 + v4) + FieldT(4)*v6 + beta*v6; c3 = FieldT(12).inverse() * (FieldT(5)*v0 - FieldT(7)*v1) - FieldT(24).inverse()*(v2 - FieldT(7)*v3 + v4 + v5) + FieldT(15)*v6; - EXPECT_EQ(res, correct_res); + assert(res == correct_res); // {v0, v3, v4, v5} const FieldT u = (FieldT::one() - beta).inverse(); - EXPECT_EQ(v0, u * c0 + beta * u * c2 - beta * u * FieldT(2).inverse() * v1 - beta * u * FieldT(2).inverse() * v2 + beta * v6); - EXPECT_EQ(v3, - FieldT(15) * u * c0 - FieldT(30) * u * c1 - FieldT(3) * (FieldT(4) + beta) * u * c2 - FieldT(6) * (FieldT(4) + beta) * u * c3 + (FieldT(24) - FieldT(3) * beta * FieldT(2).inverse()) * u * v1 + (-FieldT(8) + beta * FieldT(2).inverse()) * u * v2 + assert(v0 == u * c0 + beta * u * c2 - beta * u * FieldT(2).inverse() * v1 - beta * u * FieldT(2).inverse() * v2 + beta * v6); + assert(v3 == - FieldT(15) * u * c0 - FieldT(30) * u * c1 - FieldT(3) * (FieldT(4) + beta) * u * c2 - FieldT(6) * (FieldT(4) + beta) * u * c3 + (FieldT(24) - FieldT(3) * beta * FieldT(2).inverse()) * u * v1 + (-FieldT(8) + beta * FieldT(2).inverse()) * u * v2 - FieldT(3) * (-FieldT(16) + beta) * v6); - EXPECT_EQ(v4, - FieldT(15) * u * c0 + FieldT(30) * u * c1 - FieldT(3) * (FieldT(4) + beta) * u * c2 + FieldT(6) * (FieldT(4) + beta) * u * c3 + (FieldT(24) - FieldT(3) * beta * FieldT(2).inverse()) * u * v2 + (-FieldT(8) + beta * FieldT(2).inverse()) * u * v1 + assert(v4 == - FieldT(15) * u * c0 + FieldT(30) * u * c1 - FieldT(3) * (FieldT(4) + beta) * u * c2 + FieldT(6) * (FieldT(4) + beta) * u * c3 + (FieldT(24) - FieldT(3) * beta * FieldT(2).inverse()) * u * v2 + (-FieldT(8) + beta * FieldT(2).inverse()) * u * v1 - FieldT(3) * (-FieldT(16) + beta) * v6); - EXPECT_EQ(v5, - FieldT(80) * u * c0 - FieldT(240) * u * c1 - FieldT(8) * (FieldT(9) + beta) * u * c2 - FieldT(24) * (FieldT(9) + beta) * u * c3 - FieldT(2) * (-FieldT(81) + beta) * u * v1 + (-FieldT(81) + beta) * u * v2 + assert(v5 == - FieldT(80) * u * c0 - FieldT(240) * u * c1 - FieldT(8) * (FieldT(9) + beta) * u * c2 - FieldT(24) * (FieldT(9) + beta) * u * c3 - FieldT(2) * (-FieldT(81) + beta) * u * v1 + (-FieldT(81) + beta) * u * v2 - FieldT(8) * (-FieldT(81) + beta) * v6); // c0 + beta c2 - (beta v1)/2 - (beta v2)/ 2 - (-1 + beta) beta v6, @@ -176,8 +216,22 @@ void test_Fp4_tom_cook() } } -TEST(algebra, fields) +int main(void) { + edwards_pp::init_public_params(); + test_all_fields(); + test_cyclotomic_squaring >(); + + mnt4_pp::init_public_params(); + test_all_fields(); + test_Fp4_tom_cook(); + test_two_squarings >(); + test_cyclotomic_squaring >(); + + mnt6_pp::init_public_params(); + test_all_fields(); + test_cyclotomic_squaring >(); + alt_bn128_pp::init_public_params(); test_field(); test_Frobenius(); diff --git a/src/snark/src/algebra/scalar_multiplication/kc_multiexp.tcc b/src/snark/src/algebra/scalar_multiplication/kc_multiexp.tcc index e9c08d4bc..c71a4c82b 100644 --- a/src/snark/src/algebra/scalar_multiplication/kc_multiexp.tcc +++ b/src/snark/src/algebra/scalar_multiplication/kc_multiexp.tcc @@ -8,6 +8,7 @@ #ifndef KC_MULTIEXP_TCC_ #define KC_MULTIEXP_TCC_ + namespace libsnark { template diff --git a/src/snark/src/algebra/scalar_multiplication/multiexp.tcc b/src/snark/src/algebra/scalar_multiplication/multiexp.tcc index a6b14c4df..e1783a881 100644 --- a/src/snark/src/algebra/scalar_multiplication/multiexp.tcc +++ b/src/snark/src/algebra/scalar_multiplication/multiexp.tcc @@ -40,7 +40,7 @@ public: #if defined(__x86_64__) && defined(USE_ASM) if (n == 3) { - long res; + int64_t res; __asm__ ("// check for overflow \n\t" "mov $0, %[res] \n\t" @@ -58,7 +58,7 @@ public: } else if (n == 4) { - long res; + int64_t res; __asm__ ("// check for overflow \n\t" "mov $0, %[res] \n\t" @@ -77,7 +77,7 @@ public: } else if (n == 5) { - long res; + int64_t res; __asm__ ("// check for overflow \n\t" "mov $0, %[res] \n\t" @@ -190,7 +190,7 @@ T multi_exp_inner(typename std::vector::const_iterator vec_start, if (vec_len != odd_vec_len) { g.emplace_back(T::zero()); - opt_q.emplace_back(ordered_exponent(odd_vec_len - 1, bigint(0ul))); + opt_q.emplace_back(ordered_exponent(odd_vec_len - 1, bigint(UINT64_C(0)))); } assert(g.size() % 2 == 1); assert(opt_q.size() == g.size()); @@ -214,7 +214,7 @@ T multi_exp_inner(typename std::vector::const_iterator vec_start, const size_t bbits = b.r.num_bits(); const size_t limit = (abits-bbits >= 20 ? 20 : abits-bbits); - if (bbits < 1ul<= 0; --i) + for (int64_t i = T::fixed_base_exp_window_table.size()-1; i >= 0; --i) { #ifdef DEBUG if (!inhibit_profiling_info) @@ -420,9 +420,9 @@ window_table get_window_table(const size_t scalar_size, const size_t window, const T &g) { - const size_t in_window = 1ul< -std::vector find_wnaf(const size_t window_size, const bigint &scalar); +std::vector find_wnaf(const size_t window_size, const bigint &scalar); /** * In additive notation, use wNAF exponentiation (with the given window size) to compute scalar * base. diff --git a/src/snark/src/algebra/scalar_multiplication/wnaf.tcc b/src/snark/src/algebra/scalar_multiplication/wnaf.tcc index a5e47e8e2..4f2e4072c 100644 --- a/src/snark/src/algebra/scalar_multiplication/wnaf.tcc +++ b/src/snark/src/algebra/scalar_multiplication/wnaf.tcc @@ -17,15 +17,15 @@ namespace libsnark { template -std::vector find_wnaf(const size_t window_size, const bigint &scalar) +std::vector find_wnaf(const size_t window_size, const bigint &scalar) { const size_t length = scalar.max_bits(); // upper bound - std::vector res(length+1); + std::vector res(length+1); bigint c = scalar; - long j = 0; + int64_t j = 0; while (!c.is_zero()) { - long u; + int64_t u; if ((c.data[0] & 1) == 1) { u = c.data[0] % (1u << (window_size+1)); @@ -59,11 +59,11 @@ std::vector find_wnaf(const size_t window_size, const bigint &scalar) template T fixed_window_wnaf_exp(const size_t window_size, const T &base, const bigint &scalar) { - std::vector naf = find_wnaf(window_size, scalar); - std::vector table(1ul<<(window_size-1)); + std::vector naf = find_wnaf(window_size, scalar); + std::vector table(UINT64_C(1)<<(window_size-1)); T tmp = base; T dbl = base.dbl(); - for (size_t i = 0; i < 1ul<<(window_size-1); ++i) + for (size_t i = 0; i < UINT64_C(1)<<(window_size-1); ++i) { table[i] = tmp; tmp = tmp + dbl; @@ -71,7 +71,7 @@ T fixed_window_wnaf_exp(const size_t window_size, const T &base, const bigint T res = T::zero(); bool found_nonzero = false; - for (long i = naf.size()-1; i >= 0; --i) + for (int64_t i = naf.size()-1; i >= 0; --i) { if (found_nonzero) { @@ -99,7 +99,7 @@ template T opt_window_wnaf_exp(const T &base, const bigint &scalar, const size_t scalar_bits) { size_t best = 0; - for (long i = T::wnaf_window_table.size() - 1; i >= 0; --i) + for (int64_t i = T::wnaf_window_table.size() - 1; i >= 0; --i) { if (scalar_bits >= T::wnaf_window_table[i]) { diff --git a/src/snark/src/common/assert_except.hpp b/src/snark/src/common/assert_except.hpp index 781923044..01559aabc 100644 --- a/src/snark/src/common/assert_except.hpp +++ b/src/snark/src/common/assert_except.hpp @@ -3,10 +3,10 @@ #include -inline void assert_except(bool condition) { - if (!condition) { - throw std::runtime_error("Assertion failed."); - } +inline void assert_except (bool condition) { + if (! condition) { + throw std :: runtime_error ("Assertion failed."); + } } #endif diff --git a/src/snark/src/common/data_structures/merkle_tree.tcc b/src/snark/src/common/data_structures/merkle_tree.tcc index 281700b33..ce28b124f 100644 --- a/src/snark/src/common/data_structures/merkle_tree.tcc +++ b/src/snark/src/common/data_structures/merkle_tree.tcc @@ -66,14 +66,14 @@ merkle_tree::merkle_tree(const size_t depth, assert(log2(contents_as_vector.size()) <= depth); for (size_t address = 0; address < contents_as_vector.size(); ++address) { - const size_t idx = address + (1ul< 0; --layer) { @@ -100,13 +100,13 @@ merkle_tree::merkle_tree(const size_t depth, if (!contents.empty()) { - assert(contents.rbegin()->first < 1ul<first < UINT64_C(1)<first; const bit_vector value = it->second; - const size_t idx = address + (1ul<::set_value(const size_t address, const bit_vector &value) { assert(log2(address) <= depth); - size_t idx = address + (1ul<::get_path(con { typename HashT::merkle_authentication_path_type result(depth); assert(log2(address) <= depth); - size_t idx = address + (1ul< 0; --layer) { @@ -209,7 +209,7 @@ typename HashT::merkle_authentication_path_type merkle_tree::get_path(con auto it = hashes.find(sibling_idx); if (layer == depth) { - auto it2 = values.find(sibling_idx - ((1ul<second); result[layer-1].resize(digest_size); } @@ -227,7 +227,7 @@ typename HashT::merkle_authentication_path_type merkle_tree::get_path(con template void merkle_tree::dump() const { - for (size_t i = 0; i < 1ul< ", i); diff --git a/src/snark/src/common/data_structures/sparse_vector.hpp b/src/snark/src/common/data_structures/sparse_vector.hpp index 8b134f42e..4ec09b98e 100644 --- a/src/snark/src/common/data_structures/sparse_vector.hpp +++ b/src/snark/src/common/data_structures/sparse_vector.hpp @@ -32,9 +32,9 @@ std::istream& operator>>(std::istream &in, sparse_vector &v); template struct sparse_vector { - std::vector indices; + std::vector indices; std::vector values; - size_t domain_size_ = 0; + unsigned long long domain_size_ = 0; sparse_vector() = default; sparse_vector(const sparse_vector &other) = default; @@ -44,7 +44,7 @@ struct sparse_vector { sparse_vector& operator=(const sparse_vector &other) = default; sparse_vector& operator=(sparse_vector &&other) = default; - T operator[](const size_t idx) const; + T operator[](const unsigned long long idx) const; bool operator==(const sparse_vector &other) const; bool operator==(const std::vector &other) const; @@ -52,15 +52,15 @@ struct sparse_vector { bool is_valid() const; bool empty() const; - size_t domain_size() const; // return domain_size_ - size_t size() const; // return the number of indices (representing the number of non-zero entries) - size_t size_in_bits() const; // return the number bits needed to store the sparse vector + unsigned long long domain_size() const; // return domain_size_ + unsigned long long size() const; // return the number of indices (representing the number of non-zero entries) + unsigned long long size_in_bits() const; // return the number bits needed to store the sparse vector /* return a pair consisting of the accumulated value and the sparse vector of non-accumuated values */ template std::pair > accumulate(const typename std::vector::const_iterator &it_begin, const typename std::vector::const_iterator &it_end, - const size_t offset) const; + const unsigned long long offset) const; friend std::ostream& operator<< (std::ostream &out, const sparse_vector &v); friend std::istream& operator>> (std::istream &in, sparse_vector &v); diff --git a/src/snark/src/common/data_structures/sparse_vector.tcc b/src/snark/src/common/data_structures/sparse_vector.tcc index cfc5d7559..a12c6439d 100644 --- a/src/snark/src/common/data_structures/sparse_vector.tcc +++ b/src/snark/src/common/data_structures/sparse_vector.tcc @@ -29,7 +29,7 @@ sparse_vector::sparse_vector(std::vector &&v) : } template -T sparse_vector::operator[](const size_t idx) const +T sparse_vector::operator[](const unsigned long long idx) const { auto it = std::lower_bound(indices.begin(), indices.end(), idx); return (it != indices.end() && *it == idx) ? values[it - indices.begin()] : T(); @@ -43,7 +43,7 @@ bool sparse_vector::operator==(const sparse_vector &other) const return false; } - size_t this_pos = 0, other_pos = 0; + unsigned long long this_pos = 0, other_pos = 0; while (this_pos < this->indices.size() && other_pos < other.indices.size()) { if (this->indices[this_pos] == other.indices[other_pos]) @@ -103,8 +103,8 @@ bool sparse_vector::operator==(const std::vector &other) const return false; } - size_t j = 0; - for (size_t i = 0; i < other.size(); ++i) + unsigned long long j = 0; + for (unsigned long long i = 0; i < other.size(); ++i) { if (this->indices[j] == i) { @@ -134,7 +134,7 @@ bool sparse_vector::is_valid() const return false; } - for (size_t i = 0; i + 1 < indices.size(); ++i) + for (unsigned long long i = 0; i + 1 < indices.size(); ++i) { if (indices[i] >= indices[i+1]) { @@ -157,42 +157,42 @@ bool sparse_vector::empty() const } template -size_t sparse_vector::domain_size() const +unsigned long long sparse_vector::domain_size() const { return domain_size_; } template -size_t sparse_vector::size() const +unsigned long long sparse_vector::size() const { return indices.size(); } template -size_t sparse_vector::size_in_bits() const +unsigned long long sparse_vector::size_in_bits() const { - return indices.size() * (sizeof(size_t) * 8 + T::size_in_bits()); + return indices.size() * (sizeof(unsigned long long) * 8 + T::size_in_bits()); } template template std::pair > sparse_vector::accumulate(const typename std::vector::const_iterator &it_begin, const typename std::vector::const_iterator &it_end, - const size_t offset) const + const unsigned long long offset) const { // TODO: does not really belong here. - const size_t chunks = 1; + const unsigned long long chunks = 1; const bool use_multiexp = true; T accumulated_value = T::zero(); sparse_vector resulting_vector; resulting_vector.domain_size_ = domain_size_; - const size_t range_len = it_end - it_begin; + const unsigned long long range_len = it_end - it_begin; bool in_block = false; - size_t first_pos = -1, last_pos = -1; // g++ -flto emits unitialized warning, even though in_block guards for such cases. + unsigned long long first_pos = -1, last_pos = -1; // g++ -flto emits unitialized warning, even though in_block guards for such cases. - for (size_t i = 0; i < indices.size(); ++i) + for (unsigned long long i = 0; i < indices.size(); ++i) { const bool matching_pos = (offset <= indices[i] && indices[i] < offset + range_len); // printf("i = %zu, pos[i] = %zu, offset = %zu, w_size = %zu\n", i, indices[i], offset, w_size); @@ -265,7 +265,7 @@ std::ostream& operator<<(std::ostream& out, const sparse_vector &v) { out << v.domain_size_ << "\n"; out << v.indices.size() << "\n"; - for (const size_t& i : v.indices) + for (const unsigned long long& i : v.indices) { out << i << "\n"; } @@ -285,11 +285,11 @@ std::istream& operator>>(std::istream& in, sparse_vector &v) in >> v.domain_size_; consume_newline(in); - size_t s; + unsigned long long s; in >> s; consume_newline(in); v.indices.resize(s); - for (size_t i = 0; i < s; ++i) + for (unsigned long long i = 0; i < s; ++i) { in >> v.indices[i]; consume_newline(in); @@ -300,7 +300,7 @@ std::istream& operator>>(std::istream& in, sparse_vector &v) consume_newline(in); v.values.reserve(s); - for (size_t i = 0; i < s; ++i) + for (unsigned long long i = 0; i < s; ++i) { T t; in >> t; diff --git a/src/snark/src/common/profiling.cpp b/src/snark/src/common/profiling.cpp index d227203a0..3a4fb9e80 100644 --- a/src/snark/src/common/profiling.cpp +++ b/src/snark/src/common/profiling.cpp @@ -26,6 +26,13 @@ #include #endif +#ifdef __MACH__ // required to build on MacOS +#include +#include +#include +#include +#endif + namespace libsnark { long long get_nsec_time() @@ -38,10 +45,20 @@ long long get_nsec_time() long long get_nsec_cpu_time() { ::timespec ts; + #ifdef __MACH__ + clock_serv_t cclock; + mach_timespec_t mts; + host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); + clock_get_time(cclock, &mts); + mach_port_deallocate(mach_task_self(), cclock); + ts.tv_sec = mts.tv_sec; + ts.tv_nsec = mts.tv_nsec; + #else if ( ::clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &ts) ) throw ::std::runtime_error("clock_gettime(CLOCK_PROCESS_CPUTIME_ID) failed"); // If we expected this to work, don't silently ignore failures, because that would hide the problem and incur an unnecessarily system-call overhead. So if we ever observe this exception, we should probably add a suitable #ifdef . //TODO: clock_gettime(CLOCK_PROCESS_CPUTIME_ID) is not supported by native Windows. What about Cygwin? Should we #ifdef on CLOCK_PROCESS_CPUTIME_ID or on __linux__? + #endif return ts.tv_sec * 1000000000ll + ts.tv_nsec; } @@ -57,20 +74,20 @@ void start_profiling() } std::map invocation_counts; -std::map enter_times; -std::map last_times; -std::map cumulative_times; +std::map enter_times; +std::map last_times; +std::map cumulative_times; //TODO: Instead of analogous maps for time and cpu_time, use a single struct-valued map -std::map enter_cpu_times; -std::map last_cpu_times; -std::map, long long> op_counts; -std::map, long long> cumulative_op_counts; // ((msg, data_point), value) +std::map enter_cpu_times; +std::map last_cpu_times; +std::map, int64_t> op_counts; +std::map, int64_t> cumulative_op_counts; // ((msg, data_point), value) // TODO: Convert op_counts and cumulative_op_counts from pair to structs size_t indentation = 0; std::vector block_names; -std::list > op_data_points = { +std::list > op_data_points = { #ifdef PROFILE_OP_COUNTS std::make_pair("Fradd", &Fr::add_cnt), std::make_pair("Frsub", &Fr::sub_cnt), @@ -98,7 +115,7 @@ void clear_profiling_counters() cumulative_times.clear(); } -void print_cumulative_time_entry(const std::string &key, const long long factor) +void print_cumulative_time_entry(const std::string &key, const int64_t factor) { const double total_ms = (cumulative_times.at(key) * 1e-6); const size_t cnt = invocation_counts.at(key); @@ -106,7 +123,7 @@ void print_cumulative_time_entry(const std::string &key, const long long factor) printf(" %-45s: %12.5fms = %lld * %0.5fms (%zu invocations, %0.5fms = %lld * %0.5fms per invocation)\n", key.c_str(), total_ms, factor, total_ms/factor, cnt, avg_ms, factor, avg_ms/factor); } -void print_cumulative_times(const long long factor) +void print_cumulative_times(const int64_t factor) { printf("Dumping times:\n"); for (auto& kv : cumulative_times) @@ -155,7 +172,7 @@ void print_op_profiling(const std::string &msg) printf("(opcounts) = ("); bool first = true; - for (std::pair p : op_data_points) + for (std::pair p : op_data_points) { if (!first) { @@ -171,14 +188,14 @@ void print_op_profiling(const std::string &msg) #endif } -static void print_times_from_last_and_start(long long now, long long last, - long long cpu_now, long long cpu_last) +static void print_times_from_last_and_start(int64_t now, int64_t last, + int64_t cpu_now, int64_t cpu_last) { - long long time_from_start = now - start_time; - long long time_from_last = now - last; + int64_t time_from_start = now - start_time; + int64_t time_from_last = now - last; - long long cpu_time_from_start = cpu_now - start_cpu_time; - long long cpu_time_from_last = cpu_now - cpu_last; + int64_t cpu_time_from_start = cpu_now - start_cpu_time; + int64_t cpu_time_from_last = cpu_now - cpu_last; if (time_from_last != 0) { double parallelism_from_last = 1.0 * cpu_time_from_last / time_from_last; @@ -199,8 +216,8 @@ void print_time(const char* msg) return; } - long long now = get_nsec_time(); - long long cpu_now = get_nsec_cpu_time(); + int64_t now = get_nsec_time(); + int64_t cpu_now = get_nsec_cpu_time(); printf("%-35s\t", msg); print_times_from_last_and_start(now, last_time, cpu_now, last_cpu_time); @@ -231,7 +248,7 @@ void print_indent() void op_profiling_enter(const std::string &msg) { - for (std::pair p : op_data_points) + for (std::pair p : op_data_points) { op_counts[std::make_pair(msg, p.first)] = *(p.second); } @@ -245,9 +262,9 @@ void enter_block(const std::string &msg, const bool indent) } block_names.emplace_back(msg); - long long t = get_nsec_time(); + int64_t t = get_nsec_time(); enter_times[msg] = t; - long long cpu_t = get_nsec_cpu_time(); + int64_t cpu_t = get_nsec_cpu_time(); enter_cpu_times[msg] = cpu_t; if (inhibit_profiling_info) @@ -288,15 +305,15 @@ void leave_block(const std::string &msg, const bool indent) ++invocation_counts[msg]; - long long t = get_nsec_time(); + int64_t t = get_nsec_time(); last_times[msg] = (t - enter_times[msg]); cumulative_times[msg] += (t - enter_times[msg]); - long long cpu_t = get_nsec_cpu_time(); + int64_t cpu_t = get_nsec_cpu_time(); last_cpu_times[msg] = (cpu_t - enter_cpu_times[msg]); #ifdef PROFILE_OP_COUNTS - for (std::pair p : op_data_points) + for (std::pair p : op_data_points) { cumulative_op_counts[std::make_pair(msg, p.first)] += *(p.second)-op_counts[std::make_pair(msg, p.first)]; } diff --git a/src/snark/src/common/profiling.hpp b/src/snark/src/common/profiling.hpp index 9619117f4..4a496107b 100644 --- a/src/snark/src/common/profiling.hpp +++ b/src/snark/src/common/profiling.hpp @@ -22,7 +22,7 @@ namespace libsnark { void start_profiling(); -long long get_nsec_time(); +int64_t get_nsec_time(); void print_time(const char* msg); void print_header(const char* msg); @@ -31,13 +31,13 @@ void print_indent(); extern bool inhibit_profiling_info; extern bool inhibit_profiling_counters; extern std::map invocation_counts; -extern std::map last_times; -extern std::map cumulative_times; +extern std::map last_times; +extern std::map cumulative_times; void clear_profiling_counters(); -void print_cumulative_time_entry(const std::string &key, const long long factor=1); -void print_cumulative_times(const long long factor=1); +void print_cumulative_time_entry(const std::string &key, const int64_t factor=1); +void print_cumulative_times(const int64_t factor=1); void print_cumulative_op_counts(const bool only_fq=false); void enter_block(const std::string &msg, const bool indent=true); diff --git a/src/snark/src/common/utils.cpp b/src/snark/src/common/utils.cpp index dd114fdf0..f8f32d143 100644 --- a/src/snark/src/common/utils.cpp +++ b/src/snark/src/common/utils.cpp @@ -15,11 +15,11 @@ namespace libsnark { -size_t log2(size_t n) +unsigned long long log2(unsigned long long n) /* returns ceil(log2(n)), so 1ul< 1) { @@ -30,10 +30,10 @@ size_t log2(size_t n) return r; } -size_t bitreverse(size_t n, const size_t l) +unsigned long long bitreverse(unsigned long long n, const unsigned long long l) { - size_t r = 0; - for (size_t k = 0; k < l; ++k) + unsigned long long r = 0; + for (unsigned long long k = 0; k < l; ++k) { r = (r << 1) | (n & 1); n >>= 1; @@ -41,20 +41,20 @@ size_t bitreverse(size_t n, const size_t l) return r; } -bit_vector int_list_to_bits(const std::initializer_list &l, const size_t wordsize) +bit_vector int_list_to_bits(const std::initializer_list &l, const size_t wordsize) { bit_vector res(wordsize*l.size()); - for (size_t i = 0; i < l.size(); ++i) + for (uint64_t i = 0; i < l.size(); ++i) { - for (size_t j = 0; j < wordsize; ++j) + for (uint64_t j = 0; j < wordsize; ++j) { - res[i*wordsize + j] = (*(l.begin()+i) & (1ul<<(wordsize-1-j))); + res[i*wordsize + j] = (*(l.begin()+i) & (UINT64_C(1)<<(wordsize-1-j))); } } return res; } -long long div_ceil(long long x, long long y) +int64_t div_ceil(int64_t x, int64_t y) { return (x + (y-1)) / y; } @@ -68,7 +68,7 @@ bool is_little_endian() std::string FORMAT(const std::string &prefix, const char* format, ...) { - const static size_t MAX_FMT = 256; + const static unsigned long long MAX_FMT = 256; char buf[MAX_FMT]; va_list args; va_start(args, format); @@ -81,7 +81,7 @@ std::string FORMAT(const std::string &prefix, const char* format, ...) void serialize_bit_vector(std::ostream &out, const bit_vector &v) { out << v.size() << "\n"; - for (size_t i = 0; i < v.size(); ++i) + for (unsigned long long i = 0; i < v.size(); ++i) { out << v[i] << "\n"; } @@ -89,10 +89,10 @@ void serialize_bit_vector(std::ostream &out, const bit_vector &v) void deserialize_bit_vector(std::istream &in, bit_vector &v) { - size_t size; + unsigned long long size; in >> size; v.resize(size); - for (size_t i = 0; i < size; ++i) + for (unsigned long long i = 0; i < size; ++i) { bool b; in >> b; diff --git a/src/snark/src/common/utils.hpp b/src/snark/src/common/utils.hpp index d7d9e8947..5505cc765 100644 --- a/src/snark/src/common/utils.hpp +++ b/src/snark/src/common/utils.hpp @@ -21,12 +21,12 @@ namespace libsnark { typedef std::vector bit_vector; /// returns ceil(log2(n)), so 1ul< &l, const size_t wordsize); +unsigned long long bitreverse(unsigned long long n, const unsigned long long l); +bit_vector int_list_to_bits(const std::initializer_list &l, const unsigned long long wordsize); long long div_ceil(long long x, long long y); bool is_little_endian(); @@ -47,7 +47,7 @@ void serialize_bit_vector(std::ostream &out, const bit_vector &v); void deserialize_bit_vector(std::istream &in, bit_vector &v); template -size_t size_in_bits(const std::vector &v); +unsigned long long size_in_bits(const std::vector &v); #define ARRAY_SIZE(arr) (sizeof(arr)/sizeof(arr[0])) diff --git a/src/snark/src/gadgetlib1/gadgets/basic_gadgets.tcc b/src/snark/src/gadgetlib1/gadgets/basic_gadgets.tcc index 213b1906f..bcd2f2c72 100644 --- a/src/snark/src/gadgetlib1/gadgets/basic_gadgets.tcc +++ b/src/snark/src/gadgetlib1/gadgets/basic_gadgets.tcc @@ -275,11 +275,11 @@ void test_disjunction_gadget(const size_t n) disjunction_gadget d(pb, inputs, output, "d"); d.generate_r1cs_constraints(); - for (size_t w = 0; w < 1ul< c(pb, inputs, output, "c"); c.generate_r1cs_constraints(); - for (size_t w = 0; w < 1ul< cmp(pb, n, A, B, less, less_or_eq, "cmp"); cmp.generate_r1cs_constraints(); - for (size_t a = 0; a < 1ul< g(pb, A, B, result, "g"); g.generate_r1cs_constraints(); - for (size_t i = 0; i < 1ul<::generate_r1cs_witness() { /* assumes that idx can be fit in ulong; true for our purposes for now */ const bigint valint = this->pb.val(index).as_bigint(); - unsigned long idx = valint.as_ulong(); + uint64_t idx = valint.as_ulong(); const bigint arrsize(arr.size()); if (idx >= arr.size() || mpn_cmp(valint.data, arrsize.data, FieldT::num_limbs) >= 0) @@ -619,7 +619,7 @@ void test_loose_multiplexing_gadget(const size_t n) protoboard pb; pb_variable_array arr; - arr.allocate(pb, 1ul< index, result, success_flag; index.allocate(pb, "index"); result.allocate(pb, "result"); @@ -628,20 +628,20 @@ void test_loose_multiplexing_gadget(const size_t n) loose_multiplexing_gadget g(pb, arr, index, result, success_flag, "g"); g.generate_r1cs_constraints(); - for (size_t i = 0; i < 1ul<::generate_r1cs_witness() { for (size_t i = 0; i < 32; ++i) { - const long v = (this->pb.lc_val(X[i]) + this->pb.lc_val(Y[i]) + this->pb.lc_val(Z[i])).as_ulong(); + const int64_t v = (this->pb.lc_val(X[i]) + this->pb.lc_val(Y[i]) + this->pb.lc_val(Z[i])).as_ulong(); this->pb.val(result_bits[i]) = FieldT(v / 2); } diff --git a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.hpp b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.hpp index c2f31e3af..13bbc075c 100644 --- a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.hpp +++ b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.hpp @@ -78,7 +78,7 @@ public: pb_linear_combination_array g; pb_linear_combination_array h; pb_variable W; - long K; + int64_t K; pb_linear_combination_array new_a; pb_linear_combination_array new_e; @@ -92,7 +92,7 @@ public: const pb_linear_combination_array &g, const pb_linear_combination_array &h, const pb_variable &W, - const long &K, + const int64_t &K, const pb_linear_combination_array &new_a, const pb_linear_combination_array &new_e, const std::string &annotation_prefix); diff --git a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.tcc b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.tcc index e8f233a54..b0e006388 100644 --- a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.tcc +++ b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/sha256_components.tcc @@ -16,7 +16,7 @@ namespace libsnark { -const unsigned long SHA256_K[64] = { +const uint64_t SHA256_K[64] = { 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, @@ -27,7 +27,7 @@ const unsigned long SHA256_K[64] = { 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 }; -const unsigned long SHA256_H[8] = { +const uint64_t SHA256_H[8] = { 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 }; @@ -149,7 +149,7 @@ sha256_round_function_gadget::sha256_round_function_gadget(protoboard &g, const pb_linear_combination_array &h, const pb_variable &W, - const long &K, + const int64_t &K, const pb_linear_combination_array &new_a, const pb_linear_combination_array &new_e, const std::string &annotation_prefix) : diff --git a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/tests/test_sha256_gadget.cpp b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/tests/test_sha256_gadget.cpp index 0bfaf3a12..471928f6a 100644 --- a/src/snark/src/gadgetlib1/gadgets/hashes/sha256/tests/test_sha256_gadget.cpp +++ b/src/snark/src/gadgetlib1/gadgets/hashes/sha256/tests/test_sha256_gadget.cpp @@ -10,8 +10,6 @@ #include "common/profiling.hpp" #include "gadgetlib1/gadgets/hashes/sha256/sha256_gadget.hpp" -#include - using namespace libsnark; template @@ -37,10 +35,10 @@ void test_two_to_one() f.generate_r1cs_witness(); output.generate_r1cs_witness(hash_bv); - EXPECT_TRUE(pb.is_satisfied()); + assert(pb.is_satisfied()); } -TEST(gadgetlib1, sha256) +int main(void) { start_profiling(); default_ec_pp::init_public_params(); diff --git a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_authentication_path_variable.tcc b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_authentication_path_variable.tcc index d773051ab..b3d805d8e 100644 --- a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_authentication_path_variable.tcc +++ b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_authentication_path_variable.tcc @@ -41,7 +41,7 @@ void merkle_authentication_path_variable::generate_r1cs_witness(c for (size_t i = 0; i < tree_depth; ++i) { - if (address & (1ul << (tree_depth-1-i))) + if (address & (UINT64_C(1) << (tree_depth-1-i))) { left_digests[i].generate_r1cs_witness(path[i]); } @@ -58,7 +58,7 @@ merkle_authentication_path merkle_authentication_path_variable::g merkle_authentication_path result; for (size_t i = 0; i < tree_depth; ++i) { - if (address & (1ul << (tree_depth-1-i))) + if (address & (UINT64_C(1) << (tree_depth-1-i))) { result.emplace_back(left_digests[i].get_digest()); } diff --git a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_read_gadget.tcc b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_read_gadget.tcc index 6002a5886..2fde4f68c 100644 --- a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_read_gadget.tcc +++ b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_read_gadget.tcc @@ -144,10 +144,10 @@ void test_merkle_tree_check_read_gadget() bit_vector address_bits; size_t address = 0; - for (long level = tree_depth-1; level >= 0; --level) + for (int64_t level = tree_depth-1; level >= 0; --level) { const bool computed_is_right = (std::rand() % 2); - address |= (computed_is_right ? 1ul << (tree_depth-1-level) : 0); + address |= (computed_is_right ? UINT64_C(1) << (tree_depth-1-level) : 0); address_bits.push_back(computed_is_right); bit_vector other(digest_len); std::generate(other.begin(), other.end(), [&]() { return std::rand() % 2; }); diff --git a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.hpp b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.hpp index 6ec0ca11f..2d6840d61 100644 --- a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.hpp +++ b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.hpp @@ -19,6 +19,7 @@ #include "common/data_structures/merkle_tree.hpp" #include "gadgetlib1/gadget.hpp" +#include "gadgetlib1/gadgets/hashes/crh_gadget.hpp" #include "gadgetlib1/gadgets/hashes/hash_io.hpp" #include "gadgetlib1/gadgets/hashes/digest_selector_gadget.hpp" #include "gadgetlib1/gadgets/merkle_tree/merkle_authentication_path_variable.hpp" diff --git a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.tcc b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.tcc index 1ac08edbb..3e73904c1 100644 --- a/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.tcc +++ b/src/snark/src/gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.tcc @@ -197,10 +197,10 @@ void test_merkle_tree_check_update_gadget() bit_vector address_bits; size_t address = 0; - for (long level = tree_depth-1; level >= 0; --level) + for (int64_t level = tree_depth-1; level >= 0; --level) { const bool computed_is_right = (std::rand() % 2); - address |= (computed_is_right ? 1ul << (tree_depth-1-level) : 0); + address |= (computed_is_right ? UINT64_C(1) << (tree_depth-1-level) : 0); address_bits.push_back(computed_is_right); bit_vector other(digest_len); std::generate(other.begin(), other.end(), [&]() { return std::rand() % 2; }); diff --git a/src/snark/src/gadgetlib1/gadgets/merkle_tree/tests/test_merkle_tree_gadgets.cpp b/src/snark/src/gadgetlib1/gadgets/merkle_tree/tests/test_merkle_tree_gadgets.cpp index 27b52f9ec..8d52c579b 100644 --- a/src/snark/src/gadgetlib1/gadgets/merkle_tree/tests/test_merkle_tree_gadgets.cpp +++ b/src/snark/src/gadgetlib1/gadgets/merkle_tree/tests/test_merkle_tree_gadgets.cpp @@ -5,36 +5,44 @@ * @copyright MIT license (see LICENSE file) *****************************************************************************/ -#include "algebra/curves/alt_bn128/alt_bn128_pp.hpp" #ifdef CURVE_BN128 #include "algebra/curves/bn128/bn128_pp.hpp" #endif +#include "algebra/curves/edwards/edwards_pp.hpp" +#include "algebra/curves/mnt/mnt4/mnt4_pp.hpp" +#include "algebra/curves/mnt/mnt6/mnt6_pp.hpp" #include "gadgetlib1/gadgets/merkle_tree/merkle_tree_check_read_gadget.hpp" #include "gadgetlib1/gadgets/merkle_tree/merkle_tree_check_update_gadget.hpp" #include "gadgetlib1/gadgets/hashes/sha256/sha256_gadget.hpp" -#include - using namespace libsnark; template void test_all_merkle_tree_gadgets() { typedef Fr FieldT; + test_merkle_tree_check_read_gadget >(); test_merkle_tree_check_read_gadget >(); + test_merkle_tree_check_update_gadget >(); test_merkle_tree_check_update_gadget >(); } -TEST(gadgetlib1, merkle_tree) +int main(void) { start_profiling(); - alt_bn128_pp::init_public_params(); - test_all_merkle_tree_gadgets(); - #ifdef CURVE_BN128 // BN128 has fancy dependencies so it may be disabled bn128_pp::init_public_params(); test_all_merkle_tree_gadgets(); #endif + + edwards_pp::init_public_params(); + test_all_merkle_tree_gadgets(); + + mnt4_pp::init_public_params(); + test_all_merkle_tree_gadgets(); + + mnt6_pp::init_public_params(); + test_all_merkle_tree_gadgets(); } diff --git a/src/snark/src/gadgetlib1/pb_variable.hpp b/src/snark/src/gadgetlib1/pb_variable.hpp index fdf64d014..a6c71748d 100644 --- a/src/snark/src/gadgetlib1/pb_variable.hpp +++ b/src/snark/src/gadgetlib1/pb_variable.hpp @@ -59,7 +59,7 @@ public: void fill_with_field_elements(protoboard &pb, const std::vector& vals) const; void fill_with_bits(protoboard &pb, const bit_vector& bits) const; - void fill_with_bits_of_ulong(protoboard &pb, const unsigned long i) const; + void fill_with_bits_of_ulong(protoboard &pb, const uint64_t i) const; void fill_with_bits_of_field_element(protoboard &pb, const FieldT &r) const; std::vector get_vals(const protoboard &pb) const; @@ -120,7 +120,7 @@ public: void fill_with_field_elements(protoboard &pb, const std::vector& vals) const; void fill_with_bits(protoboard &pb, const bit_vector& bits) const; - void fill_with_bits_of_ulong(protoboard &pb, const unsigned long i) const; + void fill_with_bits_of_ulong(protoboard &pb, const uint64_t i) const; void fill_with_bits_of_field_element(protoboard &pb, const FieldT &r) const; std::vector get_vals(const protoboard &pb) const; diff --git a/src/snark/src/gadgetlib1/pb_variable.tcc b/src/snark/src/gadgetlib1/pb_variable.tcc index b36b3f8d7..77c9f13f8 100644 --- a/src/snark/src/gadgetlib1/pb_variable.tcc +++ b/src/snark/src/gadgetlib1/pb_variable.tcc @@ -65,7 +65,7 @@ void pb_variable_array::fill_with_bits_of_field_element(protoboard -void pb_variable_array::fill_with_bits_of_ulong(protoboard &pb, const unsigned long i) const +void pb_variable_array::fill_with_bits_of_ulong(protoboard &pb, const uint64_t i) const { this->fill_with_bits_of_field_element(pb, FieldT(i, true)); } @@ -232,7 +232,7 @@ void pb_linear_combination_array::fill_with_bits_of_field_element(protob } template -void pb_linear_combination_array::fill_with_bits_of_ulong(protoboard &pb, const unsigned long i) const +void pb_linear_combination_array::fill_with_bits_of_ulong(protoboard &pb, const uint64_t i) const { this->fill_with_bits_of_field_element(pb, FieldT(i)); } diff --git a/src/snark/src/relations/arithmetic_programs/qap/tests/test_qap.cpp b/src/snark/src/relations/arithmetic_programs/qap/tests/test_qap.cpp index e20f589c9..0054eaf8a 100644 --- a/src/snark/src/relations/arithmetic_programs/qap/tests/test_qap.cpp +++ b/src/snark/src/relations/arithmetic_programs/qap/tests/test_qap.cpp @@ -10,15 +10,13 @@ #include #include -#include "algebra/curves/alt_bn128/alt_bn128_pp.hpp" +#include "algebra/curves/mnt/mnt6/mnt6_pp.hpp" #include "algebra/fields/field_utils.hpp" #include "common/profiling.hpp" #include "common/utils.hpp" #include "reductions/r1cs_to_qap/r1cs_to_qap.hpp" #include "relations/constraint_satisfaction_problems/r1cs/examples/r1cs_examples.hpp" -#include - using namespace libsnark; template @@ -30,7 +28,7 @@ void test_qap(const size_t qap_degree, const size_t num_inputs, const bool binar See the transformation from R1CS to QAP for why this is the case. So we need that qap_degree >= num_inputs + 1. */ - ASSERT_LE(num_inputs + 1, qap_degree); + assert(num_inputs + 1 <= qap_degree); enter_block("Call to test_qap"); const size_t num_constraints = qap_degree - num_inputs - 1; @@ -53,7 +51,7 @@ void test_qap(const size_t qap_degree, const size_t num_inputs, const bool binar leave_block("Generate constraint system and assignment"); enter_block("Check satisfiability of constraint system"); - EXPECT_TRUE(example.constraint_system.is_satisfied(example.primary_input, example.auxiliary_input)); + assert(example.constraint_system.is_satisfied(example.primary_input, example.auxiliary_input)); leave_block("Check satisfiability of constraint system"); const FieldT t = FieldT::random_element(), @@ -74,31 +72,44 @@ void test_qap(const size_t qap_degree, const size_t num_inputs, const bool binar leave_block("Compute QAP witness"); enter_block("Check satisfiability of QAP instance 1"); - EXPECT_TRUE(qap_inst_1.is_satisfied(qap_wit)); + assert(qap_inst_1.is_satisfied(qap_wit)); leave_block("Check satisfiability of QAP instance 1"); enter_block("Check satisfiability of QAP instance 2"); - EXPECT_TRUE(qap_inst_2.is_satisfied(qap_wit)); + assert(qap_inst_2.is_satisfied(qap_wit)); leave_block("Check satisfiability of QAP instance 2"); leave_block("Call to test_qap"); } -TEST(relations, qap) +int main() { start_profiling(); + mnt6_pp::init_public_params(); + const size_t num_inputs = 10; + const size_t basic_domain_size = UINT64_C(1)< >(1ul << 21, num_inputs, true); + test_qap >(basic_domain_size, num_inputs, true); + test_qap >(step_domain_size, num_inputs, true); + test_qap >(extended_domain_size, num_inputs, true); + test_qap >(extended_domain_size_special, num_inputs, true); leave_block("Test QAP with binary input"); enter_block("Test QAP with field input"); - test_qap >(1ul << 21, num_inputs, false); + test_qap >(basic_domain_size, num_inputs, false); + test_qap >(step_domain_size, num_inputs, false); + test_qap >(extended_domain_size, num_inputs, false); + test_qap >(extended_domain_size_special, num_inputs, false); leave_block("Test QAP with field input"); } diff --git a/src/snark/src/relations/variable.hpp b/src/snark/src/relations/variable.hpp index a9a1449b8..c63f57b42 100644 --- a/src/snark/src/relations/variable.hpp +++ b/src/snark/src/relations/variable.hpp @@ -26,7 +26,7 @@ namespace libsnark { * Mnemonic typedefs. */ typedef size_t var_index_t; -typedef long integer_coeff_t; +typedef int64_t integer_coeff_t; /** * Forward declaration. diff --git a/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/examples/run_r1cs_ppzksnark.tcc b/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/examples/run_r1cs_ppzksnark.tcc index 00af6fe25..9bc875869 100644 --- a/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/examples/run_r1cs_ppzksnark.tcc +++ b/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/examples/run_r1cs_ppzksnark.tcc @@ -83,7 +83,7 @@ bool run_r1cs_ppzksnark(const r1cs_example > &example, } print_header("R1CS ppzkSNARK Prover"); - r1cs_ppzksnark_proof proof = r1cs_ppzksnark_prover(keypair.pk, example.primary_input, example.auxiliary_input, example.constraint_system); + r1cs_ppzksnark_proof proof = r1cs_ppzksnark_prover(keypair.pk, example.primary_input, example.auxiliary_input); printf("\n"); print_indent(); print_mem("after prover"); if (test_serialization) diff --git a/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/tests/test_r1cs_ppzksnark.cpp b/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/tests/test_r1cs_ppzksnark.cpp index 6c6e51857..9e7ab67d5 100644 --- a/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/tests/test_r1cs_ppzksnark.cpp +++ b/src/snark/src/zk_proof_systems/ppzksnark/r1cs_ppzksnark/tests/test_r1cs_ppzksnark.cpp @@ -11,14 +11,14 @@ #include #include -#include "algebra/curves/alt_bn128/alt_bn128_pp.hpp" +#include "zk_proof_systems/ppzksnark/r1cs_ppzksnark/r1cs_ppzksnark_params.hpp" +#include "common/default_types/ec_pp.hpp" +#include "common/default_types/r1cs_ppzksnark_pp.hpp" #include "common/profiling.hpp" #include "common/utils.hpp" #include "relations/constraint_satisfaction_problems/r1cs/examples/r1cs_examples.hpp" #include "zk_proof_systems/ppzksnark/r1cs_ppzksnark/examples/run_r1cs_ppzksnark.hpp" -#include - using namespace libsnark; template @@ -29,16 +29,16 @@ void test_r1cs_ppzksnark(size_t num_constraints, const bool test_serialization = true; r1cs_example > example = generate_r1cs_example_with_binary_input >(num_constraints, input_size); - example.constraint_system.swap_AB_if_beneficial(); const bool bit = run_r1cs_ppzksnark(example, test_serialization); - EXPECT_TRUE(bit); + assert(bit); print_header("(leave) Test R1CS ppzkSNARK"); } -TEST(zk_proof_systems, r1cs_ppzksnark) +int main() { + default_r1cs_ppzksnark_pp::init_public_params(); start_profiling(); - test_r1cs_ppzksnark(1000, 20); + test_r1cs_ppzksnark(1000, 100); } From 033626dac9f998cc8f461e0a8619ac0bc9b56b43 Mon Sep 17 00:00:00 2001 From: FireMartZ Date: Mon, 12 Feb 2018 22:13:28 -0500 Subject: [PATCH 2/2] Fix linux compilation. Now works for both linux and windows. --- src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp | 2 +- src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp | 2 +- src/snark/src/algebra/fields/fp.hpp | 6 +++--- src/snark/src/common/profiling.cpp | 2 +- src/snark/src/common/utils.hpp | 2 +- src/snark/src/common/utils.tcc | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp index 567f2fa3f..9f3929a24 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g1.hpp @@ -63,7 +63,7 @@ public: static alt_bn128_G1 one(); static alt_bn128_G1 random_element(); - static size_t size_in_bits() { return base_field::size_in_bits() + 1; } + static unsigned long long size_in_bits() { return base_field::size_in_bits() + 1; } static bigint base_field_char() { return base_field::field_char(); } static bigint order() { return scalar_field::field_char(); } diff --git a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp index 57bad1a4b..1c2380469 100644 --- a/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp +++ b/src/snark/src/algebra/curves/alt_bn128/alt_bn128_g2.hpp @@ -67,7 +67,7 @@ public: static alt_bn128_G2 one(); static alt_bn128_G2 random_element(); - static size_t size_in_bits() { return twist_field::size_in_bits() + 1; } + static unsigned long long size_in_bits() { return twist_field::size_in_bits() + 1; } static bigint base_field_char() { return base_field::field_char(); } static bigint order() { return scalar_field::field_char(); } diff --git a/src/snark/src/algebra/fields/fp.hpp b/src/snark/src/algebra/fields/fp.hpp index 1dce26c2d..973c32e8e 100644 --- a/src/snark/src/algebra/fields/fp.hpp +++ b/src/snark/src/algebra/fields/fp.hpp @@ -50,9 +50,9 @@ public: static int64_t sqr_cnt; static int64_t inv_cnt; #endif - static size_t num_bits; + static unsigned long long num_bits; static bigint euler; // (modulus-1)/2 - static size_t s; // modulus = 2^s * t + 1 + static unsigned long long s; // modulus = 2^s * t + 1 static bigint t; // with t odd static bigint t_minus_1_over_2; // (t-1)/2 static Fp_model nqr; // a quadratic nonresidue @@ -107,7 +107,7 @@ public: Fp_model inverse() const; Fp_model sqrt() const; // HAS TO BE A SQUARE (else does not terminate) - Fp_model operator^(const unsigned long long pow) const; + Fp_model operator^(const uint64_t pow) const; template Fp_model operator^(const bigint &pow) const; diff --git a/src/snark/src/common/profiling.cpp b/src/snark/src/common/profiling.cpp index 3a4fb9e80..a594b4f39 100644 --- a/src/snark/src/common/profiling.cpp +++ b/src/snark/src/common/profiling.cpp @@ -35,7 +35,7 @@ namespace libsnark { -long long get_nsec_time() +int64_t get_nsec_time() { auto timepoint = std::chrono::high_resolution_clock::now(); return std::chrono::duration_cast(timepoint.time_since_epoch()).count(); diff --git a/src/snark/src/common/utils.hpp b/src/snark/src/common/utils.hpp index 5505cc765..4223377b4 100644 --- a/src/snark/src/common/utils.hpp +++ b/src/snark/src/common/utils.hpp @@ -27,7 +27,7 @@ inline unsigned long long exp2(unsigned long long k) { return 1ull << k; } unsigned long long bitreverse(unsigned long long n, const unsigned long long l); bit_vector int_list_to_bits(const std::initializer_list &l, const unsigned long long wordsize); -long long div_ceil(long long x, long long y); +int64_t div_ceil(int64_t x, int64_t y); bool is_little_endian(); diff --git a/src/snark/src/common/utils.tcc b/src/snark/src/common/utils.tcc index f97178f8c..2d349d9b6 100644 --- a/src/snark/src/common/utils.tcc +++ b/src/snark/src/common/utils.tcc @@ -13,7 +13,7 @@ namespace libsnark { template -size_t size_in_bits(const std::vector &v) +unsigned long long size_in_bits(const std::vector &v) { return v.size() * T::size_in_bits(); }