forked from hush/lightwalletd
Aditya Kulkarni
5 years ago
4 changed files with 0 additions and 668 deletions
@ -1 +0,0 @@ |
|||
|
@ -1,268 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"context" |
|||
"database/sql" |
|||
"encoding/hex" |
|||
"encoding/json" |
|||
"flag" |
|||
"fmt" |
|||
"os" |
|||
"strconv" |
|||
"strings" |
|||
"time" |
|||
|
|||
"github.com/btcsuite/btcd/rpcclient" |
|||
"github.com/golang/protobuf/proto" |
|||
"github.com/pkg/errors" |
|||
"github.com/sirupsen/logrus" |
|||
|
|||
"github.com/adityapk00/lightwalletd/common" |
|||
"github.com/adityapk00/lightwalletd/frontend" |
|||
"github.com/adityapk00/lightwalletd/parser" |
|||
"github.com/adityapk00/lightwalletd/storage" |
|||
) |
|||
|
|||
var log *logrus.Entry |
|||
var logger = logrus.New() |
|||
var db *sql.DB |
|||
|
|||
// Options is a struct holding command line options
|
|||
type Options struct { |
|||
dbPath string |
|||
logLevel uint64 |
|||
logPath string |
|||
zcashConfPath string |
|||
} |
|||
|
|||
func main() { |
|||
opts := &Options{} |
|||
flag.StringVar(&opts.dbPath, "db-path", "", "the path to a sqlite database file") |
|||
flag.Uint64Var(&opts.logLevel, "log-level", uint64(logrus.InfoLevel), "log level (logrus 1-7)") |
|||
flag.StringVar(&opts.logPath, "log-file", "", "log file to write to") |
|||
flag.StringVar(&opts.zcashConfPath, "conf-file", "", "conf file to pull RPC creds from") |
|||
// TODO prod metrics
|
|||
// TODO support config from file and env vars
|
|||
flag.Parse() |
|||
|
|||
if opts.dbPath == "" { |
|||
flag.Usage() |
|||
os.Exit(1) |
|||
} |
|||
|
|||
// Initialize logging
|
|||
logger.SetFormatter(&logrus.TextFormatter{ |
|||
//DisableColors: true,
|
|||
FullTimestamp: true, |
|||
DisableLevelTruncation: true, |
|||
}) |
|||
|
|||
if opts.logPath != "" { |
|||
// instead write parsable logs for logstash/splunk/etc
|
|||
output, err := os.OpenFile(opts.logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
"path": opts.logPath, |
|||
}).Fatal("couldn't open log file") |
|||
} |
|||
defer output.Close() |
|||
logger.SetOutput(output) |
|||
logger.SetFormatter(&logrus.JSONFormatter{}) |
|||
} |
|||
|
|||
logger.SetLevel(logrus.Level(opts.logLevel)) |
|||
|
|||
log = logger.WithFields(logrus.Fields{ |
|||
"app": "lightwd", |
|||
}) |
|||
|
|||
// Initialize database
|
|||
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=10000&cache=shared", opts.dbPath)) |
|||
db.SetMaxOpenConns(1) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"db_path": opts.dbPath, |
|||
"error": err, |
|||
}).Fatal("couldn't open SQL db") |
|||
} |
|||
|
|||
// Creates our tables if they don't already exist.
|
|||
err = storage.CreateTables(db) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Fatal("couldn't create SQL tables") |
|||
} |
|||
|
|||
//Initialize RPC connection with full node zcashd
|
|||
rpcClient, err := frontend.NewZRPCFromConf(opts.zcashConfPath) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Warn("zcash.conf failed, will try empty credentials for rpc") |
|||
|
|||
//Default to testnet, but user MUST specify rpcuser and rpcpassword in zcash.conf; no default
|
|||
rpcClient, err = frontend.NewZRPCFromCreds("127.0.0.1:18232", "", "") |
|||
|
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Fatal("couldn't start rpc connection") |
|||
} |
|||
} |
|||
|
|||
ctx := context.Background() |
|||
height, err := storage.GetCurrentHeight(ctx, db) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Warn("Unable to get current height from local db storage. This is OK if you're starting this for the first time.") |
|||
} |
|||
|
|||
// Get the sapling activation height from the RPC
|
|||
saplingHeight, chainName, err := common.GetSaplingInfo(rpcClient) |
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Warn("Unable to get sapling activation height") |
|||
} |
|||
|
|||
log.WithField("saplingHeight", saplingHeight).Info("Got sapling height ", saplingHeight, " chain ", chainName) |
|||
|
|||
//ingest from Sapling testnet height
|
|||
if height < saplingHeight { |
|||
height = saplingHeight |
|||
log.WithFields(logrus.Fields{ |
|||
"error": err, |
|||
}).Warn("invalid current height read from local db storage") |
|||
} |
|||
|
|||
timeoutCount := 0 |
|||
reorgCount := -1 |
|||
hash := "" |
|||
phash := "" |
|||
// Start listening for new blocks
|
|||
for { |
|||
if reorgCount > 0 { |
|||
reorgCount = -1 |
|||
height -= 10 |
|||
} |
|||
block, err := getBlock(rpcClient, height) |
|||
|
|||
if err != nil { |
|||
log.WithFields(logrus.Fields{ |
|||
"height": height, |
|||
"error": err, |
|||
}).Warn("error with getblock") |
|||
timeoutCount++ |
|||
if timeoutCount == 3 { |
|||
log.WithFields(logrus.Fields{ |
|||
"timeouts": timeoutCount, |
|||
}).Warn("unable to issue RPC call to zcashd node 3 times") |
|||
break |
|||
} |
|||
} |
|||
if block != nil { |
|||
handleBlock(db, block) |
|||
if timeoutCount > 0 { |
|||
timeoutCount-- |
|||
} |
|||
phash = hex.EncodeToString(block.GetPrevHash()) |
|||
//check for reorgs once we have inital block hash from startup
|
|||
if hash != phash && reorgCount != -1 { |
|||
reorgCount++ |
|||
log.WithFields(logrus.Fields{ |
|||
"height": height, |
|||
"hash": hash, |
|||
"phash": phash, |
|||
"reorg": reorgCount, |
|||
}).Warn("REORG") |
|||
} else { |
|||
hash = hex.EncodeToString(block.GetDisplayHash()) |
|||
} |
|||
if reorgCount == -1 { |
|||
hash = hex.EncodeToString(block.GetDisplayHash()) |
|||
reorgCount = 0 |
|||
} |
|||
height++ |
|||
} else { |
|||
//TODO implement blocknotify to minimize polling on corner cases
|
|||
time.Sleep(60 * time.Second) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func getBlock(rpcClient *rpcclient.Client, height int) (*parser.Block, error) { |
|||
params := make([]json.RawMessage, 2) |
|||
params[0] = json.RawMessage("\"" + strconv.Itoa(height) + "\"") |
|||
params[1] = json.RawMessage("0") |
|||
result, rpcErr := rpcClient.RawRequest("getblock", params) |
|||
|
|||
var err error |
|||
var errCode int64 |
|||
|
|||
// For some reason, the error responses are not JSON
|
|||
if rpcErr != nil { |
|||
errParts := strings.SplitN(rpcErr.Error(), ":", 2) |
|||
errCode, err = strconv.ParseInt(errParts[0], 10, 32) |
|||
//Check to see if we are requesting a height the zcashd doesn't have yet
|
|||
if err == nil && errCode == -8 { |
|||
return nil, nil |
|||
} |
|||
return nil, errors.Wrap(rpcErr, "error requesting block") |
|||
} |
|||
|
|||
var blockDataHex string |
|||
err = json.Unmarshal(result, &blockDataHex) |
|||
if err != nil { |
|||
return nil, errors.Wrap(err, "error reading JSON response") |
|||
} |
|||
|
|||
blockData, err := hex.DecodeString(blockDataHex) |
|||
if err != nil { |
|||
return nil, errors.Wrap(err, "error decoding getblock output") |
|||
} |
|||
|
|||
block := parser.NewBlock() |
|||
rest, err := block.ParseFromSlice(blockData) |
|||
if err != nil { |
|||
return nil, errors.Wrap(err, "error parsing block") |
|||
} |
|||
if len(rest) != 0 { |
|||
return nil, errors.New("received overlong message") |
|||
} |
|||
return block, nil |
|||
} |
|||
|
|||
func handleBlock(db *sql.DB, block *parser.Block) { |
|||
prevBlockHash := hex.EncodeToString(block.GetPrevHash()) |
|||
blockHash := hex.EncodeToString(block.GetEncodableHash()) |
|||
marshaledBlock, _ := proto.Marshal(block.ToCompact()) |
|||
|
|||
err := storage.StoreBlock( |
|||
db, |
|||
block.GetHeight(), |
|||
prevBlockHash, |
|||
blockHash, |
|||
block.HasSaplingTransactions(), |
|||
marshaledBlock, |
|||
) |
|||
|
|||
entry := log.WithFields(logrus.Fields{ |
|||
"block_height": block.GetHeight(), |
|||
"block_hash": hex.EncodeToString(block.GetDisplayHash()), |
|||
"prev_hash": hex.EncodeToString(block.GetDisplayPrevHash()), |
|||
"block_version": block.GetVersion(), |
|||
"tx_count": block.GetTxCount(), |
|||
"sapling": block.HasSaplingTransactions(), |
|||
"error": err, |
|||
}) |
|||
|
|||
if err != nil { |
|||
entry.Error("new block") |
|||
} else { |
|||
entry.Info("new block") |
|||
} |
|||
|
|||
} |
@ -1,164 +0,0 @@ |
|||
package storage |
|||
|
|||
import ( |
|||
"context" |
|||
"database/sql" |
|||
"fmt" |
|||
|
|||
"github.com/pkg/errors" |
|||
) |
|||
|
|||
var ( |
|||
ErrLotsOfBlocks = errors.New("requested >10k blocks at once") |
|||
) |
|||
|
|||
func CreateTables(conn *sql.DB) error { |
|||
stateTable := ` |
|||
CREATE TABLE IF NOT EXISTS state ( |
|||
current_height INTEGER, |
|||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, |
|||
FOREIGN KEY (current_height) REFERENCES blocks (block_height) |
|||
); |
|||
` |
|||
_, err := conn.Exec(stateTable) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
blockTable := ` |
|||
CREATE TABLE IF NOT EXISTS blocks ( |
|||
block_height INTEGER PRIMARY KEY, |
|||
prev_hash TEXT, |
|||
block_hash TEXT, |
|||
sapling BOOL, |
|||
compact_encoding BLOB |
|||
); |
|||
` |
|||
_, err = conn.Exec(blockTable) |
|||
|
|||
return err |
|||
} |
|||
|
|||
// TODO consider max/count queries instead of state table. bit of a coupling assumption though.
|
|||
|
|||
func GetCurrentHeight(ctx context.Context, db *sql.DB) (int, error) { |
|||
var height int = -1 |
|||
query := "SELECT current_height FROM state WHERE rowid = 1" |
|||
err := db.QueryRowContext(ctx, query).Scan(&height) |
|||
return height, err |
|||
} |
|||
|
|||
func GetBlock(ctx context.Context, db *sql.DB, height int) ([]byte, error) { |
|||
var blockBytes []byte // avoid a copy with *RawBytes
|
|||
query := "SELECT compact_encoding from blocks WHERE block_height = ?" |
|||
err := db.QueryRowContext(ctx, query, height).Scan(&blockBytes) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return blockBytes, err |
|||
} |
|||
|
|||
func GetBlockByHash(ctx context.Context, db *sql.DB, hash string) ([]byte, error) { |
|||
var blockBytes []byte // avoid a copy with *RawBytes
|
|||
query := "SELECT compact_encoding from blocks WHERE block_hash = ?" |
|||
err := db.QueryRowContext(ctx, query, hash).Scan(&blockBytes) |
|||
if err != nil { |
|||
return nil, errors.Wrap(err, fmt.Sprintf("getting block with hash %s", hash)) |
|||
} |
|||
return blockBytes, err |
|||
} |
|||
|
|||
// [start, end] inclusive
|
|||
func GetBlockRange(ctx context.Context, db *sql.DB, blockOut chan<- []byte, errOut chan<- error, start, end int) { |
|||
// TODO sanity check ranges, this limit, etc
|
|||
numBlocks := (end - start) + 1 |
|||
if numBlocks > 10000 { |
|||
errOut <- ErrLotsOfBlocks |
|||
return |
|||
} |
|||
|
|||
query := "SELECT compact_encoding from blocks WHERE (block_height BETWEEN ? AND ?)" |
|||
result, err := db.QueryContext(ctx, query, start, end) |
|||
if err != nil { |
|||
errOut <- err |
|||
return |
|||
} |
|||
defer result.Close() |
|||
|
|||
// My assumption here is that if the context is cancelled then result.Next() will fail.
|
|||
|
|||
var blockBytes []byte |
|||
for result.Next() { |
|||
err = result.Scan(&blockBytes) |
|||
if err != nil { |
|||
errOut <- err |
|||
return |
|||
} |
|||
blockOut <- blockBytes |
|||
} |
|||
|
|||
if err := result.Err(); err != nil { |
|||
errOut <- err |
|||
return |
|||
} |
|||
|
|||
// done
|
|||
errOut <- nil |
|||
} |
|||
|
|||
func StoreBlock(conn *sql.DB, height int, prev_hash string, hash string, sapling bool, encoded []byte) error { |
|||
insertBlock := "REPLACE INTO blocks (block_height, prev_hash, block_hash, sapling, compact_encoding) values ( ?, ?, ?, ?, ?)" |
|||
|
|||
tx, err := conn.Begin() |
|||
if err != nil { |
|||
return errors.Wrap(err, fmt.Sprintf("creating db tx %d", height)) |
|||
} |
|||
|
|||
_, err = tx.Exec(insertBlock, height, prev_hash, hash, sapling, encoded) |
|||
if err != nil { |
|||
return errors.Wrap(err, fmt.Sprintf("storing compact block %d", height)) |
|||
} |
|||
|
|||
var currentHeight int |
|||
query := "SELECT current_height FROM state WHERE rowid = 1" |
|||
err = tx.QueryRow(query).Scan(¤tHeight) |
|||
|
|||
if err != nil || height > currentHeight { |
|||
err = setCurrentHeight(tx, height) |
|||
} |
|||
|
|||
err = tx.Commit() |
|||
if err != nil { |
|||
return errors.Wrap(err, fmt.Sprintf("committing db tx %d", height)) |
|||
|
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func setCurrentHeight(tx *sql.Tx, height int) error { |
|||
update := "UPDATE state SET current_height=?, timestamp=CURRENT_TIMESTAMP WHERE rowid = 1" |
|||
result, err := tx.Exec(update, height) |
|||
if err != nil { |
|||
return errors.Wrap(err, "updating state row") |
|||
} |
|||
rowCount, err := result.RowsAffected() |
|||
if err != nil { |
|||
return errors.Wrap(err, "checking if state row exists after update") |
|||
} |
|||
if rowCount == 0 { |
|||
// row does not yet exist
|
|||
insert := "INSERT OR IGNORE INTO state (rowid, current_height) VALUES (1, ?)" |
|||
result, err = tx.Exec(insert, height) |
|||
if err != nil { |
|||
return errors.Wrap(err, "on state row insert") |
|||
} |
|||
rowCount, err = result.RowsAffected() |
|||
if err != nil { |
|||
return errors.Wrap(err, "checking if state row exists after insert") |
|||
} |
|||
if rowCount != 1 { |
|||
return errors.New("totally failed to update current height state") |
|||
} |
|||
} |
|||
return nil |
|||
} |
@ -1,235 +0,0 @@ |
|||
package storage |
|||
|
|||
import ( |
|||
"context" |
|||
"database/sql" |
|||
"encoding/hex" |
|||
"encoding/json" |
|||
"fmt" |
|||
"io/ioutil" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/golang/protobuf/proto" |
|||
_ "github.com/mattn/go-sqlite3" |
|||
"github.com/pkg/errors" |
|||
|
|||
"github.com/adityapk00/lightwalletd/parser" |
|||
"github.com/adityapk00/lightwalletd/walletrpc" |
|||
) |
|||
|
|||
type compactTest struct { |
|||
BlockHeight int `json:"block"` |
|||
BlockHash string `json:"hash"` |
|||
Full string `json:"full"` |
|||
Compact string `json:"compact"` |
|||
} |
|||
|
|||
var compactTests []compactTest |
|||
|
|||
func TestSqliteStorage(t *testing.T) { |
|||
blockJSON, err := ioutil.ReadFile("../testdata/compact_blocks.json") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
err = json.Unmarshal(blockJSON, &compactTests) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
db, err := sql.Open("sqlite3", ":memory:") |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
defer db.Close() |
|||
|
|||
// Fill tables
|
|||
{ |
|||
err = CreateTables(db) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
for _, test := range compactTests { |
|||
blockData, _ := hex.DecodeString(test.Full) |
|||
block := parser.NewBlock() |
|||
blockData, err = block.ParseFromSlice(blockData) |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, fmt.Sprintf("parsing testnet block %d", test.BlockHeight))) |
|||
continue |
|||
} |
|||
|
|||
height := block.GetHeight() |
|||
hash := hex.EncodeToString(block.GetEncodableHash()) |
|||
hasSapling := block.HasSaplingTransactions() |
|||
protoBlock := block.ToCompact() |
|||
marshaled, _ := proto.Marshal(protoBlock) |
|||
|
|||
err = StoreBlock(db, height, hash, hasSapling, marshaled) |
|||
if err != nil { |
|||
t.Error(err) |
|||
continue |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Count the blocks
|
|||
{ |
|||
var count int |
|||
countBlocks := "SELECT count(*) FROM blocks" |
|||
err = db.QueryRow(countBlocks).Scan(&count) |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, fmt.Sprintf("counting compact blocks"))) |
|||
} |
|||
|
|||
if count != len(compactTests) { |
|||
t.Errorf("Wrong row count, want %d got %d", len(compactTests), count) |
|||
} |
|||
} |
|||
|
|||
ctx := context.Background() |
|||
|
|||
// Check height state is as expected
|
|||
{ |
|||
blockHeight, err := GetCurrentHeight(ctx, db) |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, fmt.Sprintf("checking current block height"))) |
|||
} |
|||
|
|||
lastBlockTest := compactTests[len(compactTests)-1] |
|||
|
|||
if blockHeight != lastBlockTest.BlockHeight { |
|||
t.Errorf("Wrong block height, got: %d", blockHeight) |
|||
} |
|||
|
|||
storedBlock, err := GetBlock(ctx, db, blockHeight) |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, "retrieving stored block")) |
|||
} |
|||
cblock := &walletrpc.CompactBlock{} |
|||
err = proto.Unmarshal(storedBlock, cblock) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
|
|||
if int(cblock.Height) != lastBlockTest.BlockHeight { |
|||
t.Error("incorrect retrieval") |
|||
} |
|||
} |
|||
|
|||
// Block ranges
|
|||
{ |
|||
blockOut := make(chan []byte) |
|||
errOut := make(chan error) |
|||
|
|||
count := 0 |
|||
go GetBlockRange(ctx, db, blockOut, errOut, 289460, 289465) |
|||
recvLoop0: |
|||
for { |
|||
select { |
|||
case <-blockOut: |
|||
count++ |
|||
case err := <-errOut: |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, "in full blockrange")) |
|||
} |
|||
break recvLoop0 |
|||
} |
|||
} |
|||
|
|||
if count != 6 { |
|||
t.Error("failed to retrieve full range") |
|||
} |
|||
|
|||
// Test timeout
|
|||
timeout, _ := context.WithTimeout(ctx, 0*time.Second) |
|||
go GetBlockRange(timeout, db, blockOut, errOut, 289460, 289465) |
|||
recvLoop1: |
|||
for { |
|||
select { |
|||
case err := <-errOut: |
|||
if err != context.DeadlineExceeded { |
|||
t.Errorf("got the wrong error: %v", err) |
|||
} |
|||
break recvLoop1 |
|||
} |
|||
} |
|||
|
|||
// Test a smaller range
|
|||
count = 0 |
|||
go GetBlockRange(ctx, db, blockOut, errOut, 289462, 289465) |
|||
recvLoop2: |
|||
for { |
|||
select { |
|||
case <-blockOut: |
|||
count++ |
|||
case err := <-errOut: |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, "in short blockrange")) |
|||
} |
|||
break recvLoop2 |
|||
} |
|||
} |
|||
|
|||
if count != 4 { |
|||
t.Errorf("failed to retrieve the shorter range") |
|||
} |
|||
|
|||
// Test a nonsense range
|
|||
count = 0 |
|||
go GetBlockRange(ctx, db, blockOut, errOut, 1, 2) |
|||
recvLoop3: |
|||
for { |
|||
select { |
|||
case <-blockOut: |
|||
count++ |
|||
case err := <-errOut: |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, "in invalid blockrange")) |
|||
} |
|||
break recvLoop3 |
|||
} |
|||
} |
|||
|
|||
if count > 0 { |
|||
t.Errorf("got some blocks that shouldn't be there") |
|||
} |
|||
} |
|||
|
|||
// Transaction storage
|
|||
{ |
|||
blockData, _ := hex.DecodeString(compactTests[0].Full) |
|||
block := parser.NewBlock() |
|||
_, _ = block.ParseFromSlice(blockData) |
|||
tx := block.Transactions()[0] |
|||
|
|||
blockHash := hex.EncodeToString(block.GetEncodableHash()) |
|||
txHash := hex.EncodeToString(tx.GetEncodableHash()) |
|||
err = StoreTransaction( |
|||
db, |
|||
block.GetHeight(), |
|||
blockHash, |
|||
0, |
|||
txHash, |
|||
tx.Bytes(), |
|||
) |
|||
|
|||
if err != nil { |
|||
t.Error(err) |
|||
} |
|||
|
|||
var storedBytes []byte |
|||
getTx := "SELECT tx_bytes FROM transactions WHERE tx_hash = ?" |
|||
err = db.QueryRow(getTx, txHash).Scan(&storedBytes) |
|||
if err != nil { |
|||
t.Error(errors.Wrap(err, fmt.Sprintf("error getting a full transaction"))) |
|||
} |
|||
|
|||
if len(storedBytes) != len(tx.Bytes()) { |
|||
t.Errorf("Wrong tx size, want %d got %d", len(tx.Bytes()), storedBytes) |
|||
} |
|||
|
|||
} |
|||
|
|||
} |
Loading…
Reference in new issue