diff --git a/src/main.cpp b/src/main.cpp index 1c1de636a..91fe6ba8f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2262,6 +2262,127 @@ CMerkleBlock::CMerkleBlock(const CBlock& block, CBloomFilter& filter) +uint256 CPartialMerkleTree::CalcHash(int height, unsigned int pos, const std::vector &vTxid) { + if (height == 0) { + // hash at height 0 is the txids themself + return vTxid[pos]; + } else { + // calculate left hash + uint256 left = CalcHash(height-1, pos*2, vTxid), right; + // calculate right hash if not beyong the end of the array - copy left hash otherwise1 + if (pos*2+1 < CalcTreeWidth(height-1)) + right = CalcHash(height-1, pos*2+1, vTxid); + else + right = left; + // combine subhashes + return Hash(BEGIN(left), END(left), BEGIN(right), END(right)); + } +} + +void CPartialMerkleTree::TraverseAndBuild(int height, unsigned int pos, const std::vector &vTxid, const std::vector &vMatch) { + // determine whether this node is the parent of at least one matched txid + bool fParentOfMatch = false; + for (unsigned int p = pos << height; p < (pos+1) << height && p < nTransactions; p++) + fParentOfMatch |= vMatch[p]; + // store as flag bit + vBits.push_back(fParentOfMatch); + if (height==0 || !fParentOfMatch) { + // if at height 0, or nothing interesting below, store hash and stop + vHash.push_back(CalcHash(height, pos, vTxid)); + } else { + // otherwise, don't store any hash, but descend into the subtrees + TraverseAndBuild(height-1, pos*2, vTxid, vMatch); + if (pos*2+1 < CalcTreeWidth(height-1)) + TraverseAndBuild(height-1, pos*2+1, vTxid, vMatch); + } +} + +uint256 CPartialMerkleTree::TraverseAndExtract(int height, unsigned int pos, unsigned int &nBitsUsed, unsigned int &nHashUsed, std::vector &vMatch) { + if (nBitsUsed >= vBits.size()) { + // overflowed the bits array - failure + fBad = true; + return 0; + } + bool fParentOfMatch = vBits[nBitsUsed++]; + if (height==0 || !fParentOfMatch) { + // if at height 0, or nothing interesting below, use stored hash and do not descend + if (nHashUsed >= vHash.size()) { + // overflowed the hash array - failure + fBad = true; + return 0; + } + const uint256 &hash = vHash[nHashUsed++]; + if (height==0 && fParentOfMatch) // in case of height 0, we have a matched txid + vMatch.push_back(hash); + return hash; + } else { + // otherwise, descend into the subtrees to extract matched txids and hashes + uint256 left = TraverseAndExtract(height-1, pos*2, nBitsUsed, nHashUsed, vMatch), right; + if (pos*2+1 < CalcTreeWidth(height-1)) + right = TraverseAndExtract(height-1, pos*2+1, nBitsUsed, nHashUsed, vMatch); + else + right = left; + // and combine them before returning + return Hash(BEGIN(left), END(left), BEGIN(right), END(right)); + } +} + +CPartialMerkleTree::CPartialMerkleTree(const std::vector &vTxid, const std::vector &vMatch) : nTransactions(vTxid.size()), fBad(false) { + // reset state + vBits.clear(); + vHash.clear(); + + // calculate height of tree + int nHeight = 0; + while (CalcTreeWidth(nHeight) > 1) + nHeight++; + + // traverse the partial tree + TraverseAndBuild(nHeight, 0, vTxid, vMatch); +} + +CPartialMerkleTree::CPartialMerkleTree() : nTransactions(0), fBad(true) {} + +uint256 CPartialMerkleTree::ExtractMatches(std::vector &vMatch) { + vMatch.clear(); + // An empty set will not work + if (nTransactions == 0) + return 0; + // check for excessively high numbers of transactions + if (nTransactions > MAX_BLOCK_SIZE / 60) // 60 is the lower bound for the size of a serialized CTransaction + return 0; + // there can never be more hashes provided than one for every txid + if (vHash.size() > nTransactions) + return 0; + // there must be at least one bit per node in the partial tree, and at least one node per hash + if (vBits.size() < vHash.size()) + return 0; + // calculate height of tree + int nHeight = 0; + while (CalcTreeWidth(nHeight) > 1) + nHeight++; + // traverse the partial tree + unsigned int nBitsUsed = 0, nHashUsed = 0; + uint256 hashMerkleRoot = TraverseAndExtract(nHeight, 0, nBitsUsed, nHashUsed, vMatch); + // verify that no problems occured during the tree traversal + if (fBad) + return 0; + // verify that all bits were consumed (except for the padding caused by serializing it as a byte sequence) + if ((nBitsUsed+7)/8 != (vBits.size()+7)/8) + return 0; + // verify that all hashes were consumed + if (nHashUsed != vHash.size()) + return 0; + return hashMerkleRoot; +} + + + + + + + + bool CheckDiskSpace(uint64 nAdditionalBytes) { uint64 nFreeBytesAvailable = filesystem::space(GetDataDir()).available; diff --git a/src/main.h b/src/main.h index 77aac71d2..f6086e92c 100644 --- a/src/main.h +++ b/src/main.h @@ -1110,11 +1110,101 @@ public: +/** Data structure that represents a partial merkle tree. + * + * It respresents a subset of the txid's of a known block, in a way that + * allows recovery of the list of txid's and the merkle root, in an + * authenticated way. + * + * The encoding works as follows: we traverse the tree in depth-first order, + * storing a bit for each traversed node, signifying whether the node is the + * parent of at least one matched leaf txid (or a matched txid itself). In + * case we are at the leaf level, or this bit is 0, its merkle node hash is + * stored, and its children are not explorer further. Otherwise, no hash is + * stored, but we recurse into both (or the only) child branch. During + * decoding, the same depth-first traversal is performed, consuming bits and + * hashes as they written during encoding. + * + * The serialization is fixed and provides a hard guarantee about the + * encoded size: + * + * SIZE <= 10 + ceil(32.25*N) + * + * Where N represents the number of leaf nodes of the partial tree. N itself + * is bounded by: + * + * N <= total_transactions + * N <= 1 + matched_transactions*tree_height + * + * The serialization format: + * - uint32 total_transactions (4 bytes) + * - varint number of hashes (1-3 bytes) + * - uint256[] hashes in depth-first order (<= 32*N bytes) + * - varint number of bytes of flag bits (1-3 bytes) + * - byte[] flag bits, packed per 8 in a byte, least significant bit first (<= 2*N-1 bits) + * The size constraints follow from this. + */ +class CPartialMerkleTree +{ +protected: + // the total number of transactions in the block + unsigned int nTransactions; + + // node-is-parent-of-matched-txid bits + std::vector vBits; + + // txids and internal hashes + std::vector vHash; + // flag set when encountering invalid data + bool fBad; + // helper function to efficiently calculate the number of nodes at given height in the merkle tree + unsigned int CalcTreeWidth(int height) { + return (nTransactions+(1 << height)-1) >> height; + } + + // calculate the hash of a node in the merkle tree (at leaf level: the txid's themself) + uint256 CalcHash(int height, unsigned int pos, const std::vector &vTxid); + + // recursive function that traverses tree nodes, storing the data as bits and hashes + void TraverseAndBuild(int height, unsigned int pos, const std::vector &vTxid, const std::vector &vMatch); + + // recursive function that traverses tree nodes, consuming the bits and hashes produced by TraverseAndBuild. + // it returns the hash of the respective node. + uint256 TraverseAndExtract(int height, unsigned int pos, unsigned int &nBitsUsed, unsigned int &nHashUsed, std::vector &vMatch); + +public: + // serialization implementation + IMPLEMENT_SERIALIZE( + READWRITE(nTransactions); + READWRITE(vHash); + std::vector vBytes; + if (fRead) { + READWRITE(vBytes); + CPartialMerkleTree &us = *(const_cast(this)); + us.vBits.resize(vBytes.size() * 8); + for (unsigned int p = 0; p < us.vBits.size(); p++) + us.vBits[p] = (vBytes[p / 8] & (1 << (p % 8))) != 0; + us.fBad = false; + } else { + vBytes.resize((vBits.size()+7)/8); + for (unsigned int p = 0; p < vBits.size(); p++) + vBytes[p / 8] |= vBits[p] << (p % 8); + READWRITE(vBytes); + } + ) + + // Construct a partial merkle tree from a list of transaction id's, and a mask that selects a subset of them + CPartialMerkleTree(const std::vector &vTxid, const std::vector &vMatch); + CPartialMerkleTree(); + // extract the matching txid's represented by this partial merkle tree. + // returns the merkle root, or 0 in case of failure + uint256 ExtractMatches(std::vector &vMatch); +}; /** Nodes collect new transactions into a block, hash them into a hash tree, diff --git a/src/test/pmt_tests.cpp b/src/test/pmt_tests.cpp new file mode 100644 index 000000000..cf0942161 --- /dev/null +++ b/src/test/pmt_tests.cpp @@ -0,0 +1,98 @@ +#include + +#include "uint256.h" +#include "main.h" + +using namespace std; + +class CPartialMerkleTreeTester : public CPartialMerkleTree +{ +public: + // flip one bit in one of the hashes - this should break the authentication + void Damage() { + unsigned int n = rand() % vHash.size(); + int bit = rand() % 256; + uint256 &hash = vHash[n]; + hash ^= ((uint256)1 << bit); + } +}; + +BOOST_AUTO_TEST_SUITE(pmt_tests) + +BOOST_AUTO_TEST_CASE(pmt_test1) +{ + static const unsigned int nTxCounts[] = {1, 4, 7, 17, 56, 100, 127, 256, 312, 513, 1000, 4095}; + + for (int n = 0; n < 12; n++) { + unsigned int nTx = nTxCounts[n]; + + // build a block with some dummy transactions + CBlock block; + for (unsigned int j=0; j vTxid(nTx, 0); + for (unsigned int j=0; j 1) { + nTx_ = (nTx_+1)/2; + nHeight++; + } + + // check with random subsets with inclusion chances 1, 1/2, 1/4, ..., 1/128 + for (int att = 1; att < 15; att++) { + // build random subset of txid's + std::vector vMatch(nTx, false); + std::vector vMatchTxid1; + for (unsigned int j=0; j(nTx, 1 + vMatchTxid1.size()*nHeight); + BOOST_CHECK(ss.size() <= 10 + (258*n+7)/8); + + // deserialize into a tester copy + CPartialMerkleTreeTester pmt2; + ss >> pmt2; + + // extract merkle root and matched txids from copy + std::vector vMatchTxid2; + uint256 merkleRoot2 = pmt2.ExtractMatches(vMatchTxid2); + + // check that it has the same merkle root as the original, and a valid one + BOOST_CHECK(merkleRoot1 == merkleRoot2); + BOOST_CHECK(merkleRoot2 != 0); + + // check that it contains the matched transactions (in the same order!) + BOOST_CHECK(vMatchTxid1 == vMatchTxid2); + + // check that random bit flips break the authentication + for (int j=0; j<4; j++) { + CPartialMerkleTreeTester pmt3(pmt2); + pmt3.Damage(); + std::vector vMatchTxid3; + uint256 merkleRoot3 = pmt3.ExtractMatches(vMatchTxid3); + BOOST_CHECK(merkleRoot3 != merkleRoot1); + } + } + } +} + +BOOST_AUTO_TEST_SUITE_END()