From 5b4ebcd5e259ea997e3ae516b675a43dc8a81d92 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Thu, 21 Jul 2016 16:39:32 +1200 Subject: [PATCH] Add tests that exercise the cancellation code branches --- src/Makefile.gtest.include | 1 + src/crypto/equihash.cpp | 51 +++++----- src/crypto/equihash.h | 33 +++++-- src/gtest/test_equihash.cpp | 184 ++++++++++++++++++++++++++++++++++++ src/miner.cpp | 4 +- 5 files changed, 238 insertions(+), 35 deletions(-) create mode 100644 src/gtest/test_equihash.cpp diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index 7369f9431..ba07d2c43 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -6,6 +6,7 @@ zcash_gtest_SOURCES = \ gtest/main.cpp \ gtest/test_tautology.cpp \ gtest/test_checktransaction.cpp \ + gtest/test_equihash.cpp \ gtest/test_joinsplit.cpp \ gtest/test_noteencryption.cpp \ gtest/test_merkletree.cpp \ diff --git a/src/crypto/equihash.cpp b/src/crypto/equihash.cpp index 0e66e4886..64011ab59 100644 --- a/src/crypto/equihash.cpp +++ b/src/crypto/equihash.cpp @@ -193,7 +193,7 @@ eh_trunc* TruncatedStepRow::GetTruncatedIndices(size_t len, size_t lenInd } template -std::set> Equihash::BasicSolve(const eh_HashState& base_state, const std::function cancelled) +std::set> Equihash::BasicSolve(const eh_HashState& base_state, const std::function cancelled) { eh_index init_size { 1 << (CollisionBitLength + 1) }; @@ -206,7 +206,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba for (eh_index i = 0; i < init_size; i++) { X.emplace_back(N, base_state, i); // Slow down checking to prevent segfaults (??) - if (i % 10000 == 0 && cancelled()) throw solver_cancelled; + if (i % 10000 == 0 && cancelled(ListGeneration)) throw solver_cancelled; } // 3) Repeat step 2 until 2n/(k+1) bits remain @@ -215,7 +215,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba // 2a) Sort the list LogPrint("pow", "- Sorting list\n"); std::sort(X.begin(), X.end(), CompareSR(CollisionByteLength)); - if (cancelled()) throw solver_cancelled; + if (cancelled(ListSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; @@ -245,7 +245,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba } i += j; - if (cancelled()) throw solver_cancelled; + if (cancelled(ListColliding)) throw solver_cancelled; } // 2e) Handle edge case where final table entry has no collision @@ -265,7 +265,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba hashLen -= CollisionByteLength; lenIndices *= 2; - if (cancelled()) throw solver_cancelled; + if (cancelled(RoundEnd)) throw solver_cancelled; } // k+1) Find a collision on last 2n(k+1) bits @@ -274,6 +274,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba if (X.size() > 1) { LogPrint("pow", "- Sorting list\n"); std::sort(X.begin(), X.end(), CompareSR(hashLen)); + if (cancelled(FinalSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; while (i < X.size() - 1) { @@ -293,7 +294,7 @@ std::set> Equihash::BasicSolve(const eh_HashState& ba } i += j; - if (cancelled()) throw solver_cancelled; + if (cancelled(FinalColliding)) throw solver_cancelled; } } else LogPrint("pow", "- List is empty\n"); @@ -354,7 +355,7 @@ void CollideBranches(std::vector>& X, const size_t hlen, cons } template -std::set> Equihash::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled) +std::set> Equihash::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled) { eh_index init_size { 1 << (CollisionBitLength + 1) }; @@ -375,7 +376,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState for (eh_index i = 0; i < init_size; i++) { Xt.emplace_back(N, base_state, i, CollisionBitLength + 1); // Slow down checking to prevent segfaults (??) - if (i % 10000 == 0 && cancelled()) throw solver_cancelled; + if (i % 10000 == 0 && cancelled(ListGeneration)) throw solver_cancelled; } // 3) Repeat step 2 until 2n/(k+1) bits remain @@ -384,7 +385,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState // 2a) Sort the list LogPrint("pow", "- Sorting list\n"); std::sort(Xt.begin(), Xt.end(), CompareSR(CollisionByteLength)); - if (cancelled()) throw solver_cancelled; + if (cancelled(ListSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; @@ -413,7 +414,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState } i += j; - if (cancelled()) throw solver_cancelled; + if (cancelled(ListColliding)) throw solver_cancelled; } // 2e) Handle edge case where final table entry has no collision @@ -433,7 +434,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState hashLen -= CollisionByteLength; lenIndices *= 2; - if (cancelled()) throw solver_cancelled; + if (cancelled(RoundEnd)) throw solver_cancelled; } // k+1) Find a collision on last 2n(k+1) bits @@ -441,7 +442,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState if (Xt.size() > 1) { LogPrint("pow", "- Sorting list\n"); std::sort(Xt.begin(), Xt.end(), CompareSR(hashLen)); - if (cancelled()) throw solver_cancelled; + if (cancelled(FinalSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; while (i < Xt.size() - 1) { @@ -459,7 +460,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState } i += j; - if (cancelled()) break; + if (cancelled(FinalColliding)) break; } } else LogPrint("pow", "- List is empty\n"); @@ -473,7 +474,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState std::set> solns; eh_index recreate_size { UntruncateIndex(1, 0, CollisionBitLength + 1) }; int invalidCount = 0; - if (cancelled()) goto cancelsolver; + if (cancelled(StartCulling)) goto cancelsolver; for (eh_trunc* partialSoln : partialSolns) { size_t hashLen; size_t lenIndices; @@ -488,7 +489,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState for (eh_index j = 0; j < recreate_size; j++) { eh_index newIndex { UntruncateIndex(partialSoln[i], j, CollisionBitLength + 1) }; icv.emplace_back(N, base_state, newIndex); - if (cancelled()) goto cancelsolver; + if (cancelled(PartialGeneration)) goto cancelsolver; } boost::optional>> ic = icv; @@ -504,7 +505,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState ic->reserve(ic->size() + X[r]->size()); ic->insert(ic->end(), X[r]->begin(), X[r]->end()); std::sort(ic->begin(), ic->end(), CompareSR(hashLen)); - if (cancelled()) goto cancelsolver; + if (cancelled(PartialSorting)) goto cancelsolver; size_t lti = rti-(1<> Equihash::OptimisedSolve(const eh_HashState X.push_back(ic); break; } - if (cancelled()) goto cancelsolver; + if (cancelled(PartialSubtreeEnd)) goto cancelsolver; } - if (cancelled()) goto cancelsolver; + if (cancelled(PartialIndexEnd)) goto cancelsolver; } // We are at the top of the tree @@ -537,7 +538,7 @@ std::set> Equihash::OptimisedSolve(const eh_HashState for (FullStepRow row : *X[K]) { solns.insert(row.GetIndices(hashLen, lenIndices)); } - if (cancelled()) goto cancelsolver; + if (cancelled(PartialEnd)) goto cancelsolver; continue; invalidsolution: @@ -604,18 +605,18 @@ bool Equihash::IsValidSolution(const eh_HashState& base_state, std::vector< // Explicit instantiations for Equihash<96,3> template int Equihash<96,3>::InitialiseState(eh_HashState& base_state); -template std::set> Equihash<96,3>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); -template std::set> Equihash<96,3>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<96,3>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<96,3>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); template bool Equihash<96,3>::IsValidSolution(const eh_HashState& base_state, std::vector soln); // Explicit instantiations for Equihash<96,5> template int Equihash<96,5>::InitialiseState(eh_HashState& base_state); -template std::set> Equihash<96,5>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); -template std::set> Equihash<96,5>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<96,5>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<96,5>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); template bool Equihash<96,5>::IsValidSolution(const eh_HashState& base_state, std::vector soln); // Explicit instantiations for Equihash<48,5> template int Equihash<48,5>::InitialiseState(eh_HashState& base_state); -template std::set> Equihash<48,5>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); -template std::set> Equihash<48,5>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<48,5>::BasicSolve(const eh_HashState& base_state, const std::function cancelled); +template std::set> Equihash<48,5>::OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); template bool Equihash<48,5>::IsValidSolution(const eh_HashState& base_state, std::vector soln); diff --git a/src/crypto/equihash.h b/src/crypto/equihash.h index e87c693a2..f561b0249 100644 --- a/src/crypto/equihash.h +++ b/src/crypto/equihash.h @@ -110,12 +110,27 @@ public: eh_trunc* GetTruncatedIndices(size_t len, size_t lenIndices) const; }; +enum EhSolverCancelCheck +{ + ListGeneration, + ListSorting, + ListColliding, + RoundEnd, + FinalSorting, + FinalColliding, + StartCulling, + PartialGeneration, + PartialSorting, + PartialSubtreeEnd, + PartialIndexEnd, + PartialEnd +}; + class EhSolverCancelledException : public std::exception { - virtual const char* what() const throw() - { - return "Equihash solver was cancelled"; - } + virtual const char* what() const throw() { + return "Equihash solver was cancelled"; + } }; inline constexpr const size_t max(const size_t A, const size_t B) { return A > B ? A : B; } @@ -140,8 +155,8 @@ public: Equihash() { } int InitialiseState(eh_HashState& base_state); - std::set> BasicSolve(const eh_HashState& base_state, const std::function cancelled); - std::set> OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); + std::set> BasicSolve(const eh_HashState& base_state, const std::function cancelled); + std::set> OptimisedSolve(const eh_HashState& base_state, const std::function cancelled); bool IsValidSolution(const eh_HashState& base_state, std::vector soln); }; @@ -172,8 +187,8 @@ static Equihash<48,5> Eh48_5; } else { \ throw std::invalid_argument("Unsupported Equihash parameters"); \ } -#define EhBasicSolveUncancellable(n, k, base_state, solns) \ - EhBasicSolve(n, k, base_state, solns, [] { return false; }) +#define EhBasicSolveUncancellable(n, k, base_state, solns) \ + EhBasicSolve(n, k, base_state, solns, [](EhSolverCancelCheck pos) { return false; }) #define EhOptimisedSolve(n, k, base_state, solns, cancelled) \ if (n == 96 && k == 3) { \ @@ -186,7 +201,7 @@ static Equihash<48,5> Eh48_5; throw std::invalid_argument("Unsupported Equihash parameters"); \ } #define EhOptimisedSolveUncancellable(n, k, base_state, solns) \ - EhOptimisedSolve(n, k, base_state, solns, [] { return false; }) + EhOptimisedSolve(n, k, base_state, solns, [](EhSolverCancelCheck pos) { return false; }) #define EhIsValidSolution(n, k, base_state, soln, ret) \ if (n == 96 && k == 3) { \ diff --git a/src/gtest/test_equihash.cpp b/src/gtest/test_equihash.cpp new file mode 100644 index 000000000..52636ba7d --- /dev/null +++ b/src/gtest/test_equihash.cpp @@ -0,0 +1,184 @@ +#include +#include + +#include "crypto/equihash.h" + +TEST(equihash_tests, check_basic_solver_cancelled) { + Equihash<48,5> Eh48_5; + crypto_generichash_blake2b_state state; + Eh48_5.InitialiseState(state); + std::set> solns; + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return false; + })); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListGeneration; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListSorting; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListColliding; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == RoundEnd; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == FinalSorting; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == FinalColliding; + }), EhSolverCancelledException); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == StartCulling; + })); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialGeneration; + })); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialSorting; + })); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialSubtreeEnd; + })); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialIndexEnd; + })); + } + + { + ASSERT_NO_THROW(Eh48_5.BasicSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialEnd; + })); + } +} + +TEST(equihash_tests, check_optimised_solver_cancelled) { + Equihash<48,5> Eh48_5; + crypto_generichash_blake2b_state state; + Eh48_5.InitialiseState(state); + std::set> solns; + + { + ASSERT_NO_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return false; + })); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListGeneration; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListSorting; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == ListColliding; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == RoundEnd; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == FinalSorting; + }), EhSolverCancelledException); + } + + { + // More state required here, because in OptimisedSolve() the + // FinalColliding cancellation check can't throw because it will leak + // memory, and it can't goto because that steps over initialisations. + bool triggered = false; + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [=](EhSolverCancelCheck pos) mutable { + if (triggered) + return pos == StartCulling; + if (pos == FinalColliding) { + triggered = true; + return true; + } + return false; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == StartCulling; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialGeneration; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialSorting; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialSubtreeEnd; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialIndexEnd; + }), EhSolverCancelledException); + } + + { + ASSERT_THROW(Eh48_5.OptimisedSolve(state, [](EhSolverCancelCheck pos) { + return pos == PartialEnd; + }), EhSolverCancelledException); + } +} diff --git a/src/miner.cpp b/src/miner.cpp index 5065977cb..48de9ab60 100644 --- a/src/miner.cpp +++ b/src/miner.cpp @@ -521,7 +521,9 @@ void static BitcoinMiner(CWallet *pwallet) pblock->nNonce.ToString()); std::set> solns; try { - std::function cancelled = [pindexPrev] { return pindexPrev != chainActive.Tip(); }; + std::function cancelled = [pindexPrev](EhSolverCancelCheck pos) { + return pindexPrev != chainActive.Tip(); + }; EhOptimisedSolve(n, k, curr_state, solns, cancelled); } catch (EhSolverCancelledException&) { LogPrint("pow", "Equihash solver cancelled\n");