From 856f1a1ea7b29d0508443663c9084284a1cece63 Mon Sep 17 00:00:00 2001 From: DenioD <41270280+DenioD@users.noreply.github.com> Date: Wed, 30 Sep 2020 21:17:04 +0200 Subject: [PATCH] add wif import support --- Cargo.lock | 1 + lib/Cargo.toml | 1 + lib/src/commands.rs | 37 ++++++ lib/src/lightclient.rs | 172 +++++++++++++++++----------- lib/src/lightwallet.rs | 184 ++++++++++++++++++++---------- lib/src/lightwallet/walletzkey.rs | 162 +++++++++++++++++++++++++- 6 files changed, 426 insertions(+), 131 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eba7ad6..d723abd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1744,6 +1744,7 @@ version = "0.1.0" dependencies = [ "base58 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "bellman 0.1.0 (git+https://github.com/MyHush/librustzcash.git?rev=1a0204113d487cdaaf183c2967010e5214ff9e37)", + "bs58 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)", "dirs 2.0.2 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 9e7e41f..b61970d 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -9,6 +9,7 @@ embed_params = [] [dependencies] base58 = "0.1.0" +bs58 = { version = "0.2", features = ["check"] } log = "0.4" log4rs = "0.8.3" dirs = "2.0.2" diff --git a/lib/src/commands.rs b/lib/src/commands.rs index 8c65a8a..3d853fa 100644 --- a/lib/src/commands.rs +++ b/lib/src/commands.rs @@ -701,6 +701,42 @@ impl Command for ImportCommand { } } +struct TImportCommand {} +impl Command for TImportCommand { + fn help(&self) -> String { + let mut h = vec![]; + h.push("Import an external WIF"); + h.push("Usage:"); + h.push("timport wif (Begins with U"); + h.push(""); + + h.join("\n") + } + + + fn short_help(&self) -> String { + "Import wif to the wallet".to_string() + } + + fn exec(&self, args: &[&str], lightclient: &LightClient) -> String { + + let key = args[0]; + + let r = match lightclient.do_import_tk(key.to_string()){ + Ok(r) => r.pretty(2), + Err(e) => return format!("Error: {}", e), + }; + + match lightclient.do_rescan() { + Ok(_) => {}, + Err(e) => return format!("Error: Rescan failed: {}", e), + }; + + + return r; + } +} + struct HeightCommand {} impl Command for HeightCommand { fn help(&self) -> String { @@ -864,6 +900,7 @@ pub fn get_commands() -> Box>> { map.insert("addresses".to_string(), Box::new(AddressCommand{})); map.insert("height".to_string(), Box::new(HeightCommand{})); map.insert("import".to_string(), Box::new(ImportCommand{})); + map.insert("timport".to_string(), Box::new(TImportCommand{})); map.insert("export".to_string(), Box::new(ExportCommand{})); map.insert("info".to_string(), Box::new(InfoCommand{})); map.insert("coinsupply".to_string(), Box::new(CoinsupplyCommand{})); diff --git a/lib/src/lightclient.rs b/lib/src/lightclient.rs index 6154396..38c48ea 100644 --- a/lib/src/lightclient.rs +++ b/lib/src/lightclient.rs @@ -1,7 +1,5 @@ use crate::lightwallet::LightWallet; -use rand::{rngs::OsRng, seq::SliceRandom}; - use std::sync::{Arc, RwLock, Mutex, mpsc::channel}; use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; use std::path::{Path, PathBuf}; @@ -284,7 +282,7 @@ impl LightClientConfig { pub fn base58_secretkey_prefix(&self) -> [u8; 1] { match &self.chain_name[..] { - "main" => [0x80], + "main" => [0xBC], "test" => [0xEF], "regtest" => [0xEF], c => panic!("Unknown chain {}", c) @@ -612,7 +610,7 @@ impl LightClient { let z_addresses = wallet.get_all_zaddresses(); // Collect t addresses - let t_addresses = wallet.taddresses.read().unwrap().iter().map( |a| a.clone() ) + let t_addresses = wallet.get_all_taddresses().iter().map( |a| a.clone() ) .collect::>(); object!{ @@ -635,7 +633,7 @@ impl LightClient { }).collect::>(); // Collect t addresses - let t_addresses = wallet.taddresses.read().unwrap().iter().map( |address| { + let t_addresses = wallet.get_all_taddresses().iter().map( |address| { // Get the balance for this address let balance = wallet.tbalance(Some(address.clone())) ; @@ -1012,6 +1010,31 @@ impl LightClient { Ok(array![new_address]) } + /// Import a new private key + pub fn do_import_tk(&self, sk: String) -> Result { + if !self.wallet.read().unwrap().is_unlocked_for_spending() { + error!("Wallet is locked"); + return Err("Wallet is locked".to_string()); + } + + let new_address = { + let wallet = self.wallet.write().unwrap(); + + let addr = wallet.import_taddr(sk); + if addr.starts_with("Error") { + let e = format!("Error creating new address{}", addr); + error!("{}", e); + return Err(e); + } + + addr + }; + + self.do_save()?; + + Ok(array![new_address]) + } + /// Convinence function to determine what type of key this is and import it pub fn do_import_key(&self, key: String, birthday: u64) -> Result { if key.starts_with(self.config.hrp_sapling_private_key()) { @@ -1303,53 +1326,18 @@ impl LightClient { // So, reset the total_reorg total_reorg = 0; - // We'll also fetch all the txids that our transparent addresses are involved with - { - // Copy over addresses so as to not lock up the wallet, which we'll use inside the callback below. - let addresses = self.wallet.read().unwrap() - .taddresses.read().unwrap().iter().map(|a| a.clone()) - .collect::>(); - - // Create a channel so the fetch_transparent_txids can send the results back - let (ctx, crx) = channel(); - let num_addresses = addresses.len(); - - for address in addresses { - let wallet = self.wallet.clone(); - let block_times_inner = block_times.clone(); - - // If this is the first pass after a retry, fetch older t address txids too, becuse - // they might have been missed last time. - let transparent_start_height = if pass == 1 && retry_count > 0 { - start_height - scan_batch_size - } else { - start_height - }; + // If this is the first pass after a retry, fetch older t address txids too, becuse + // they might have been missed last time. + let transparent_start_height = if pass == 1 && retry_count > 0 { + start_height - scan_batch_size + } else { + start_height + }; - let pool = pool.clone(); - let server_uri = self.get_server_uri(); - let ctx = ctx.clone(); - let no_cert = self.config.no_cert_verification; - - pool.execute(move || { - // Fetch the transparent transactions for this address, and send the results - // via the channel - let r = fetch_transparent_txids(&server_uri, address, transparent_start_height, end_height, no_cert, - move |tx_bytes: &[u8], height: u64| { - let tx = Transaction::read(tx_bytes).unwrap(); - - // Scan this Tx for transparent inputs and outputs - let datetime = block_times_inner.read().unwrap().get(&height).map(|v| *v).unwrap_or(0); - wallet.read().unwrap().scan_full_tx(&tx, height as i32, datetime as u64); - }); - ctx.send(r).unwrap(); - }); - } + let no_cert = true; - // Collect all results from the transparent fetches, and make sure everything was OK. - // If it was not, we return an error, which will go back to the retry - crx.iter().take(num_addresses).collect::, String>>()?; - } + // We'll also fetch all the txids that our transparent addresses are involved with + self.scan_taddress_txids(&pool, block_times, transparent_start_height, end_height, no_cert)?; // Do block height accounting last_scanned_height = end_height; @@ -1376,6 +1364,61 @@ impl LightClient { // Get the Raw transaction for all the wallet transactions + { + let decoy_txids = all_new_txs.read().unwrap(); + match self.scan_fill_fulltxs(&pool, decoy_txids.to_vec()) { + Ok(_) => Ok(object!{ + "result" => "success", + "latest_block" => latest_block, + "downloaded_bytes" => bytes_downloaded.load(Ordering::SeqCst) + }), + Err(e) => Err(format!("Error fetching all txns for memos: {}", e)) + } + } + } + + fn scan_taddress_txids(&self, pool: &ThreadPool, block_times: Arc>>, start_height: u64, end_height: u64, no_cert: bool) -> Result, String> { + // Copy over addresses so as to not lock up the wallet, which we'll use inside the callback below. + let addresses = self.wallet.read().unwrap() + .get_all_taddresses().iter() + .map(|a| a.clone()) + .collect::>(); + + // Create a channel so the fetch_transparent_txids can send the results back + let (ctx, crx) = channel(); + let num_addresses = addresses.len(); + + for address in addresses { + let wallet = self.wallet.clone(); + + let pool = pool.clone(); + let server_uri = self.get_server_uri(); + let ctx = ctx.clone(); + + let block_times = block_times.clone(); + + pool.execute(move || { + // Fetch the transparent transactions for this address, and send the results + // via the channel + let r = fetch_transparent_txids(&server_uri, address, start_height, end_height,no_cert, + move |tx_bytes: &[u8], height: u64| { + let tx = Transaction::read(tx_bytes).unwrap(); + + // Scan this Tx for transparent inputs and outputs + let datetime = block_times.read().unwrap().get(&height).map(|v| *v).unwrap_or(0); + wallet.read().unwrap().scan_full_tx(&tx, height as i32, datetime as u64); + }); + ctx.send(r).unwrap(); + }); + } + + // Collect all results from the transparent fetches, and make sure everything was OK. + // If it was not, we return an error, which will go back to the retry + crx.iter().take(num_addresses).collect::, String>>() + } + + fn scan_fill_fulltxs(&self, pool: &ThreadPool, decoy_txids: Vec<(TxId, i32)>) -> Result, String> { + // We need to first copy over the Txids from the wallet struct, because // we need to free the read lock from here (Because we'll self.wallet.txs later) let mut txids_to_fetch: Vec<(TxId, i32)> = self.wallet.read().unwrap().txs.read().unwrap().values() @@ -1383,14 +1426,11 @@ impl LightClient { .map(|wtx| (wtx.txid.clone(), wtx.block)) .collect::>(); - info!("Fetching {} new txids, total {} with decoy", txids_to_fetch.len(), all_new_txs.read().unwrap().len()); - txids_to_fetch.extend_from_slice(&all_new_txs.read().unwrap()[..]); + info!("Fetching {} new txids, total {} with decoy", txids_to_fetch.len(), decoy_txids.len()); + txids_to_fetch.extend_from_slice(&decoy_txids[..]); txids_to_fetch.sort(); txids_to_fetch.dedup(); - let mut rng = OsRng; - txids_to_fetch.shuffle(&mut rng); - let num_fetches = txids_to_fetch.len(); let (ctx, crx) = channel(); @@ -1420,15 +1460,7 @@ impl LightClient { }; // Wait for all the fetches to finish. - let result = crx.iter().take(num_fetches).collect::, String>>(); - match result { - Ok(_) => Ok(object!{ - "result" => "success", - "latest_block" => latest_block, - "downloaded_bytes" => bytes_downloaded.load(Ordering::SeqCst) - }), - Err(e) => Err(format!("Error fetching all txns for memos: {}", e)) - } + crx.iter().take(num_fetches).collect::, String>>() } pub fn do_send(&self, addrs: Vec<(&str, u64, Option)>) -> Result { @@ -1486,11 +1518,15 @@ pub mod tests { assert!(!lc.do_export(None).is_err()); assert!(!lc.do_seed_phrase().is_err()); - // This will lock the wallet again, so after this, we'll need to unlock again - assert!(!lc.do_new_address("R").is_err()); - lc.wallet.write().unwrap().unlock("password".to_string()).unwrap(); + // Can't add keys while unlocked but encrypted + assert!(lc.do_new_address("R").is_err()); + assert!(lc.do_new_address("zs").is_err()); + + // Remove encryption, which will allow adding + lc.wallet.write().unwrap().remove_encryption("password".to_string()).unwrap(); - assert!(!lc.do_new_address("zs").is_err()); + assert!(lc.do_new_address("R").is_ok()); + assert!(lc.do_new_address("zs").is_ok()); } #[test] diff --git a/lib/src/lightwallet.rs b/lib/src/lightwallet.rs index cc6af1b..083f20e 100644 --- a/lib/src/lightwallet.rs +++ b/lib/src/lightwallet.rs @@ -65,7 +65,7 @@ mod walletzkey; use data::{BlockData, WalletTx, Utxo, SaplingNoteData, SpendableNote, OutgoingTxMetadata}; use extended_key::{KeyIndex, ExtendedPrivKey}; -use walletzkey::{WalletZKey, WalletZKeyType}; +use walletzkey::{WalletZKey, WalletTKey, WalletZKeyType}; pub const MAX_REORG: usize = 100; pub const GAP_RULE_UNUSED_ADDRESSES: usize = 5; @@ -124,11 +124,8 @@ pub struct LightWallet { // viewing keys and imported spending keys. zkeys: Arc>>, - - // Transparent keys. If the wallet is locked, then the secret keys will be encrypted, - // but the addresses will be present. - tkeys: Arc>>, - pub taddresses: Arc>>, + // Transparent keys. + tkeys: Arc>>, blocks: Arc>>, pub txs: Arc>>, @@ -149,7 +146,7 @@ pub struct LightWallet { impl LightWallet { pub fn serialized_version() -> u64 { - return 7; + return 9; } fn get_taddr_from_bip39seed(config: &LightClientConfig, bip39_seed: &[u8], pos: u32) -> SecretKey { @@ -255,8 +252,7 @@ impl LightWallet { nonce: vec![], seed: seed_bytes, zkeys: Arc::new(RwLock::new(vec![WalletZKey::new_hdkey(hdkey_num, extsk)])), - tkeys: Arc::new(RwLock::new(vec![tpk])), - taddresses: Arc::new(RwLock::new(vec![taddr])), + tkeys: Arc::new(RwLock::new(vec![WalletTKey::new_hdkey(tpk, taddr)])), blocks: Arc::new(RwLock::new(vec![])), txs: Arc::new(RwLock::new(HashMap::new())), mempool_txs: Arc::new(RwLock::new(HashMap::new())), @@ -374,18 +370,29 @@ impl LightWallet { // Calculate the addresses - let tkeys = Vector::read(&mut reader, |r| { - let mut tpk_bytes = [0u8; 32]; - r.read_exact(&mut tpk_bytes)?; - secp256k1::SecretKey::from_slice(&tpk_bytes).map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) - })?; - - let taddresses = if version >= 4 { - // Read the addresses - Vector::read(&mut reader, |r| utils::read_string(r))? + let wallet_tkeys = if version >= 9 { + Vector::read(&mut reader, |r| { + WalletTKey::read(r) + })? } else { - // Calculate the addresses - tkeys.iter().map(|sk| LightWallet::address_from_prefix_sk(&config.base58_pubkey_address(), sk)).collect() + + let tkeys = Vector::read(&mut reader, |r| { + let mut tpk_bytes = [0u8; 32]; + r.read_exact(&mut tpk_bytes)?; + secp256k1::SecretKey::from_slice(&tpk_bytes).map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + })?; + + let taddresses = if version >= 4 { + // Read the addresses + Vector::read(&mut reader, |r| utils::read_string(r))? + } else { + // Calculate the addresses + tkeys.iter().map(|sk| LightWallet::address_from_prefix_sk(&config.base58_pubkey_address(), sk)).collect() + }; + + tkeys.iter().zip(taddresses.iter()).map(|(k, a)| + WalletTKey::new_hdkey(*k, a.clone()) + ).collect() }; let blocks = Vector::read(&mut reader, |r| BlockData::read(r))?; @@ -414,8 +421,7 @@ impl LightWallet { nonce: nonce, seed: seed_bytes, zkeys: Arc::new(RwLock::new(zkeys)), - tkeys: Arc::new(RwLock::new(tkeys)), - taddresses: Arc::new(RwLock::new(taddresses)), + tkeys: Arc::new(RwLock::new(wallet_tkeys)), blocks: Arc::new(RwLock::new(blocks)), txs: Arc::new(RwLock::new(txs)), mempool_txs: Arc::new(RwLock::new(HashMap::new())), @@ -456,12 +462,7 @@ impl LightWallet { // Write the transparent private keys Vector::write(&mut writer, &self.tkeys.read().unwrap(), - |w, pk| w.write_all(&pk[..]) - )?; - - // Write the transparent addresses - Vector::write(&mut writer, &self.taddresses.read().unwrap(), - |w, a| utils::write_string(w, a) + |w, tk| tk.write(w) )?; Vector::write(&mut writer, &self.blocks.read().unwrap(), |w, b| b.write(w))?; @@ -533,9 +534,14 @@ impl LightWallet { /// Get all t-address private keys. Returns a Vector of (address, secretkey) pub fn get_t_secret_keys(&self) -> Vec<(String, String)> { - self.tkeys.read().unwrap().iter().map(|sk| { - (self.address_from_sk(sk), - sk[..].to_base58check(&self.config.base58_secretkey_prefix(), &[0x01])) + self.tkeys.read().unwrap().iter().map(|wtk| { + let sk = if wtk.tkey.is_some() { + wtk.tkey.unwrap()[..].to_base58check(&self.config.base58_secretkey_prefix(), &[0x01]) + } else { + "".to_string() + }; + + (wtk.address.clone(), sk) }).collect::>() } @@ -547,6 +553,10 @@ impl LightWallet { return "Error: Can't add key while wallet is locked".to_string(); } + if self.encrypted { + return "Error: Can't add key while wallet is encrypted".to_string(); + } + // Find the highest pos we have let pos = self.zkeys.read().unwrap().iter() .filter(|zk| zk.hdkey_num.is_some()) @@ -603,14 +613,49 @@ impl LightWallet { return "Error: Can't add key while wallet is locked".to_string(); } + if self.encrypted { + return "Error: Can't add key while wallet is encrypted".to_string(); + } + let pos = self.tkeys.read().unwrap().len() as u32; let bip39_seed = bip39::Seed::new(&Mnemonic::from_entropy(&self.seed, Language::English).unwrap(), ""); let sk = LightWallet::get_taddr_from_bip39seed(&self.config, &bip39_seed.as_bytes(), pos); let address = self.address_from_sk(&sk); - self.tkeys.write().unwrap().push(sk); - self.taddresses.write().unwrap().push(address.clone()); + self.tkeys.write().unwrap().push(WalletTKey::new_hdkey(sk, address.clone())); + + address + } + + pub fn import_taddr(&self, sk: String) -> String { + if !self.unlocked { + return "Error: Can't add key while wallet is locked".to_string(); + } + + //// Decode Wif to base58 to hex + let sk_to_bs58 = bs58::decode(sk).into_vec().unwrap(); + + let bs58_to_hex = hex::encode(sk_to_bs58); + + //// Manipulate string, to exclude last 4 bytes (checksum bytes), first 2 bytes (secretkey prefix) and the compressed flag (works only for compressed Wifs!) + + let slice_sk = &bs58_to_hex[2..66]; + + //// Get the SecretKey from slice + let secret_key = SecretKey::from_slice(&hex::decode(slice_sk).unwrap()); + + let sk_raw = secret_key.unwrap(); + + //// Make sure the key doesn't already exist + if self.tkeys.read().unwrap().iter().find(|&wk| wk.tkey.is_some() && wk.tkey.as_ref().unwrap() == &sk_raw.clone()).is_some() { + return "Error: Key already exists".to_string(); + } + //// Get the taddr from key + let address = self.address_from_sk(&sk_raw); + + //// Add to tkeys + self.tkeys.write().unwrap().push(WalletTKey::import_hdkey(sk_raw , address.clone())); address } @@ -791,6 +836,12 @@ impl LightWallet { } } + pub fn get_all_taddresses(&self) -> Vec { + self.tkeys.read().unwrap() + .iter() + .map(|wtx| wtx.address.clone()).collect() + } + pub fn get_all_zaddresses(&self) -> Vec { self.zkeys.read().unwrap().iter().map( |zk| { encode_payment_address(self.config.hrp_sapling_address(), &zk.zaddress) @@ -863,6 +914,11 @@ impl LightWallet { self.nonce = nonce.as_ref().to_vec(); // Encrypt the individual keys + + self.tkeys.write().unwrap().iter_mut() + .map(|k| k.encrypt(&key)) + .collect::>>()?; + self.zkeys.write().unwrap().iter_mut() .map(|k| k.encrypt(&key)) .collect::>>()?; @@ -884,7 +940,12 @@ impl LightWallet { // Empty the seed and the secret keys self.seed.copy_from_slice(&[0u8; 32]); - self.tkeys = Arc::new(RwLock::new(vec![])); + + // Remove all the private key from the tkeys + self.tkeys.write().unwrap().iter_mut().map(|tk| { + tk.lock() + }).collect::>>()?; + // Remove all the private key from the zkeys self.zkeys.write().unwrap().iter_mut().map(|zk| { zk.lock() @@ -922,19 +983,10 @@ impl LightWallet { // we need to get the 64 byte bip39 entropy let bip39_seed = bip39::Seed::new(&Mnemonic::from_entropy(&seed, Language::English).unwrap(), ""); - // Transparent keys - let mut tkeys = vec![]; - for pos in 0..self.taddresses.read().unwrap().len() { - let sk = LightWallet::get_taddr_from_bip39seed(&self.config, &bip39_seed.as_bytes(), pos as u32); - let address = self.address_from_sk(&sk); - - if address != self.taddresses.read().unwrap()[pos] { - return Err(io::Error::new(ErrorKind::InvalidData, - format!("taddress mismatch at {}. {} vs {}", pos, address, self.taddresses.read().unwrap()[pos]))); - } - - tkeys.push(sk); - } + // Go over the tkeys, and add the keys again + self.tkeys.write().unwrap().iter_mut().map(|tk| { + tk.unlock(&key) + }).collect::>>()?; // Go over the zkeys, and add the spending keys again self.zkeys.write().unwrap().iter_mut().map(|zk| { @@ -942,7 +994,6 @@ impl LightWallet { }).collect::>>()?; // Everything checks out, so we'll update our wallet with the decrypted values - self.tkeys = Arc::new(RwLock::new(tkeys)); self.seed.copy_from_slice(&seed); self.encrypted = true; @@ -963,8 +1014,13 @@ impl LightWallet { self.unlock(passwd)?; } + // Remove encryption from individual tkeys + self.tkeys.write().unwrap().iter_mut().map(|tk| { + tk.remove_encryption() + }).collect::>>()?; + // Remove encryption from individual zkeys - self.zkeys.write().unwrap().iter_mut().map(|zk| { + self.zkeys.write().unwrap().iter_mut().map(|zk| { zk.remove_encryption() }).collect::>>()?; @@ -1159,7 +1215,12 @@ impl LightWallet { // If one of the last 'n' taddress was used, ensure we add the next HD taddress to the wallet. pub fn ensure_hd_taddresses(&self, address: &String) { let last_addresses = { - self.taddresses.read().unwrap().iter().rev().take(GAP_RULE_UNUSED_ADDRESSES).map(|s| s.clone()).collect::>() + self.tkeys.read().unwrap() + .iter() + .map(|t| t.address.clone()) + .rev().take(GAP_RULE_UNUSED_ADDRESSES).map(|s| + s.clone()) + .collect::>() }; match last_addresses.iter().position(|s| *s == *address) { @@ -1252,7 +1313,8 @@ impl LightWallet { } // Scan for t outputs - let all_taddresses = self.taddresses.read().unwrap().iter() + let all_taddresses = self.tkeys.read().unwrap().iter() + .map(|wtx| wtx.address.clone()) .map(|a| a.clone()) .collect::>(); for address in all_taddresses { @@ -1279,7 +1341,8 @@ impl LightWallet { // outgoing metadata // Collect our t-addresses - let wallet_taddrs = self.taddresses.read().unwrap().iter() + let wallet_taddrs = self.tkeys.read().unwrap().iter() + .map(|wtx| wtx.address.clone()) .map(|a| a.clone()) .collect::>(); @@ -2037,7 +2100,8 @@ impl LightWallet { // Create a map from address -> sk for all taddrs, so we can spend from the // right address let address_to_sk = self.tkeys.read().unwrap().iter() - .map(|sk| (self.address_from_sk(&sk), sk.clone())) + .filter(|wtk| wtk.tkey.is_some()) + .map(|wtk| (wtk.address.clone(), wtk.tkey.unwrap().clone())) .collect::>(); // Add all tinputs @@ -2050,15 +2114,11 @@ impl LightWallet { script_pubkey: Script { 0: utxo.script.clone() }, }; - match address_to_sk.get(&utxo.address) { - Some(sk) => builder.add_transparent_input(*sk, outpoint.clone(), coin.clone()), - None => { - // Something is very wrong - let e = format!("Couldn't find the secreykey for taddr {}", utxo.address); - error!("{}", e); - - Err(zcash_primitives::transaction::builder::Error::InvalidAddress) - } + if let Some(sk) = address_to_sk.get(&utxo.address) { + return builder.add_transparent_input(*sk, outpoint.clone(), coin.clone()) + } else { + info!("Not adding a UTXO because secret key is absent."); + return Ok(()) } }) diff --git a/lib/src/lightwallet/walletzkey.rs b/lib/src/lightwallet/walletzkey.rs index 8cc4e88..ee1dec7 100644 --- a/lib/src/lightwallet/walletzkey.rs +++ b/lib/src/lightwallet/walletzkey.rs @@ -13,7 +13,167 @@ use zcash_primitives::{ }; use crate::lightclient::{LightClientConfig}; -use crate::lightwallet::LightWallet; +use crate::lightwallet::{LightWallet, utils}; + +#[derive(PartialEq, Debug, Clone)] +pub enum WalletTKeyType { + HdKey = 0, + ImportedKey = 1, +} + + +// A struct that holds z-address private keys or view keys +#[derive(Clone, Debug, PartialEq)] +pub struct WalletTKey { + pub(super) keytype: WalletTKeyType, + locked: bool, + pub(super) address: String, + pub(super) tkey: Option, + + // If locked, the encrypted key is here + enc_key: Option>, + nonce: Option>, +} + +impl WalletTKey { + pub fn new_hdkey(key: secp256k1::SecretKey, address: String) -> Self { + WalletTKey { + keytype: WalletTKeyType::HdKey, + locked: false, + address, + tkey: Some(key), + + enc_key: None, + nonce: None, + } + } + + pub fn import_hdkey(key: secp256k1::SecretKey, address: String) -> Self { + WalletTKey { + keytype: WalletTKeyType::ImportedKey, + locked: false, + address, + tkey: Some(key), + + enc_key: None, + nonce: None, + } + } + + fn serialized_version() -> u8 { + return 1; + } + + pub fn read(mut inp: R) -> io::Result { + let version = inp.read_u8()?; + assert!(version <= Self::serialized_version()); + + let keytype: WalletTKeyType = match inp.read_u32::()? { + 0 => Ok(WalletTKeyType::HdKey), + 1 => Ok(WalletTKeyType::ImportedKey), + n => Err(io::Error::new(ErrorKind::InvalidInput, format!("Unknown tkey type {}", n))) + }?; + + let locked = inp.read_u8()? > 0; + + let address = utils::read_string(&mut inp)?; + let tkey = Optional::read(&mut inp, |r| { + let mut tpk_bytes = [0u8; 32]; + r.read_exact(&mut tpk_bytes)?; + secp256k1::SecretKey::from_slice(&tpk_bytes).map_err(|e| io::Error::new(ErrorKind::InvalidData, e)) + })?; + + let enc_key = Optional::read(&mut inp, |r| + Vector::read(r, |r| r.read_u8()))?; + let nonce = Optional::read(&mut inp, |r| + Vector::read(r, |r| r.read_u8()))?; + + Ok(WalletTKey { + keytype, + locked, + address, + tkey, + enc_key, + nonce, + }) + } + + pub fn write(&self, mut out: W) -> io::Result<()> { + out.write_u8(Self::serialized_version())?; + + out.write_u32::(self.keytype.clone() as u32)?; + + out.write_u8(self.locked as u8)?; + + utils::write_string(&mut out, &self.address)?; + Optional::write(&mut out, &self.tkey, |w, pk| + w.write_all(&pk[..]) + )?; + + // Write enc_key + Optional::write(&mut out, &self.enc_key, |o, v| + Vector::write(o, v, |o,n| o.write_u8(*n)))?; + + // Write nonce + Optional::write(&mut out, &self.nonce, |o, v| + Vector::write(o, v, |o,n| o.write_u8(*n))) + } + + + pub fn lock(&mut self) -> io::Result<()> { + // For keys, encrypt the key into enckey + // assert that we have the encrypted key. + if self.enc_key.is_none() { + return Err(Error::new(ErrorKind::InvalidInput, "Can't lock when t-addr private key is not encrypted")); + } + self.tkey = None; + self.locked = true; + + + Ok(()) + } + + pub fn unlock(&mut self, key: &secretbox::Key) -> io::Result<()> { + // For imported keys, we need to decrypt from the encrypted key + let nonce = secretbox::Nonce::from_slice(&self.nonce.as_ref().unwrap()).unwrap(); + let sk_bytes = match secretbox::open(&self.enc_key.as_ref().unwrap(), &nonce, &key) { + Ok(s) => s, + Err(_) => {return Err(io::Error::new(ErrorKind::InvalidData, "Decryption failed. Is your password correct?"));} + }; + + self.tkey = Some(secp256k1::SecretKey::from_slice(&sk_bytes[..]).map_err(|e| + io::Error::new(ErrorKind::InvalidData, format!("{}", e)) + )?); + + self.locked = false; + Ok(()) + } + + pub fn encrypt(&mut self, key: &secretbox::Key) -> io::Result<()> { + // For keys, encrypt the key into enckey + let nonce = secretbox::gen_nonce(); + + let sk_bytes = &self.tkey.unwrap()[..]; + + self.enc_key = Some(secretbox::seal(&sk_bytes, &nonce, &key)); + self.nonce = Some(nonce.as_ref().to_vec()); + + self.tkey = None; + + // Also lock after encrypt + self.lock() + } + + pub fn remove_encryption(&mut self) -> io::Result<()> { + if self.locked { + return Err(Error::new(ErrorKind::InvalidInput, "Can't remove encryption while locked")); + } + + self.enc_key = None; + self.nonce = None; + Ok(()) + } +} #[derive(PartialEq, Debug, Clone)] pub enum WalletZKeyType {