From e6ad8d96172b5538112b70daef85ede3bed0dd42 Mon Sep 17 00:00:00 2001 From: DenioD <41270280+DenioD@users.noreply.github.com> Date: Wed, 29 Jul 2020 15:23:54 +0200 Subject: [PATCH] Multi Thread sync, ported from https://github.com/adityapk00/zecwallet-light-cli/commit/5d2b85c03a88e91be69115f38fde81efddef62b9 --- Cargo.lock | 22 +++- lib/Cargo.toml | 3 + lib/src/grpcconnector.rs | 84 +++++++----- lib/src/lightclient.rs | 113 +++++++++++----- lib/src/lightwallet.rs | 277 ++++++++++++++++++++++++++++++++++++--- 5 files changed, 414 insertions(+), 85 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3fc905..eba7ad6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -174,7 +174,7 @@ dependencies = [ "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)", "futures-cpupool 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "group 0.1.0 (git+https://github.com/MyHush/librustzcash.git?rev=1a0204113d487cdaaf183c2967010e5214ff9e37)", - "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", "pairing 0.14.2 (git+https://github.com/MyHush/librustzcash.git?rev=1a0204113d487cdaaf183c2967010e5214ff9e37)", "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -559,7 +559,7 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "futures 0.1.29 (registry+https://github.com/rust-lang/crates.io-index)", - "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1000,7 +1000,7 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.11.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "hermit-abi 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1755,6 +1755,7 @@ dependencies = [ "libflate 0.1.27 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "log4rs 0.8.3 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", "pairing 0.14.2 (git+https://github.com/MyHush/librustzcash.git?rev=1a0204113d487cdaaf183c2967010e5214ff9e37)", "prost 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", "prost-types 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1766,7 +1767,9 @@ dependencies = [ "secp256k1 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", "sha2 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "sodiumoxide 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)", + "subtle 2.2.2 (registry+https://github.com/rust-lang/crates.io-index)", "tempdir 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)", + "threadpool 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "tiny-bip39 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", "tokio-rustls 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1933,6 +1936,14 @@ dependencies = [ "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "threadpool" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "time" version = "0.1.42" @@ -1972,7 +1983,7 @@ dependencies = [ "mio 0.6.21 (registry+https://github.com/rust-lang/crates.io-index)", "mio-named-pipes 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", "mio-uds 0.6.7 (registry+https://github.com/rust-lang/crates.io-index)", - "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)", "pin-project-lite 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "signal-hook-registry 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2707,7 +2718,7 @@ dependencies = [ "checksum num-bigint 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f9c3f34cdd24f334cb265d9bf8bfa8a241920d026916785747a92f0e55541a1a" "checksum num-integer 0.1.41 (registry+https://github.com/rust-lang/crates.io-index)" = "b85e541ef8255f6cf42bbfe4ef361305c6c135d10919ecc26126c4e5ae94bc09" "checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4" -"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" +"checksum num_cpus 1.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" "checksum once_cell 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "532c29a261168a45ce28948f9537ddd7a5dd272cc513b3017b1e82a88f962c37" "checksum opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" "checksum openssl-probe 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de" @@ -2809,6 +2820,7 @@ dependencies = [ "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum thread-id 3.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c7fbf4c9d56b320106cd64fd024dadfa0be7cb4706725fc44a7d7ce952d820c1" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" +"checksum threadpool 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e8dae184447c15d5a6916d973c642aec485105a13cd238192a6927ae3e077d66" "checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f" "checksum tiny-bip39 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c1c5676413eaeb1ea35300a0224416f57abc3bd251657e0fafc12c47ff98c060" "checksum tokio 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0e1bef565a52394086ecac0a6fa3b8ace4cb3a138ee1d96bd2b93283b56824e3" diff --git a/lib/Cargo.toml b/lib/Cargo.toml index d554947..9967a4b 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -23,6 +23,9 @@ rand = "0.7.2" sodiumoxide = "0.2.5" ring = "0.16.9" libflate = "0.1" +subtle = "2" +threadpool = "1.8.0" +num_cpus = "1.13.0" tonic = { version = "0.1.1", features = ["tls", "tls-roots"] } bytes = "0.4" diff --git a/lib/src/grpcconnector.rs b/lib/src/grpcconnector.rs index 8e5f674..2492be3 100644 --- a/lib/src/grpcconnector.rs +++ b/lib/src/grpcconnector.rs @@ -2,12 +2,15 @@ use log::{error}; use std::sync::Arc; use zcash_primitives::transaction::{TxId}; -use crate::grpc_client::{ChainSpec, BlockId, BlockRange, RawTransaction, +use crate::grpc_client::{ChainSpec, BlockId, BlockRange, RawTransaction, CompactBlock, TransparentAddressBlockFilter, TxFilter, Empty, LightdInfo, Coinsupply}; use tonic::transport::{Channel, ClientTlsConfig}; use tokio_rustls::{rustls::ClientConfig}; use tonic::{Request}; +use threadpool::ThreadPool; +use std::sync::mpsc::channel; + use crate::PubCertificate; use crate::grpc_client::compact_tx_streamer_client::CompactTxStreamerClient; @@ -95,7 +98,7 @@ pub fn get_coinsupply(uri: http::Uri, no_cert: bool) -> Result(uri: &http::Uri, start_height: u64, end_height: u64, no_cert: bool, c: F) +async fn get_block_range(uri: &http::Uri, start_height: u64, end_height: u64, no_cert: bool, pool: ThreadPool, c: F) -> Result<(), Box> where F : Fn(&[u8], u64) { let mut client = get_client(uri, no_cert).await?; @@ -105,20 +108,40 @@ where F : Fn(&[u8], u64) { let request = Request::new(BlockRange{ start: Some(bs), end: Some(be) }); + // Channel where the blocks are sent. A None signifies end of all blocks + let (tx, rx) = channel::>(); + + // Channel that the processor signals it is done, so the method can return + let (ftx, frx) = channel(); + + // The processor runs on a different thread, so that the network calls don't + // block on this + pool.execute(move || { + while let Some(block) = rx.recv().unwrap() { + use prost::Message; + let mut encoded_buf = vec![]; + + block.encode(&mut encoded_buf).unwrap(); + c(&encoded_buf, block.height); + } + + ftx.send(Ok(())).unwrap(); + }); + let mut response = client.get_block_range(request).await?.into_inner(); //println!("{:?}", response); while let Some(block) = response.message().await? { - use prost::Message; - let mut encoded_buf = vec![]; - - block.encode(&mut encoded_buf).unwrap(); - c(&encoded_buf, block.height); + tx.send(Some(block)).unwrap(); } + tx.send(None).unwrap(); + + // Wait for the processor to exit + frx.iter().take(1).collect::, String>>()?; Ok(()) } -pub fn fetch_blocks(uri: &http::Uri, start_height: u64, end_height: u64, no_cert: bool, c: F) -> Result<(), String> +pub fn fetch_blocks(uri: &http::Uri, start_height: u64, end_height: u64, no_cert: bool, pool: ThreadPool, c: F) -> Result<(), String> where F : Fn(&[u8], u64) { let mut rt = match tokio::runtime::Runtime::new() { @@ -131,7 +154,7 @@ pub fn fetch_blocks(uri: &http::Uri, start_heig } }; - match rt.block_on(get_block_range(uri, start_height, end_height, no_cert, c)) { + match rt.block_on(get_block_range(uri, start_height, end_height, no_cert, pool, c)) { Ok(o) => Ok(o), Err(e) => { let e = format!("Error fetching blocks {:?}", e); @@ -202,26 +225,26 @@ async fn get_transaction(uri: &http::Uri, txid: TxId, no_cert: bool) Ok(response.into_inner()) } -pub fn fetch_full_tx(uri: &http::Uri, txid: TxId, no_cert: bool, c: F) - where F : Fn(&[u8]) { +pub fn fetch_full_tx(uri: &http::Uri, txid: TxId, no_cert: bool) -> Result, String> { let mut rt = match tokio::runtime::Runtime::new() { Ok(r) => r, Err(e) => { - error!("Error creating runtime {}", e.to_string()); - eprintln!("{}", e); - return; + let errstr = format!("Error creating runtime {}", e.to_string()); + error!("{}", errstr); + eprintln!("{}", errstr); + return Err(errstr); } }; match rt.block_on(get_transaction(uri, txid, no_cert)) { - Ok(rawtx) => c(&rawtx.data), + Ok(rawtx) => Ok(rawtx.data.to_vec()), Err(e) => { - error!("Error in get_transaction runtime {}", e.to_string()); - eprintln!("{}", e); + let errstr = format!("Error in get_transaction runtime {}", e.to_string()); + error!("{}", errstr); + eprintln!("{}", errstr); + Err(errstr) } - } - - + } } // send_transaction GRPC call @@ -262,22 +285,19 @@ async fn get_latest_block(uri: &http::Uri, no_cert: bool) -> Result(uri: &http::Uri, no_cert: bool, mut c : F) - where F : FnMut(BlockId) { +pub fn fetch_latest_block(uri: &http::Uri, no_cert: bool) -> Result { let mut rt = match tokio::runtime::Runtime::new() { Ok(r) => r, Err(e) => { - error!("Error creating runtime {}", e.to_string()); - eprintln!("{}", e); - return; + let errstr = format!("Error creating runtime {}", e.to_string()); + eprintln!("{}", errstr); + return Err(errstr); } }; - match rt.block_on(get_latest_block(uri, no_cert)) { - Ok(b) => c(b), - Err(e) => { - error!("Error getting latest block {}", e.to_string()); - eprintln!("{}", e); - } - }; + rt.block_on(get_latest_block(uri, no_cert)).map_err(|e| { + let errstr = format!("Error getting latest block {}", e.to_string()); + eprintln!("{}", errstr); + errstr + }) } diff --git a/lib/src/lightclient.rs b/lib/src/lightclient.rs index 8177b0a..3b5d0cf 100644 --- a/lib/src/lightclient.rs +++ b/lib/src/lightclient.rs @@ -2,17 +2,21 @@ use crate::lightwallet::LightWallet; use rand::{rngs::OsRng, seq::SliceRandom}; -use std::sync::{Arc, RwLock, Mutex}; -use std::sync::atomic::{AtomicU64, AtomicI32, AtomicUsize, Ordering}; +use std::sync::{Arc, RwLock, Mutex, mpsc::channel}; +use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; use std::path::{Path, PathBuf}; use std::fs::File; use std::collections::HashMap; +use std::cmp::{max, min}; use std::io; use std::io::prelude::*; use std::io::{BufReader, BufWriter, Error, ErrorKind}; use protobuf::parse_from_bytes; + +use threadpool::ThreadPool; + use json::{object, array, JsonValue}; use zcash_primitives::transaction::{TxId, Transaction}; use zcash_client_backend::{ @@ -30,7 +34,6 @@ use log4rs::append::rolling_file::policy::compound::{ roll::fixed_window::FixedWindowRoller, }; -use crate::grpc_client::{BlockId}; use crate::grpcconnector::{self, *}; use crate::SaplingParams; @@ -991,15 +994,9 @@ impl LightClient { let mut last_scanned_height = self.wallet.read().unwrap().last_scanned_height() as u64; // This will hold the latest block fetched from the RPC - let latest_block_height = Arc::new(AtomicU64::new(0)); - let lbh = latest_block_height.clone(); - fetch_latest_block(&self.get_server_uri(), self.config.no_cert_verification, - move |block: BlockId| { - lbh.store(block.height, Ordering::SeqCst); - }); - let latest_block = latest_block_height.load(Ordering::SeqCst); - + let latest_block = fetch_latest_block(&self.get_server_uri(), self.config.no_cert_verification)?.height; + if latest_block < last_scanned_height { let w = format!("Server's latest block({}) is behind ours({})", latest_block, last_scanned_height); warn!("{}", w); @@ -1035,6 +1032,9 @@ impl LightClient { // belong to us. let all_new_txs = Arc::new(RwLock::new(vec![])); + // Create a new threadpool (upto 8, atleast 2 threads) to scan with + let pool = ThreadPool::new(max(2, min(8, num_cpus::get()))); + // Fetch CompactBlocks in increments let mut pass = 0; loop { @@ -1070,7 +1070,8 @@ impl LightClient { let last_invalid_height = Arc::new(AtomicI32::new(0)); let last_invalid_height_inner = last_invalid_height.clone(); - fetch_blocks(&self.get_server_uri(), start_height, end_height, self.config.no_cert_verification, + let tpool = pool.clone(); + fetch_blocks(&self.get_server_uri(), start_height, end_height, self.config.no_cert_verification, pool.clone(), move |encoded_block: &[u8], height: u64| { // Process the block only if there were no previous errors if last_invalid_height_inner.load(Ordering::SeqCst) > 0 { @@ -1088,7 +1089,7 @@ impl LightClient { Err(_) => {} } - match local_light_wallet.read().unwrap().scan_block(encoded_block) { + match local_light_wallet.read().unwrap().scan_block_with_pool(encoded_block, &tpool) { Ok(block_txns) => { // Add to global tx list all_txs.write().unwrap().extend_from_slice(&block_txns.iter().map(|txid| (txid.clone(), height as i32)).collect::>()[..]); @@ -1102,6 +1103,16 @@ impl LightClient { local_bytes_downloaded.fetch_add(encoded_block.len(), Ordering::SeqCst); })?; + + { + // println!("Total scan duration: {:?}", self.wallet.read().unwrap().total_scan_duration.read().unwrap().get(0).unwrap().as_millis()); + + let t = self.wallet.read().unwrap(); + let mut d = t.total_scan_duration.write().unwrap(); + d.clear(); + d.push(std::time::Duration::new(0, 0)); + } + // Check if there was any invalid block, which means we might have to do a reorg let invalid_height = last_invalid_height.load(Ordering::SeqCst); if invalid_height > 0 { @@ -1136,11 +1147,16 @@ impl LightClient { let addresses = self.wallet.read().unwrap() .taddresses.read().unwrap().iter().map(|a| a.clone()) .collect::>(); + + // Create a channel so the fetch_transparent_txids can send the results back + let (ctx, crx) = channel(); + let num_addresses = addresses.len(); + for address in addresses { let wallet = self.wallet.clone(); let block_times_inner = block_times.clone(); - // If this is the first pass after a retry, fetch older t address txids too, becuse + // If this is the first pass after a retry, fetch older t address txids too, becuse // they might have been missed last time. let transparent_start_height = if pass == 1 && retry_count > 0 { start_height - scan_batch_size @@ -1148,16 +1164,29 @@ impl LightClient { start_height }; - fetch_transparent_txids(&self.get_server_uri(), address, transparent_start_height, end_height, self.config.no_cert_verification, - move |tx_bytes: &[u8], height: u64| { - let tx = Transaction::read(tx_bytes).unwrap(); - - // Scan this Tx for transparent inputs and outputs - let datetime = block_times_inner.read().unwrap().get(&height).map(|v| *v).unwrap_or(0); - wallet.read().unwrap().scan_full_tx(&tx, height as i32, datetime as u64); - } - )?; + let pool = pool.clone(); + let server_uri = self.get_server_uri(); + let ctx = ctx.clone(); + let no_cert = self.config.no_cert_verification; + + pool.execute(move || { + // Fetch the transparent transactions for this address, and send the results + // via the channel + let r = fetch_transparent_txids(&server_uri, address, transparent_start_height, end_height, no_cert, + move |tx_bytes: &[u8], height: u64| { + let tx = Transaction::read(tx_bytes).unwrap(); + + // Scan this Tx for transparent inputs and outputs + let datetime = block_times_inner.read().unwrap().get(&height).map(|v| *v).unwrap_or(0); + wallet.read().unwrap().scan_full_tx(&tx, height as i32, datetime as u64); + }); + ctx.send(r).unwrap(); + }); } + + // Collect all results from the transparent fetches, and make sure everything was OK. + // If it was not, we return an error, which will go back to the retry + crx.iter().take(num_addresses).collect::, String>>()?; } // Do block height accounting @@ -1200,24 +1229,44 @@ impl LightClient { let mut rng = OsRng; txids_to_fetch.shuffle(&mut rng); + let num_fetches = txids_to_fetch.len(); + let (ctx, crx) = channel(); + // And go and fetch the txids, getting the full transaction, so we can // read the memos for (txid, height) in txids_to_fetch { let light_wallet_clone = self.wallet.clone(); - info!("Fetching full Tx: {}", txid); - fetch_full_tx(&self.get_server_uri(), txid, self.config.no_cert_verification, move |tx_bytes: &[u8] | { - let tx = Transaction::read(tx_bytes).unwrap(); + let pool = pool.clone(); + let server_uri = self.get_server_uri(); + let ctx = ctx.clone(); + let no_cert = self.config.no_cert_verification; + + pool.execute(move || { + info!("Fetching full Tx: {}", txid); - light_wallet_clone.read().unwrap().scan_full_tx(&tx, height, 0); + match fetch_full_tx(&server_uri, txid, no_cert) { + Ok(tx_bytes) => { + let tx = Transaction::read(&tx_bytes[..]).unwrap(); + + light_wallet_clone.read().unwrap().scan_full_tx(&tx, height, 0); + ctx.send(Ok(())).unwrap(); + }, + Err(e) => ctx.send(Err(e)).unwrap() + }; }); }; - Ok(object!{ - "result" => "success", - "latest_block" => latest_block, - "downloaded_bytes" => bytes_downloaded.load(Ordering::SeqCst) - }) + // Wait for all the fetches to finish. + let result = crx.iter().take(num_fetches).collect::, String>>(); + match result { + Ok(_) => Ok(object!{ + "result" => "success", + "latest_block" => latest_block, + "downloaded_bytes" => bytes_downloaded.load(Ordering::SeqCst) + }), + Err(e) => Err(format!("Error fetching all txns for memos: {}", e)) + } } pub fn do_send(&self, addrs: Vec<(&str, u64, Option)>) -> Result { diff --git a/lib/src/lightwallet.rs b/lib/src/lightwallet.rs index 9ff92a8..71ec362 100644 --- a/lib/src/lightwallet.rs +++ b/lib/src/lightwallet.rs @@ -1,11 +1,15 @@ -use std::time::SystemTime; +use std::time::{SystemTime, Duration}; use std::io::{self, Read, Write}; use std::cmp; use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::io::{Error, ErrorKind}; +use threadpool::ThreadPool; +use std::sync::mpsc::{channel}; + use rand::{Rng, rngs::OsRng}; +use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption}; use log::{info, warn, error}; @@ -21,20 +25,23 @@ use sha2::{Sha256, Digest}; use zcash_client_backend::{ encoding::{encode_payment_address, encode_extended_spending_key}, - proto::compact_formats::CompactBlock, welding_rig::scan_block, + proto::compact_formats::{CompactBlock, CompactOutput}, + wallet::{WalletShieldedOutput, WalletShieldedSpend} }; use zcash_primitives::{ + jubjub::fs::Fs, block::BlockHash, - merkle_tree::{CommitmentTree}, serialize::{Vector}, transaction::{ builder::{Builder}, components::{Amount, OutPoint, TxOut}, components::amount::DEFAULT_FEE, TxId, Transaction, }, - legacy::{Script, TransparentAddress}, - note_encryption::{Memo, try_sapling_note_decryption, try_sapling_output_recovery}, + sapling::Node, + merkle_tree::{CommitmentTree, IncrementalWitness}, + legacy::{Script, TransparentAddress}, + note_encryption::{Memo, try_sapling_note_decryption, try_sapling_output_recovery, try_sapling_compact_note_decryption}, zip32::{ExtendedFullViewingKey, ExtendedSpendingKey, ChildIndex}, JUBJUB, primitives::{PaymentAddress}, @@ -136,6 +143,8 @@ pub struct LightWallet { // Non-serialized fields config: LightClientConfig, + + pub total_scan_duration: Arc>>, } impl LightWallet { @@ -254,6 +263,7 @@ impl LightWallet { mempool_txs: Arc::new(RwLock::new(HashMap::new())), config: config.clone(), birthday: latest_block, + total_scan_duration: Arc::new(RwLock::new(vec![Duration::new(0, 0)])), }; // If restoring from seed, make sure we are creating 50 addresses for users @@ -375,6 +385,7 @@ impl LightWallet { mempool_txs: Arc::new(RwLock::new(HashMap::new())), config: config.clone(), birthday, + total_scan_duration: Arc::new(RwLock::new(vec![Duration::new(0, 0)])), }) } @@ -424,12 +435,19 @@ impl LightWallet { Vector::write(&mut writer, &self.blocks.read().unwrap(), |w, b| b.write(w))?; - // The hashmap, write as a set of tuples - Vector::write(&mut writer, &self.txs.read().unwrap().iter().collect::>(), - |w, (k, v)| { - w.write_all(&k.0)?; - v.write(w) - })?; + // The hashmap, write as a set of tuples. Store them sorted so that wallets are + // deterministically saved + { + let txlist = self.txs.read().unwrap(); + let mut txns = txlist.iter().collect::>(); + txns.sort_by(|a, b| a.0.partial_cmp(b.0).unwrap()); + + Vector::write(&mut writer, &txns, + |w, (k, v)| { + w.write_all(&k.0)?; + v.write(w) + })?; + } utils::write_string(&mut writer, &self.config.chain_name)?; // While writing the birthday, get it from the fn so we recalculate it properly @@ -1272,7 +1290,7 @@ impl LightWallet { // Trim all witnesses for the invalidated blocks for tx in txs.values_mut() { for nd in tx.notes.iter_mut() { - nd.witnesses.split_off(nd.witnesses.len().saturating_sub(num_invalidated)); + let _discard = nd.witnesses.split_off(nd.witnesses.len().saturating_sub(num_invalidated)); } } } @@ -1280,8 +1298,234 @@ impl LightWallet { num_invalidated as u64 } - // Scan a block. Will return an error with the block height that failed to scan + /// Scans a [`CompactOutput`] with a set of [`ExtendedFullViewingKey`]s. + /// + /// Returns a [`WalletShieldedOutput`] and corresponding [`IncrementalWitness`] if this + /// output belongs to any of the given [`ExtendedFullViewingKey`]s. + /// + /// The given [`CommitmentTree`] and existing [`IncrementalWitness`]es are incremented + /// with this output's commitment. + fn scan_output_internal( + &self, + (index, output): (usize, CompactOutput), + ivks: &[Fs], + tree: &mut CommitmentTree, + existing_witnesses: &mut [&mut IncrementalWitness], + block_witnesses: &mut [&mut IncrementalWitness], + new_witnesses: &mut [&mut IncrementalWitness], + pool: &ThreadPool + ) -> Option { + let cmu = output.cmu().ok()?; + let epk = output.epk().ok()?; + let ct = output.ciphertext; + + let (tx, rx) = channel(); + ivks.iter().enumerate().for_each(|(account, ivk)| { + // Clone all values for passing to the closure + let ivk = ivk.clone(); + let epk = epk.clone(); + let ct = ct.clone(); + let tx = tx.clone(); + + pool.execute(move || { + let m = try_sapling_compact_note_decryption(&ivk, &epk, &cmu, &ct); + let r = match m { + Some((note, to)) => { + tx.send(Some(Some((note, to, account)))) + }, + None => { + tx.send(Some(None)) + } + }; + + match r { + Ok(_) => {}, + Err(e) => println!("Send error {:?}", e) + } + }); + }); + + // Increment tree and witnesses + let node = Node::new(cmu.into()); + for witness in existing_witnesses { + witness.append(node).unwrap(); + } + for witness in block_witnesses { + witness.append(node).unwrap(); + } + for witness in new_witnesses { + witness.append(node).unwrap(); + } + tree.append(node).unwrap(); + + // Collect all the RXs and fine if there was a valid result somewhere + let mut wsos = vec![]; + for _i in 0..ivks.len() { + let n = rx.recv().unwrap(); + let epk = epk.clone(); + + let wso = match n { + None => panic!("Got a none!"), + Some(None) => None, + Some(Some((note, to, account))) => { + // A note is marked as "change" if the account that received it + // also spent notes in the same transaction. This will catch, + // for instance: + // - Change created by spending fractions of notes. + // - Notes created by consolidation transactions. + // - Notes sent from one account to itself. + //let is_change = spent_from_accounts.contains(&account); + + Some(WalletShieldedOutput { + index, cmu, epk, account, note, to, is_change: false, + witness: IncrementalWitness::from_tree(tree), + }) + } + }; + wsos.push(wso); + } + + match wsos.into_iter().find(|wso| wso.is_some()) { + Some(Some(wso)) => Some(wso), + _ => None + } + } + + /// Scans a [`CompactBlock`] with a set of [`ExtendedFullViewingKey`]s. + /// + /// Returns a vector of [`WalletTx`]s belonging to any of the given + /// [`ExtendedFullViewingKey`]s, and the corresponding new [`IncrementalWitness`]es. + /// + /// The given [`CommitmentTree`] and existing [`IncrementalWitness`]es are + /// incremented appropriately. + pub fn scan_block_internal( + &self, + block: CompactBlock, + extfvks: &[ExtendedFullViewingKey], + nullifiers: Vec<(Vec, usize)>, + tree: &mut CommitmentTree, + existing_witnesses: &mut [&mut IncrementalWitness], + pool: &ThreadPool + ) -> Vec { + let mut wtxs: Vec = vec![]; + let ivks = extfvks.iter().map(|extfvk| extfvk.fvk.vk.ivk()).collect::>(); + + for tx in block.vtx.into_iter() { + let num_spends = tx.spends.len(); + let num_outputs = tx.outputs.len(); + + let (ctx, crx) = channel(); + { + let nullifiers = nullifiers.clone(); + let tx = tx.clone(); + pool.execute(move || { + // Check for spent notes + // The only step that is not constant-time is the filter() at the end. + let shielded_spends: Vec<_> = tx + .spends + .into_iter() + .enumerate() + .map(|(index, spend)| { + // Find the first tracked nullifier that matches this spend, and produce + // a WalletShieldedSpend if there is a match, in constant time. + nullifiers + .iter() + .map(|(nf, account)| CtOption::new(*account as u64, nf.ct_eq(&spend.nf[..]))) + .fold(CtOption::new(0, 0.into()), |first, next| { + CtOption::conditional_select(&next, &first, first.is_some()) + }) + .map(|account| WalletShieldedSpend { + index, + nf: spend.nf, + account: account as usize, + }) + }) + .filter(|spend| spend.is_some().into()) + .map(|spend| spend.unwrap()) + .collect(); + + // Collect the set of accounts that were spent from in this transaction + let spent_from_accounts: HashSet<_> = + shielded_spends.iter().map(|spend| spend.account).collect(); + + ctx.send((shielded_spends, spent_from_accounts)).unwrap(); + + drop(ctx); + }); + } + + + // Check for incoming notes while incrementing tree and witnesses + let mut shielded_outputs: Vec = vec![]; + { + // Grab mutable references to new witnesses from previous transactions + // in this block so that we can update them. Scoped so we don't hold + // mutable references to wtxs for too long. + let mut block_witnesses: Vec<_> = wtxs + .iter_mut() + .map(|tx| { + tx.shielded_outputs + .iter_mut() + .map(|output| &mut output.witness) + }) + .flatten() + .collect(); + + for to_scan in tx.outputs.into_iter().enumerate() { + // Grab mutable references to new witnesses from previous outputs + // in this transaction so that we can update them. Scoped so we + // don't hold mutable references to shielded_outputs for too long. + let mut new_witnesses: Vec<_> = shielded_outputs + .iter_mut() + .map(|output| &mut output.witness) + .collect(); + + if let Some(output) = self.scan_output_internal( + to_scan, + &ivks, + tree, + existing_witnesses, + &mut block_witnesses, + &mut new_witnesses, + pool + ) { + shielded_outputs.push(output); + } + } + } + + let (shielded_spends, spent_from_accounts) = crx.recv().unwrap(); + + // Identify change outputs + shielded_outputs.iter_mut().for_each(|output| { + if spent_from_accounts.contains(&output.account) { + output.is_change = true; + } + }); + + // Update wallet tx + if !(shielded_spends.is_empty() && shielded_outputs.is_empty()) { + let mut txid = TxId([0u8; 32]); + txid.0.copy_from_slice(&tx.hash); + wtxs.push(zcash_client_backend::wallet::WalletTx { + txid, + index: tx.index as usize, + num_spends, + num_outputs, + shielded_spends, + shielded_outputs, + }); + } + } + + wtxs + } pub fn scan_block(&self, block_bytes: &[u8]) -> Result, i32> { + self.scan_block_with_pool(&block_bytes, &ThreadPool::new(1)) + } + + // Scan a block. Will return an error with the block height that failed to scan + pub fn scan_block_with_pool(&self, block_bytes: &[u8], pool: &ThreadPool) -> Result, i32> { let block: CompactBlock = match parse_from_bytes(block_bytes) { Ok(block) => block, Err(e) => { @@ -1372,7 +1616,7 @@ impl LightWallet { } new_txs = { - let nf_refs: Vec<_> = nfs.iter().map(|(nf, acc, _)| (&nf[..], *acc)).collect(); + let nf_refs = nfs.iter().map(|(nf, account, _)| (nf.to_vec(), *account)).collect::>(); // Create a single mutable slice of all the newly-added witnesses. let mut witness_refs: Vec<_> = txs @@ -1381,12 +1625,13 @@ impl LightWallet { .flatten() .collect(); - scan_block( + self.scan_block_internal( block.clone(), &self.extfvks.read().unwrap(), - &nf_refs[..], + nf_refs, &mut block_data.tree, &mut witness_refs[..], + pool, ) }; }