diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index 5b7fa31e2..99fdf1e75 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -8,6 +8,7 @@ zcash_gtest_SOURCES = \ gtest/test_checktransaction.cpp \ gtest/test_equihash.cpp \ gtest/test_joinsplit.cpp \ + gtest/test_keystore.cpp \ gtest/test_noteencryption.cpp \ gtest/test_merkletree.cpp \ gtest/test_circuit.cpp \ diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp new file mode 100644 index 000000000..987bdd13a --- /dev/null +++ b/src/gtest/test_keystore.cpp @@ -0,0 +1,26 @@ +#include + +#include "keystore.h" +#include "zcash/Address.hpp" + +TEST(keystore_tests, store_and_retrieve_spending_key) { + CBasicKeyStore keyStore; + + std::set addrs; + keyStore.GetPaymentAddresses(addrs); + ASSERT_EQ(0, addrs.size()); + + auto sk = libzcash::SpendingKey::random(); + keyStore.AddSpendingKey(sk); + + auto addr = sk.address(); + ASSERT_TRUE(keyStore.HaveSpendingKey(addr)); + + libzcash::SpendingKey keyOut; + keyStore.GetSpendingKey(addr, keyOut); + ASSERT_EQ(sk, keyOut); + + keyStore.GetPaymentAddresses(addrs); + ASSERT_EQ(1, addrs.size()); + ASSERT_EQ(1, addrs.count(addr)); +} diff --git a/src/keystore.cpp b/src/keystore.cpp index 3bae24b7b..376654e7f 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -23,6 +23,10 @@ bool CKeyStore::AddKey(const CKey &key) { return AddKeyPubKey(key, key.GetPubKey()); } +bool CKeyStore::AddSpendingKey(const libzcash::SpendingKey &key) { + return AddSpendingKeyPaymentAddress(key, key.address()); +} + bool CBasicKeyStore::AddKeyPubKey(const CKey& key, const CPubKey &pubkey) { LOCK(cs_KeyStore); @@ -83,3 +87,10 @@ bool CBasicKeyStore::HaveWatchOnly() const LOCK(cs_KeyStore); return (!setWatchOnly.empty()); } + +bool CBasicKeyStore::AddSpendingKeyPaymentAddress(const libzcash::SpendingKey& key, const libzcash::PaymentAddress &address) +{ + LOCK(cs_KeyStore); + mapSpendingKeys[address] = key; + return true; +} diff --git a/src/keystore.h b/src/keystore.h index 4a4b6d20a..bbd04f235 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -11,6 +11,7 @@ #include "script/script.h" #include "script/standard.h" #include "sync.h" +#include "zcash/Address.hpp" #include #include @@ -44,11 +45,21 @@ public: virtual bool RemoveWatchOnly(const CScript &dest) =0; virtual bool HaveWatchOnly(const CScript &dest) const =0; virtual bool HaveWatchOnly() const =0; + + //! Add a spending key to the store. + virtual bool AddSpendingKeyPaymentAddress(const libzcash::SpendingKey &key, const libzcash::PaymentAddress &address) =0; + virtual bool AddSpendingKey(const libzcash::SpendingKey &key); + + //! Check whether a spending key corresponding to a given payment address is present in the store. + virtual bool HaveSpendingKey(const libzcash::PaymentAddress &address) const =0; + virtual bool GetSpendingKey(const libzcash::PaymentAddress &address, libzcash::SpendingKey& keyOut) const =0; + virtual void GetPaymentAddresses(std::set &setAddress) const =0; }; typedef std::map KeyMap; typedef std::map ScriptMap; typedef std::set WatchOnlySet; +typedef std::map SpendingKeyMap; /** Basic key store, that keeps keys in an address->secret map */ class CBasicKeyStore : public CKeyStore @@ -57,6 +68,7 @@ protected: KeyMap mapKeys; ScriptMap mapScripts; WatchOnlySet setWatchOnly; + SpendingKeyMap mapSpendingKeys; public: bool AddKeyPubKey(const CKey& key, const CPubKey &pubkey); @@ -103,6 +115,43 @@ public: virtual bool RemoveWatchOnly(const CScript &dest); virtual bool HaveWatchOnly(const CScript &dest) const; virtual bool HaveWatchOnly() const; + + bool AddSpendingKeyPaymentAddress(const libzcash::SpendingKey &key, const libzcash::PaymentAddress &address); + bool HaveSpendingKey(const libzcash::PaymentAddress &address) const + { + bool result; + { + LOCK(cs_KeyStore); + result = (mapSpendingKeys.count(address) > 0); + } + return result; + } + bool GetSpendingKey(const libzcash::PaymentAddress &address, libzcash::SpendingKey &keyOut) const + { + { + LOCK(cs_KeyStore); + SpendingKeyMap::const_iterator mi = mapSpendingKeys.find(address); + if (mi != mapSpendingKeys.end()) + { + keyOut = mi->second; + return true; + } + } + return false; + } + void GetPaymentAddresses(std::set &setAddress) const + { + setAddress.clear(); + { + LOCK(cs_KeyStore); + SpendingKeyMap::const_iterator mi = mapSpendingKeys.begin(); + while (mi != mapSpendingKeys.end()) + { + setAddress.insert((*mi).first); + mi++; + } + } + } }; typedef std::vector > CKeyingMaterial; diff --git a/src/uint252.h b/src/uint252.h index c5ab1a380..6281e8533 100644 --- a/src/uint252.h +++ b/src/uint252.h @@ -43,6 +43,8 @@ public: uint256 inner() const { return contents; } + + friend inline bool operator==(const uint252& a, const uint252& b) { return a.contents == b.contents; } }; #endif diff --git a/src/zcash/Address.cpp b/src/zcash/Address.cpp index 446a63db1..9bb32fb6c 100644 --- a/src/zcash/Address.cpp +++ b/src/zcash/Address.cpp @@ -8,7 +8,7 @@ uint256 ViewingKey::pk_enc() { return ZCNoteEncryption::generate_pubkey(*this); } -ViewingKey SpendingKey::viewing_key() { +ViewingKey SpendingKey::viewing_key() const { return ViewingKey(ZCNoteEncryption::generate_privkey(*this)); } @@ -16,8 +16,8 @@ SpendingKey SpendingKey::random() { return SpendingKey(random_uint252()); } -PaymentAddress SpendingKey::address() { +PaymentAddress SpendingKey::address() const { return PaymentAddress(PRF_addr_a_pk(*this), viewing_key().pk_enc()); } -} \ No newline at end of file +} diff --git a/src/zcash/Address.hpp b/src/zcash/Address.hpp index 86e351cf7..36b9402a3 100644 --- a/src/zcash/Address.hpp +++ b/src/zcash/Address.hpp @@ -22,6 +22,8 @@ public: READWRITE(a_pk); READWRITE(pk_enc); } + + friend inline bool operator<(const PaymentAddress& a, const PaymentAddress& b) { return a.a_pk < b.a_pk; } }; class ViewingKey : public uint256 { @@ -38,8 +40,8 @@ public: static SpendingKey random(); - ViewingKey viewing_key(); - PaymentAddress address(); + ViewingKey viewing_key() const; + PaymentAddress address() const; }; }