diff --git a/src/gtest/test_joinsplit.cpp b/src/gtest/test_joinsplit.cpp index 0e732e057..18f293f1d 100644 --- a/src/gtest/test_joinsplit.cpp +++ b/src/gtest/test_joinsplit.cpp @@ -204,14 +204,12 @@ void invokeAPIFailure( { try { invokeAPI(js, inputs, outputs, vpub_old, vpub_new, rt); + FAIL() << "It worked, when it shouldn't have!"; } catch(std::invalid_argument const & err) { EXPECT_EQ(err.what(), reason); - return; } catch(...) { FAIL() << "Expected invalid_argument exception."; } - - FAIL() << "It worked, when it shouldn't have!"; } TEST(joinsplit, h_sig) diff --git a/src/zcash/IncrementalMerkleTree.cpp b/src/zcash/IncrementalMerkleTree.cpp index 3a6501555..cf2d00af7 100644 --- a/src/zcash/IncrementalMerkleTree.cpp +++ b/src/zcash/IncrementalMerkleTree.cpp @@ -70,6 +70,17 @@ void IncrementalMerkleTree::wfcheck() const { } } +template +Hash IncrementalMerkleTree::last() const { + if (right) { + return *right; + } else if (left) { + return *left; + } else { + throw std::runtime_error("tree has no cursor"); + } +} + template void IncrementalMerkleTree::append(Hash obj) { if (is_complete(Depth)) { diff --git a/src/zcash/IncrementalMerkleTree.hpp b/src/zcash/IncrementalMerkleTree.hpp index cd21bf651..6c50192c8 100644 --- a/src/zcash/IncrementalMerkleTree.hpp +++ b/src/zcash/IncrementalMerkleTree.hpp @@ -79,6 +79,7 @@ public: Hash root() const { return root(Depth, std::deque()); } + Hash last() const; IncrementalWitness witness() const { return IncrementalWitness(*this); @@ -138,6 +139,12 @@ public: return tree.path(partial_path()); } + // Return the element being witnessed (should be a note + // commitment!) + Hash element() const { + return tree.last(); + } + Hash root() const { return tree.root(Depth, partial_path()); } diff --git a/src/zcash/JoinSplit.cpp b/src/zcash/JoinSplit.cpp index 702c3bac6..ebcef6b5b 100644 --- a/src/zcash/JoinSplit.cpp +++ b/src/zcash/JoinSplit.cpp @@ -205,10 +205,17 @@ public: for (size_t i = 0; i < NumInputs; i++) { // Sanity checks of input { - // If note has nonzero value, its witness's root must be equal to the - // input. - if ((inputs[i].note.value != 0) && (inputs[i].witness.root() != rt)) { - throw std::invalid_argument("joinsplit not anchored to the correct root"); + // If note has nonzero value + if (inputs[i].note.value != 0) { + // The witness root must equal the input root. + if (inputs[i].witness.root() != rt) { + throw std::invalid_argument("joinsplit not anchored to the correct root"); + } + + // The tree must witness the correct element + if (inputs[i].note.cm() != inputs[i].witness.element()) { + throw std::invalid_argument("witness of wrong element for joinsplit input"); + } } // Ensure we have the key to this note.