// Copyright (c) 2016 Jack Grigg // Copyright (c) 2016 The Zcash developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. // Implementation of the Equihash Proof-of-Work algorithm. // // Reference // ========= // Alex Biryukov and Dmitry Khovratovich // Equihash: Asymmetric Proof-of-Work Based on the Generalized Birthday Problem // NDSS ’16, 21-24 February 2016, San Diego, CA, USA // https://www.internetsociety.org/sites/default/files/blogs-media/equihash-asymmetric-proof-of-work-based-generalized-birthday-problem.pdf #include "crypto/equihash.h" #include "util.h" #include #include #include #include EhSolverCancelledException solver_cancelled; template int Equihash::InitialiseState(eh_HashState& base_state) { uint32_t le_N = htole32(N); uint32_t le_K = htole32(K); unsigned char personalization[crypto_generichash_blake2b_PERSONALBYTES] = {}; memcpy(personalization, "ZcashPoW", 8); memcpy(personalization+8, &le_N, 4); memcpy(personalization+12, &le_K, 4); return crypto_generichash_blake2b_init_salt_personal(&base_state, NULL, 0, // No key. N/8, NULL, // No salt. personalization); } // Big-endian so that lexicographic array comparison is equivalent to integer // comparison void EhIndexToArray(const eh_index i, unsigned char* array) { assert(sizeof(eh_index) == 4); eh_index bei = htobe32(i); memcpy(array, &bei, sizeof(eh_index)); } // Big-endian so that lexicographic array comparison is equivalent to integer // comparison eh_index ArrayToEhIndex(const unsigned char* array) { assert(sizeof(eh_index) == 4); eh_index bei; memcpy(&bei, array, sizeof(eh_index)); return be32toh(bei); } eh_trunc TruncateIndex(const eh_index i, const unsigned int ilen) { // Truncate to 8 bits assert(sizeof(eh_trunc) == 1); return (i >> (ilen - 8)) & 0xff; } eh_index UntruncateIndex(const eh_trunc t, const eh_index r, const unsigned int ilen) { eh_index i{t}; return (i << (ilen - 8)) | r; } template StepRow::StepRow(const eh_HashState& base_state, eh_index i, size_t hLen, size_t cBitLen, size_t cByteLen) { eh_HashState state; state = base_state; unsigned char array[sizeof(eh_index)]; eh_index lei = htole32(i); memcpy(array, &lei, sizeof(eh_index)); crypto_generichash_blake2b_update(&state, array, sizeof(eh_index)); crypto_generichash_blake2b_final(&state, hash, hLen); if (8*cByteLen != cBitLen) { // We are not colliding an integer number of bytes, expand // TODO fix this to expand from the correct length instead of clearing bits // When this is done, change hLen to be N/8 instead of HashLength for (size_t i = 0; i < hLen; i += cByteLen) { hash[i] &= 0xFF >> (8*cByteLen - cBitLen); } } } template template StepRow::StepRow(const StepRow& a) { assert(W <= WIDTH); std::copy(a.hash, a.hash+W, hash); } template FullStepRow::FullStepRow(const eh_HashState& base_state, eh_index i, size_t hLen, size_t cBitLen, size_t cByteLen) : StepRow {base_state, i, hLen, cBitLen, cByteLen} { EhIndexToArray(i, hash+hLen); } template template FullStepRow::FullStepRow(const FullStepRow& a, const FullStepRow& b, size_t len, size_t lenIndices, int trim) : StepRow {a} { assert(len+lenIndices <= W); assert(len-trim+(2*lenIndices) <= WIDTH); for (int i = trim; i < len; i++) hash[i-trim] = a.hash[i] ^ b.hash[i]; if (a.IndicesBefore(b, len, lenIndices)) { std::copy(a.hash+len, a.hash+len+lenIndices, hash+len-trim); std::copy(b.hash+len, b.hash+len+lenIndices, hash+len-trim+lenIndices); } else { std::copy(b.hash+len, b.hash+len+lenIndices, hash+len-trim); std::copy(a.hash+len, a.hash+len+lenIndices, hash+len-trim+lenIndices); } } template FullStepRow& FullStepRow::operator=(const FullStepRow& a) { std::copy(a.hash, a.hash+WIDTH, hash); return *this; } template bool StepRow::IsZero(size_t len) { // This doesn't need to be constant time. for (int i = 0; i < len; i++) { if (hash[i] != 0) return false; } return true; } template std::vector FullStepRow::GetIndices(size_t len, size_t lenIndices) const { std::vector ret; for (int i = 0; i < lenIndices; i += sizeof(eh_index)) { ret.push_back(ArrayToEhIndex(hash+len+i)); } return ret; } template bool HasCollision(StepRow& a, StepRow& b, int l) { // This doesn't need to be constant time. for (int j = 0; j < l; j++) { if (a.hash[j] != b.hash[j]) return false; } return true; } template TruncatedStepRow::TruncatedStepRow(const eh_HashState& base_state, eh_index i, size_t hLen, size_t cBitLen, size_t cByteLen, unsigned int ilen) : StepRow {base_state, i, hLen, cBitLen, cByteLen} { hash[hLen] = TruncateIndex(i, ilen); } template template TruncatedStepRow::TruncatedStepRow(const TruncatedStepRow& a, const TruncatedStepRow& b, size_t len, size_t lenIndices, int trim) : StepRow {a} { assert(len+lenIndices <= W); assert(len-trim+(2*lenIndices) <= WIDTH); for (int i = trim; i < len; i++) hash[i-trim] = a.hash[i] ^ b.hash[i]; if (a.IndicesBefore(b, len, lenIndices)) { std::copy(a.hash+len, a.hash+len+lenIndices, hash+len-trim); std::copy(b.hash+len, b.hash+len+lenIndices, hash+len-trim+lenIndices); } else { std::copy(b.hash+len, b.hash+len+lenIndices, hash+len-trim); std::copy(a.hash+len, a.hash+len+lenIndices, hash+len-trim+lenIndices); } } template TruncatedStepRow& TruncatedStepRow::operator=(const TruncatedStepRow& a) { std::copy(a.hash, a.hash+WIDTH, hash); return *this; } template std::shared_ptr TruncatedStepRow::GetTruncatedIndices(size_t len, size_t lenIndices) const { std::shared_ptr p (new eh_trunc[lenIndices]); std::copy(hash+len, hash+len+lenIndices, p.get()); return p; } template bool Equihash::BasicSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled) { eh_index init_size { 1 << (CollisionBitLength + 1) }; // 1) Generate first list LogPrint("pow", "Generating first list\n"); size_t hashLen = HashLength; size_t lenIndices = sizeof(eh_index); std::vector> X; X.reserve(init_size); for (eh_index i = 0; i < init_size; i++) { X.emplace_back(base_state, i, HashLength, CollisionBitLength, CollisionByteLength); if (cancelled(ListGeneration)) throw solver_cancelled; } // 3) Repeat step 2 until 2n/(k+1) bits remain for (int r = 1; r < K && X.size() > 0; r++) { LogPrint("pow", "Round %d:\n", r); // 2a) Sort the list LogPrint("pow", "- Sorting list\n"); std::sort(X.begin(), X.end(), CompareSR(CollisionByteLength)); if (cancelled(ListSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; int posFree = 0; std::vector> Xc; while (i < X.size() - 1) { // 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits int j = 1; while (i+j < X.size() && HasCollision(X[i], X[i+j], CollisionByteLength)) { j++; } // 2c) Calculate tuples (X_i ^ X_j, (i, j)) for (int l = 0; l < j - 1; l++) { for (int m = l + 1; m < j; m++) { if (DistinctIndices(X[i+l], X[i+m], hashLen, lenIndices)) { Xc.emplace_back(X[i+l], X[i+m], hashLen, lenIndices, CollisionByteLength); } } } // 2d) Store tuples on the table in-place if possible while (posFree < i+j && Xc.size() > 0) { X[posFree++] = Xc.back(); Xc.pop_back(); } i += j; if (cancelled(ListColliding)) throw solver_cancelled; } // 2e) Handle edge case where final table entry has no collision while (posFree < X.size() && Xc.size() > 0) { X[posFree++] = Xc.back(); Xc.pop_back(); } if (Xc.size() > 0) { // 2f) Add overflow to end of table X.insert(X.end(), Xc.begin(), Xc.end()); } else if (posFree < X.size()) { // 2g) Remove empty space at the end X.erase(X.begin()+posFree, X.end()); X.shrink_to_fit(); } hashLen -= CollisionByteLength; lenIndices *= 2; if (cancelled(RoundEnd)) throw solver_cancelled; } // k+1) Find a collision on last 2n(k+1) bits LogPrint("pow", "Final round:\n"); if (X.size() > 1) { LogPrint("pow", "- Sorting list\n"); std::sort(X.begin(), X.end(), CompareSR(hashLen)); if (cancelled(FinalSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; while (i < X.size() - 1) { int j = 1; while (i+j < X.size() && HasCollision(X[i], X[i+j], hashLen)) { j++; } for (int l = 0; l < j - 1; l++) { for (int m = l + 1; m < j; m++) { FullStepRow res(X[i+l], X[i+m], hashLen, lenIndices, 0); if (DistinctIndices(X[i+l], X[i+m], hashLen, lenIndices) && validBlock(res.GetIndices(hashLen, 2*lenIndices))) { return true; } } } i += j; if (cancelled(FinalColliding)) throw solver_cancelled; } } else LogPrint("pow", "- List is empty\n"); return false; } template void CollideBranches(std::vector>& X, const size_t hlen, const size_t lenIndices, const unsigned int clen, const unsigned int ilen, const eh_trunc lt, const eh_trunc rt) { int i = 0; int posFree = 0; std::vector> Xc; while (i < X.size() - 1) { // 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits int j = 1; while (i+j < X.size() && HasCollision(X[i], X[i+j], clen)) { j++; } // 2c) Calculate tuples (X_i ^ X_j, (i, j)) for (int l = 0; l < j - 1; l++) { for (int m = l + 1; m < j; m++) { if (DistinctIndices(X[i+l], X[i+m], hlen, lenIndices)) { if (IsValidBranch(X[i+l], hlen, ilen, lt) && IsValidBranch(X[i+m], hlen, ilen, rt)) { Xc.emplace_back(X[i+l], X[i+m], hlen, lenIndices, clen); } else if (IsValidBranch(X[i+m], hlen, ilen, lt) && IsValidBranch(X[i+l], hlen, ilen, rt)) { Xc.emplace_back(X[i+m], X[i+l], hlen, lenIndices, clen); } } } } // 2d) Store tuples on the table in-place if possible while (posFree < i+j && Xc.size() > 0) { X[posFree++] = Xc.back(); Xc.pop_back(); } i += j; } // 2e) Handle edge case where final table entry has no collision while (posFree < X.size() && Xc.size() > 0) { X[posFree++] = Xc.back(); Xc.pop_back(); } if (Xc.size() > 0) { // 2f) Add overflow to end of table X.insert(X.end(), Xc.begin(), Xc.end()); } else if (posFree < X.size()) { // 2g) Remove empty space at the end X.erase(X.begin()+posFree, X.end()); X.shrink_to_fit(); } } template bool Equihash::OptimisedSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled) { eh_index init_size { 1 << (CollisionBitLength + 1) }; eh_index recreate_size { UntruncateIndex(1, 0, CollisionBitLength + 1) }; // First run the algorithm with truncated indices const eh_index soln_size { 1 << K }; std::vector> partialSolns; int invalidCount = 0; { // 1) Generate first list LogPrint("pow", "Generating first list\n"); size_t hashLen = HashLength; size_t lenIndices = sizeof(eh_trunc); std::vector> Xt; Xt.reserve(init_size); for (eh_index i = 0; i < init_size; i++) { Xt.emplace_back(base_state, i, HashLength, CollisionBitLength, CollisionByteLength, CollisionBitLength + 1); if (cancelled(ListGeneration)) throw solver_cancelled; } // 3) Repeat step 2 until 2n/(k+1) bits remain for (int r = 1; r < K && Xt.size() > 0; r++) { LogPrint("pow", "Round %d:\n", r); // 2a) Sort the list LogPrint("pow", "- Sorting list\n"); std::sort(Xt.begin(), Xt.end(), CompareSR(CollisionByteLength)); if (cancelled(ListSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; int posFree = 0; std::vector> Xc; while (i < Xt.size() - 1) { // 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits int j = 1; while (i+j < Xt.size() && HasCollision(Xt[i], Xt[i+j], CollisionByteLength)) { j++; } // 2c) Calculate tuples (X_i ^ X_j, (i, j)) bool checking_for_zero = (i == 0 && Xt[0].IsZero(hashLen)); for (int l = 0; l < j - 1; l++) { for (int m = l + 1; m < j; m++) { // We truncated, so don't check for distinct indices here TruncatedStepRow Xi {Xt[i+l], Xt[i+m], hashLen, lenIndices, CollisionByteLength}; if (!(Xi.IsZero(hashLen-CollisionByteLength) && IsProbablyDuplicate(Xi.GetTruncatedIndices(hashLen-CollisionByteLength, 2*lenIndices), 2*lenIndices))) { Xc.emplace_back(Xi); } } } // 2d) Store tuples on the table in-place if possible while (posFree < i+j && Xc.size() > 0) { Xt[posFree++] = Xc.back(); Xc.pop_back(); } i += j; if (cancelled(ListColliding)) throw solver_cancelled; } // 2e) Handle edge case where final table entry has no collision while (posFree < Xt.size() && Xc.size() > 0) { Xt[posFree++] = Xc.back(); Xc.pop_back(); } if (Xc.size() > 0) { // 2f) Add overflow to end of table Xt.insert(Xt.end(), Xc.begin(), Xc.end()); } else if (posFree < Xt.size()) { // 2g) Remove empty space at the end Xt.erase(Xt.begin()+posFree, Xt.end()); Xt.shrink_to_fit(); } hashLen -= CollisionByteLength; lenIndices *= 2; if (cancelled(RoundEnd)) throw solver_cancelled; } // k+1) Find a collision on last 2n(k+1) bits LogPrint("pow", "Final round:\n"); if (Xt.size() > 1) { LogPrint("pow", "- Sorting list\n"); std::sort(Xt.begin(), Xt.end(), CompareSR(hashLen)); if (cancelled(FinalSorting)) throw solver_cancelled; LogPrint("pow", "- Finding collisions\n"); int i = 0; while (i < Xt.size() - 1) { int j = 1; while (i+j < Xt.size() && HasCollision(Xt[i], Xt[i+j], hashLen)) { j++; } for (int l = 0; l < j - 1; l++) { for (int m = l + 1; m < j; m++) { TruncatedStepRow res(Xt[i+l], Xt[i+m], hashLen, lenIndices, 0); auto soln = res.GetTruncatedIndices(hashLen, 2*lenIndices); if (!IsProbablyDuplicate(soln, 2*lenIndices)) { partialSolns.push_back(soln); } } } i += j; if (cancelled(FinalColliding)) throw solver_cancelled; } } else LogPrint("pow", "- List is empty\n"); } // Ensure Xt goes out of scope and is destroyed LogPrint("pow", "Found %d partial solutions\n", partialSolns.size()); // Now for each solution run the algorithm again to recreate the indices LogPrint("pow", "Culling solutions\n"); for (std::shared_ptr partialSoln : partialSolns) { std::set> solns; size_t hashLen; size_t lenIndices; std::vector>>> X; X.reserve(K+1); // 3) Repeat steps 1 and 2 for each partial index for (eh_index i = 0; i < soln_size; i++) { // 1) Generate first list of possibilities std::vector> icv; icv.reserve(recreate_size); for (eh_index j = 0; j < recreate_size; j++) { eh_index newIndex { UntruncateIndex(partialSoln.get()[i], j, CollisionBitLength + 1) }; icv.emplace_back(base_state, newIndex, HashLength, CollisionBitLength, CollisionByteLength); if (cancelled(PartialGeneration)) throw solver_cancelled; } boost::optional>> ic = icv; // 2a) For each pair of lists: hashLen = HashLength; lenIndices = sizeof(eh_index); size_t rti = i; for (size_t r = 0; r <= K; r++) { // 2b) Until we are at the top of a subtree: if (r < X.size()) { if (X[r]) { // 2c) Merge the lists ic->reserve(ic->size() + X[r]->size()); ic->insert(ic->end(), X[r]->begin(), X[r]->end()); std::sort(ic->begin(), ic->end(), CompareSR(hashLen)); if (cancelled(PartialSorting)) throw solver_cancelled; size_t lti = rti-(1<size() == 0) goto invalidsolution; X[r] = boost::none; hashLen -= CollisionByteLength; lenIndices *= 2; rti = lti; } else { X[r] = *ic; break; } } else { X.push_back(ic); break; } if (cancelled(PartialSubtreeEnd)) throw solver_cancelled; } if (cancelled(PartialIndexEnd)) throw solver_cancelled; } // We are at the top of the tree assert(X.size() == K+1); for (FullStepRow row : *X[K]) { solns.insert(row.GetIndices(hashLen, lenIndices)); } for (auto soln : solns) { if (validBlock(soln)) return true; } if (cancelled(PartialEnd)) throw solver_cancelled; continue; invalidsolution: invalidCount++; } LogPrint("pow", "- Number of invalid solutions found: %d\n", invalidCount); return false; } template bool Equihash::IsValidSolution(const eh_HashState& base_state, std::vector soln) { eh_index soln_size { 1u << K }; if (soln.size() != soln_size) { LogPrint("pow", "Invalid solution size: %d\n", soln.size()); return false; } std::vector> X; X.reserve(soln_size); for (eh_index i : soln) { X.emplace_back(base_state, i, HashLength, CollisionBitLength, CollisionByteLength); } size_t hashLen = HashLength; size_t lenIndices = sizeof(eh_index); while (X.size() > 1) { std::vector> Xc; for (int i = 0; i < X.size(); i += 2) { if (!HasCollision(X[i], X[i+1], CollisionByteLength)) { LogPrint("pow", "Invalid solution: invalid collision length between StepRows\n"); LogPrint("pow", "X[i] = %s\n", X[i].GetHex(hashLen)); LogPrint("pow", "X[i+1] = %s\n", X[i+1].GetHex(hashLen)); return false; } if (X[i+1].IndicesBefore(X[i], hashLen, lenIndices)) { return false; LogPrint("pow", "Invalid solution: Index tree incorrectly ordered\n"); } if (!DistinctIndices(X[i], X[i+1], hashLen, lenIndices)) { LogPrint("pow", "Invalid solution: duplicate indices\n"); return false; } Xc.emplace_back(X[i], X[i+1], hashLen, lenIndices, CollisionByteLength); } X = Xc; hashLen -= CollisionByteLength; lenIndices *= 2; } assert(X.size() == 1); return X[0].IsZero(hashLen); } // Explicit instantiations for Equihash<96,3> template int Equihash<96,3>::InitialiseState(eh_HashState& base_state); template bool Equihash<96,3>::BasicSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<96,3>::OptimisedSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<96,3>::IsValidSolution(const eh_HashState& base_state, std::vector soln); // Explicit instantiations for Equihash<200,9> template int Equihash<200,9>::InitialiseState(eh_HashState& base_state); template bool Equihash<200,9>::BasicSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<200,9>::OptimisedSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<200,9>::IsValidSolution(const eh_HashState& base_state, std::vector soln); // Explicit instantiations for Equihash<96,5> template int Equihash<96,5>::InitialiseState(eh_HashState& base_state); template bool Equihash<96,5>::BasicSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<96,5>::OptimisedSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<96,5>::IsValidSolution(const eh_HashState& base_state, std::vector soln); // Explicit instantiations for Equihash<48,5> template int Equihash<48,5>::InitialiseState(eh_HashState& base_state); template bool Equihash<48,5>::BasicSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<48,5>::OptimisedSolve(const eh_HashState& base_state, const std::function)> validBlock, const std::function cancelled); template bool Equihash<48,5>::IsValidSolution(const eh_HashState& base_state, std::vector soln);