Browse Source

Auto merge of #3148 - str4d:DOS-mitigation-tx-expiry, r=str4d

Don't increase banscore for expired transactions if they only just expired

Closes #3141.
pull/4/head
Homu 6 years ago
parent
commit
f5b1082f9c
  1. 1
      qa/pull-tester/rpc-tests.sh
  2. 133
      qa/rpc-tests/p2p_txexpiry_dos.py
  3. 40
      qa/rpc-tests/test_framework/mininode.py
  4. 4
      src/main.cpp

1
qa/pull-tester/rpc-tests.sh

@ -56,6 +56,7 @@ testScripts=(
'bipdersig-p2p.py'
'overwinter_peer_management.py'
'rewind_index.py'
'p2p_txexpiry_dos.py'
);
testScriptsExt=(
'getblocktemplate_longpoll.py'

133
qa/rpc-tests/p2p_txexpiry_dos.py

@ -0,0 +1,133 @@
#!/usr/bin/env python2
# Copyright (c) 2018 The Zcash developers
# Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
from test_framework.mininode import NodeConn, NodeConnCB, NetworkThread, \
CTransaction, msg_tx, mininode_lock, OVERWINTER_PROTO_VERSION
from test_framework.test_framework import BitcoinTestFramework
from test_framework.util import initialize_chain_clean, start_nodes, \
p2p_port, assert_equal
import time, cStringIO
from binascii import hexlify, unhexlify
class TestNode(NodeConnCB):
def __init__(self):
NodeConnCB.__init__(self)
self.create_callback_map()
self.connection = None
def add_connection(self, conn):
self.connection = conn
# Spin until verack message is received from the node.
# We use this to signal that our test can begin. This
# is called from the testing thread, so it needs to acquire
# the global lock.
def wait_for_verack(self):
while True:
with mininode_lock:
if self.verack_received:
return
time.sleep(0.05)
# Wrapper for the NodeConn's send_message function
def send_message(self, message):
self.connection.send_message(message)
def on_close(self, conn):
pass
def on_reject(self, conn, message):
conn.rejectMessage = message
class TxExpiryDoSTest(BitcoinTestFramework):
def setup_chain(self):
print "Initializing test directory "+self.options.tmpdir
initialize_chain_clean(self.options.tmpdir, 1)
def setup_network(self):
self.nodes = start_nodes(1, self.options.tmpdir,
extra_args=[['-nuparams=5ba81b19:10']])
def create_transaction(self, node, coinbase, to_address, amount, txModifier=None):
from_txid = node.getblock(coinbase)['tx'][0]
inputs = [{ "txid" : from_txid, "vout" : 0}]
outputs = { to_address : amount }
rawtx = node.createrawtransaction(inputs, outputs)
tx = CTransaction()
if txModifier:
f = cStringIO.StringIO(unhexlify(rawtx))
tx.deserialize(f)
txModifier(tx)
rawtx = hexlify(tx.serialize())
signresult = node.signrawtransaction(rawtx)
f = cStringIO.StringIO(unhexlify(signresult['hex']))
tx.deserialize(f)
return tx
def run_test(self):
test_node = TestNode()
connections = []
connections.append(NodeConn('127.0.0.1', p2p_port(0), self.nodes[0],
test_node, "regtest", True))
test_node.add_connection(connections[0])
# Start up network handling in another thread
NetworkThread().start()
test_node.wait_for_verack()
# Verify mininodes are connected to zcashd nodes
peerinfo = self.nodes[0].getpeerinfo()
versions = [x["version"] for x in peerinfo]
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
assert_equal(0, peerinfo[0]["banscore"])
self.coinbase_blocks = self.nodes[0].generate(1)
self.nodes[0].generate(100)
self.nodeaddress = self.nodes[0].getnewaddress()
# Mininodes send transaction to zcashd node.
def setExpiryHeight(tx):
tx.nExpiryHeight = 101
spendtx = self.create_transaction(self.nodes[0],
self.coinbase_blocks[0],
self.nodeaddress, 1.0,
txModifier=setExpiryHeight)
test_node.send_message(msg_tx(spendtx))
time.sleep(3)
# Verify test mininode has not been dropped
# and still has a banscore of 0.
peerinfo = self.nodes[0].getpeerinfo()
versions = [x["version"] for x in peerinfo]
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
assert_equal(0, peerinfo[0]["banscore"])
# Mine a block and resend the transaction
self.nodes[0].generate(1)
test_node.send_message(msg_tx(spendtx))
time.sleep(3)
# Verify test mininode has not been dropped
# but has a banscore of 10.
peerinfo = self.nodes[0].getpeerinfo()
versions = [x["version"] for x in peerinfo]
assert_equal(1, versions.count(OVERWINTER_PROTO_VERSION))
assert_equal(10, peerinfo[0]["banscore"])
[ c.disconnect_node() for c in connections ]
if __name__ == '__main__':
TxExpiryDoSTest().main()

40
qa/rpc-tests/test_framework/mininode.py

@ -44,6 +44,8 @@ BIP0031_VERSION = 60000
MY_VERSION = 170002 # past bip-31 for ping/pong
MY_SUBVERSION = "/python-mininode-tester:0.0.1/"
OVERWINTER_VERSION_GROUP_ID = 0x03C48270
MAX_INV_SZ = 50000
@ -565,20 +567,26 @@ class CTxOut(object):
class CTransaction(object):
def __init__(self, tx=None):
if tx is None:
self.fOverwintered = False
self.nVersion = 1
self.nVersionGroupId = 0
self.vin = []
self.vout = []
self.nLockTime = 0
self.nExpiryHeight = 0
self.vjoinsplit = []
self.joinSplitPubKey = None
self.joinSplitSig = None
self.sha256 = None
self.hash = None
else:
self.fOverwintered = tx.fOverwintered
self.nVersion = tx.nVersion
self.nVersionGroupId = tx.nVersionGroupId
self.vin = copy.deepcopy(tx.vin)
self.vout = copy.deepcopy(tx.vout)
self.nLockTime = tx.nLockTime
self.nExpiryHeight = tx.nExpiryHeight
self.vjoinsplit = copy.deepcopy(tx.vjoinsplit)
self.joinSplitPubKey = tx.joinSplitPubKey
self.joinSplitSig = tx.joinSplitSig
@ -586,24 +594,46 @@ class CTransaction(object):
self.hash = None
def deserialize(self, f):
self.nVersion = struct.unpack("<i", f.read(4))[0]
header = struct.unpack("<I", f.read(4))[0]
self.fOverwintered = bool(header >> 31)
self.nVersion = header & 0x7FFFFFFF
self.nVersionGroupId = (struct.unpack("<I", f.read(4))[0]
if self.fOverwintered else 0)
isOverwinterV3 = (self.fOverwintered and
self.nVersionGroupId == OVERWINTER_VERSION_GROUP_ID and
self.nVersion == 3)
self.vin = deser_vector(f, CTxIn)
self.vout = deser_vector(f, CTxOut)
self.nLockTime = struct.unpack("<I", f.read(4))[0]
if isOverwinterV3:
self.nExpiryHeight = struct.unpack("<I", f.read(4))[0]
if self.nVersion >= 2:
self.vjoinsplit = deser_vector(f, JSDescription)
if len(self.vjoinsplit) > 0:
self.joinSplitPubKey = deser_uint256(f)
self.joinSplitSig = f.read(64)
self.sha256 = None
self.hash = None
def serialize(self):
header = (int(self.fOverwintered)<<31) | self.nVersion
isOverwinterV3 = (self.fOverwintered and
self.nVersionGroupId == OVERWINTER_VERSION_GROUP_ID and
self.nVersion == 3)
r = ""
r += struct.pack("<i", self.nVersion)
r += struct.pack("<I", header)
if self.fOverwintered:
r += struct.pack("<I", self.nVersionGroupId)
r += ser_vector(self.vin)
r += ser_vector(self.vout)
r += struct.pack("<I", self.nLockTime)
if isOverwinterV3:
r += struct.pack("<I", self.nExpiryHeight)
if self.nVersion >= 2:
r += ser_vector(self.vjoinsplit)
if len(self.vjoinsplit) > 0:
@ -628,8 +658,10 @@ class CTransaction(object):
return True
def __repr__(self):
r = "CTransaction(nVersion=%i vin=%s vout=%s nLockTime=%i" \
% (self.nVersion, repr(self.vin), repr(self.vout), self.nLockTime)
r = ("CTransaction(fOverwintered=%r nVersion=%i nVersionGroupId=0x%08x "
"vin=%s vout=%s nLockTime=%i nExpiryHeight=%i"
% (self.fOverwintered, self.nVersion, self.nVersionGroupId,
repr(self.vin), repr(self.vout), self.nLockTime, self.nExpiryHeight))
if self.nVersion >= 2:
r += " vjoinsplit=%s" % repr(self.vjoinsplit)
if len(self.vjoinsplit) > 0:

4
src/main.cpp

@ -897,7 +897,9 @@ bool ContextualCheckTransaction(const CTransaction& tx, CValidationState &state,
// Check that all transactions are unexpired
if (IsExpiredTx(tx, nHeight)) {
return state.DoS(dosLevel, error("ContextualCheckTransaction(): transaction is expired"), REJECT_INVALID, "tx-overwinter-expired");
// Don't increase banscore if the transaction only just expired
int expiredDosLevel = IsExpiredTx(tx, nHeight - 1) ? dosLevel : 0;
return state.DoS(expiredDosLevel, error("ContextualCheckTransaction(): transaction is expired"), REJECT_INVALID, "tx-overwinter-expired");
}
}

Loading…
Cancel
Save