diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d4e4e14..4f8992a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -59,12 +59,15 @@ jobs: - name: Install test dependencies run: python -m pip install --upgrade tox codecov + - name: Check style and typing + if: startsWith(matrix.os, 'ubuntu-') && matrix.python-version == env.PYTHON_VERSION_BUILD_EXTRA + run: tox -e lint,typing + - name: Run tests - run: tox -e ${PYTHON_VERSION},bench + run: tox -e ${PYTHON_VERSION} - - name: Lint - if: startsWith(matrix.os, 'ubuntu-') && matrix.python-version == env.PYTHON_VERSION_BUILD_EXTRA && always() - run: tox -e lint + - name: Run benchmarks + run: tox -e bench - name: Upload coverage run: codecov -X gcov diff --git a/.gitignore b/.gitignore index 0124316..8d65801 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ /.coverage /.eggs /.idea +/.mypy_cache /.tox /coincurve.egg-info /build diff --git a/coincurve/context.py b/coincurve/context.py index b01bd80..c88be3d 100644 --- a/coincurve/context.py +++ b/coincurve/context.py @@ -7,7 +7,7 @@ from ._libsecp256k1 import ffi, lib class Context: - def __init__(self, seed=None, flag=CONTEXT_ALL): + def __init__(self, seed: bytes = None, flag=CONTEXT_ALL): if flag not in CONTEXT_FLAGS: raise ValueError('{} is an invalid context flag.'.format(flag)) self._lock = Lock() @@ -15,7 +15,7 @@ class Context: self.ctx = ffi.gc(lib.secp256k1_context_create(flag), lib.secp256k1_context_destroy) self.reseed(seed) - def reseed(self, seed=None): + def reseed(self, seed: bytes = None): """ Protects against certain possible future side-channel timing attacks. """ diff --git a/coincurve/ecdsa.py b/coincurve/ecdsa.py index 8d14bd0..5dd0b91 100644 --- a/coincurve/ecdsa.py +++ b/coincurve/ecdsa.py @@ -1,4 +1,5 @@ -from coincurve.context import GLOBAL_CONTEXT +from coincurve.context import GLOBAL_CONTEXT, Context +from coincurve.types import Hasher from coincurve.utils import bytes_to_int, int_to_bytes, sha256 from ._libsecp256k1 import ffi, lib @@ -7,7 +8,7 @@ MAX_SIG_LENGTH = 72 CDATA_SIG_LENGTH = 64 -def cdata_to_der(cdata, context=GLOBAL_CONTEXT): +def cdata_to_der(cdata, context: Context = GLOBAL_CONTEXT) -> bytes: der = ffi.new('unsigned char[%d]' % MAX_SIG_LENGTH) der_length = ffi.new('size_t *', MAX_SIG_LENGTH) @@ -16,7 +17,7 @@ def cdata_to_der(cdata, context=GLOBAL_CONTEXT): return bytes(ffi.buffer(der, der_length[0])) -def der_to_cdata(der, context=GLOBAL_CONTEXT): +def der_to_cdata(der: bytes, context: Context = GLOBAL_CONTEXT): cdata = ffi.new('secp256k1_ecdsa_signature *') parsed = lib.secp256k1_ecdsa_signature_parse_der(context.ctx, cdata, der, len(der)) @@ -26,7 +27,7 @@ def der_to_cdata(der, context=GLOBAL_CONTEXT): return cdata -def recover(message, recover_sig, hasher=sha256, context=GLOBAL_CONTEXT): +def recover(message: bytes, recover_sig, hasher: Hasher = sha256, context: Context = GLOBAL_CONTEXT): msg_hash = hasher(message) if hasher is not None else message if len(msg_hash) != 32: raise ValueError('Message hash must be 32 bytes long.') @@ -38,7 +39,7 @@ def recover(message, recover_sig, hasher=sha256, context=GLOBAL_CONTEXT): raise Exception('failed to recover ECDSA public key') -def serialize_recoverable(recover_sig, context=GLOBAL_CONTEXT): +def serialize_recoverable(recover_sig, context: Context = GLOBAL_CONTEXT) -> bytes: output = ffi.new('unsigned char[%d]' % CDATA_SIG_LENGTH) recid = ffi.new('int *') @@ -47,7 +48,7 @@ def serialize_recoverable(recover_sig, context=GLOBAL_CONTEXT): return bytes(ffi.buffer(output, CDATA_SIG_LENGTH)) + int_to_bytes(recid[0]) -def deserialize_recoverable(serialized, context=GLOBAL_CONTEXT): +def deserialize_recoverable(serialized: bytes, context: Context = GLOBAL_CONTEXT): if len(serialized) != 65: raise ValueError('Serialized signature must be 65 bytes long.') @@ -71,7 +72,7 @@ Warning: """ -def serialize_compact(raw_sig, context=GLOBAL_CONTEXT): # no cov +def serialize_compact(raw_sig, context: Context = GLOBAL_CONTEXT): # no cov output = ffi.new('unsigned char[%d]' % CDATA_SIG_LENGTH) res = lib.secp256k1_ecdsa_signature_serialize_compact(context.ctx, output, raw_sig) @@ -80,7 +81,7 @@ def serialize_compact(raw_sig, context=GLOBAL_CONTEXT): # no cov return bytes(ffi.buffer(output, CDATA_SIG_LENGTH)) -def deserialize_compact(ser_sig, context=GLOBAL_CONTEXT): # no cov +def deserialize_compact(ser_sig: bytes, context: Context = GLOBAL_CONTEXT): # no cov if len(ser_sig) != 64: raise Exception('invalid signature length') @@ -91,15 +92,13 @@ def deserialize_compact(ser_sig, context=GLOBAL_CONTEXT): # no cov return raw_sig -def signature_normalize(raw_sig, context=GLOBAL_CONTEXT): # no cov +def signature_normalize(raw_sig, context: Context = GLOBAL_CONTEXT): # no cov """ Check and optionally convert a signature to a normalized lower-S form. - If check_only is True then the normalized signature is not returned. This function always return a tuple containing a boolean (True if not previously normalized or False if signature was already - normalized), and the normalized signature. When check_only is True, - the normalized signature returned is always None. + normalized), and the normalized signature. """ sigout = ffi.new('secp256k1_ecdsa_signature *') @@ -108,7 +107,7 @@ def signature_normalize(raw_sig, context=GLOBAL_CONTEXT): # no cov return not not res, sigout -def recoverable_convert(recover_sig, context=GLOBAL_CONTEXT): # no cov +def recoverable_convert(recover_sig, context: Context = GLOBAL_CONTEXT): # no cov normal_sig = ffi.new('secp256k1_ecdsa_signature *') lib.secp256k1_ecdsa_recoverable_signature_convert(context.ctx, normal_sig, recover_sig) diff --git a/coincurve/keys.py b/coincurve/keys.py index 38ddc6a..e6cacf4 100644 --- a/coincurve/keys.py +++ b/coincurve/keys.py @@ -1,8 +1,11 @@ +from typing import Tuple + from asn1crypto.keys import ECDomainParameters, ECPointBitString, ECPrivateKey, PrivateKeyAlgorithm, PrivateKeyInfo -from coincurve.context import GLOBAL_CONTEXT +from coincurve.context import GLOBAL_CONTEXT, Context from coincurve.ecdsa import cdata_to_der, der_to_cdata, deserialize_recoverable, recover, serialize_recoverable from coincurve.flags import EC_COMPRESSED, EC_UNCOMPRESSED +from coincurve.types import Hasher from coincurve.utils import ( bytes_to_int, der_to_pem, @@ -21,12 +24,12 @@ DEFAULT_NONCE = (ffi.NULL, ffi.NULL) class PrivateKey: - def __init__(self, secret=None, context=GLOBAL_CONTEXT): - self.secret = validate_secret(secret) if secret is not None else get_valid_secret() + def __init__(self, secret: bytes = None, context: Context = GLOBAL_CONTEXT): + self.secret: bytes = validate_secret(secret) if secret is not None else get_valid_secret() self.context = context - self.public_key = PublicKey.from_valid_secret(self.secret, self.context) + self.public_key: PublicKey = PublicKey.from_valid_secret(self.secret, self.context) - def sign(self, message, hasher=sha256, custom_nonce=None): + def sign(self, message: bytes, hasher: Hasher = sha256, custom_nonce=None) -> bytes: msg_hash = hasher(message) if hasher is not None else message if len(msg_hash) != 32: raise ValueError('Message hash must be 32 bytes long.') @@ -41,7 +44,7 @@ class PrivateKey: return cdata_to_der(signature, self.context) - def sign_recoverable(self, message, hasher=sha256): + def sign_recoverable(self, message, hasher: Hasher = sha256): msg_hash = hasher(message) if hasher is not None else message if len(msg_hash) != 32: raise ValueError('Message hash must be 32 bytes long.') @@ -57,14 +60,14 @@ class PrivateKey: return serialize_recoverable(signature, self.context) - def ecdh(self, public_key): + def ecdh(self, public_key: bytes) -> bytes: secret = ffi.new('unsigned char [32]') lib.secp256k1_ecdh(self.context.ctx, secret, PublicKey(public_key).public_key, self.secret, ffi.NULL, ffi.NULL) return bytes(ffi.buffer(secret, 32)) - def add(self, scalar, update=False): + def add(self, scalar: bytes, update=False): scalar = pad_scalar(scalar) secret = ffi.new('unsigned char [32]', self.secret) @@ -83,7 +86,7 @@ class PrivateKey: return PrivateKey(secret, self.context) - def multiply(self, scalar, update=False): + def multiply(self, scalar: bytes, update=False): scalar = validate_secret(scalar) secret = ffi.new('unsigned char [32]', self.secret) @@ -99,16 +102,16 @@ class PrivateKey: return PrivateKey(secret, self.context) - def to_hex(self): + def to_hex(self) -> str: return self.secret.hex() - def to_int(self): + def to_int(self) -> int: return bytes_to_int(self.secret) - def to_pem(self): + def to_pem(self) -> bytes: return der_to_pem(self.to_der()) - def to_der(self): + def to_der(self) -> bytes: pk = ECPrivateKey( { 'version': 'ecPrivkeyVer1', @@ -131,21 +134,21 @@ class PrivateKey: ).dump() @classmethod - def from_hex(cls, hexed, context=GLOBAL_CONTEXT): + def from_hex(cls, hexed: str, context: Context = GLOBAL_CONTEXT): return PrivateKey(hex_to_bytes(hexed), context) @classmethod - def from_int(cls, num, context=GLOBAL_CONTEXT): + def from_int(cls, num: int, context: Context = GLOBAL_CONTEXT): return PrivateKey(int_to_bytes_padded(num), context) @classmethod - def from_pem(cls, pem, context=GLOBAL_CONTEXT): + def from_pem(cls, pem: bytes, context: Context = GLOBAL_CONTEXT): return PrivateKey( int_to_bytes_padded(PrivateKeyInfo.load(pem_to_der(pem)).native['private_key']['private_key']), context ) @classmethod - def from_der(cls, der, context=GLOBAL_CONTEXT): + def from_der(cls, der: bytes, context: Context = GLOBAL_CONTEXT): return PrivateKey(int_to_bytes_padded(PrivateKeyInfo.load(der).native['private_key']['private_key']), context) def _update_public_key(self): @@ -154,12 +157,12 @@ class PrivateKey: if not created: raise ValueError('Invalid secret.') - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.secret == other.secret class PublicKey: - def __init__(self, data, context=GLOBAL_CONTEXT): + def __init__(self, data, context: Context = GLOBAL_CONTEXT): if not isinstance(data, bytes): self.public_key = data else: @@ -175,7 +178,7 @@ class PublicKey: self.context = context @classmethod - def from_secret(cls, secret, context=GLOBAL_CONTEXT): + def from_secret(cls, secret: bytes, context: Context = GLOBAL_CONTEXT): public_key = ffi.new('secp256k1_pubkey *') created = lib.secp256k1_ec_pubkey_create(context.ctx, public_key, validate_secret(secret)) @@ -190,7 +193,7 @@ class PublicKey: return PublicKey(public_key, context) @classmethod - def from_valid_secret(cls, secret, context=GLOBAL_CONTEXT): + def from_valid_secret(cls, secret: bytes, context: Context = GLOBAL_CONTEXT): public_key = ffi.new('secp256k1_pubkey *') created = lib.secp256k1_ec_pubkey_create(context.ctx, public_key, secret) @@ -201,17 +204,19 @@ class PublicKey: return PublicKey(public_key, context) @classmethod - def from_point(cls, x, y, context=GLOBAL_CONTEXT): + def from_point(cls, x: int, y: int, context: Context = GLOBAL_CONTEXT): return PublicKey(b'\x04' + int_to_bytes_padded(x) + int_to_bytes_padded(y), context) @classmethod - def from_signature_and_message(cls, serialized_sig, message, hasher=sha256, context=GLOBAL_CONTEXT): + def from_signature_and_message( + cls, serialized_sig: bytes, message: bytes, hasher: Hasher = sha256, context: Context = GLOBAL_CONTEXT + ): return PublicKey( recover(message, deserialize_recoverable(serialized_sig, context=context), hasher=hasher, context=context) ) @classmethod - def combine_keys(cls, public_keys, context=GLOBAL_CONTEXT): + def combine_keys(cls, public_keys, context: Context = GLOBAL_CONTEXT): public_key = ffi.new('secp256k1_pubkey *') combined = lib.secp256k1_ec_pubkey_combine( @@ -223,7 +228,7 @@ class PublicKey: return PublicKey(public_key, context) - def format(self, compressed=True): + def format(self, compressed=True) -> bytes: length = 33 if compressed else 65 serialized = ffi.new('unsigned char [%d]' % length) output_len = ffi.new('size_t *', length) @@ -234,11 +239,11 @@ class PublicKey: return bytes(ffi.buffer(serialized, length)) - def point(self): + def point(self) -> Tuple[int, int]: public_key = self.format(compressed=False) return bytes_to_int(public_key[1:33]), bytes_to_int(public_key[33:]) - def verify(self, signature, message, hasher=sha256): + def verify(self, signature: bytes, message: bytes, hasher: Hasher = sha256) -> bool: msg_hash = hasher(message) if hasher is not None else message if len(msg_hash) != 32: raise ValueError('Message hash must be 32 bytes long.') @@ -248,7 +253,7 @@ class PublicKey: # A performance hack to avoid global bool() lookup. return not not verified - def add(self, scalar, update=False): + def add(self, scalar: bytes, update=False): scalar = pad_scalar(scalar) new_key = ffi.new('secp256k1_pubkey *', self.public_key[0]) @@ -264,7 +269,7 @@ class PublicKey: return PublicKey(new_key, self.context) - def multiply(self, scalar, update=False): + def multiply(self, scalar: bytes, update=False): scalar = validate_secret(scalar) new_key = ffi.new('secp256k1_pubkey *', self.public_key[0]) @@ -293,5 +298,5 @@ class PublicKey: return PublicKey(new_key, self.context) - def __eq__(self, other): + def __eq__(self, other) -> bool: return self.format(compressed=False) == other.format(compressed=False) diff --git a/coincurve/types.py b/coincurve/types.py new file mode 100644 index 0000000..123baed --- /dev/null +++ b/coincurve/types.py @@ -0,0 +1,10 @@ +import sys +from typing import Optional + +# https://bugs.python.org/issue42965 +if sys.version_info >= (3, 9, 2): + from collections.abc import Callable +else: + from typing import Callable + +Hasher = Optional[Callable[[bytes], bytes]] diff --git a/coincurve/utils.py b/coincurve/utils.py index d803048..44e728a 100644 --- a/coincurve/utils.py +++ b/coincurve/utils.py @@ -1,8 +1,10 @@ from base64 import b64decode, b64encode from hashlib import sha256 as _sha256 from os import urandom +from typing import Generator -from coincurve.context import GLOBAL_CONTEXT +from coincurve.context import GLOBAL_CONTEXT, Context +from coincurve.types import Hasher from ._libsecp256k1 import ffi, lib @@ -17,61 +19,63 @@ PEM_HEADER = b'-----BEGIN PRIVATE KEY-----\n' PEM_FOOTER = b'-----END PRIVATE KEY-----\n' -def pad_hex(hexed): +def pad_hex(hexed: str) -> str: # Pad odd-length hex strings. return hexed if not len(hexed) & 1 else f'0{hexed}' -def bytes_to_int(bytestr): +def bytes_to_int(bytestr: bytes) -> int: return int.from_bytes(bytestr, 'big') -def int_to_bytes(num): +def int_to_bytes(num: int) -> bytes: return num.to_bytes((num.bit_length() + 7) // 8 or 1, 'big') -def int_to_bytes_padded(num): +def int_to_bytes_padded(num: int) -> bytes: return pad_scalar(num.to_bytes((num.bit_length() + 7) // 8 or 1, 'big')) -def hex_to_bytes(hexed): +def hex_to_bytes(hexed: str) -> bytes: return pad_scalar(bytes.fromhex(pad_hex(hexed))) -def sha256(bytestr): +def sha256(bytestr: bytes) -> bytes: return _sha256(bytestr).digest() -def chunk_data(data, size): +def chunk_data(data: bytes, size: int) -> Generator[bytes, None, None]: return (data[i : i + size] for i in range(0, len(data), size)) -def der_to_pem(der): +def der_to_pem(der: bytes) -> bytes: return b''.join([PEM_HEADER, b'\n'.join(chunk_data(b64encode(der), 64)), b'\n', PEM_FOOTER]) -def pem_to_der(pem): +def pem_to_der(pem: bytes) -> bytes: return b64decode(b''.join(pem.strip().splitlines()[1:-1])) -def get_valid_secret(): +def get_valid_secret() -> bytes: while True: secret = urandom(KEY_SIZE) if ZERO < secret < GROUP_ORDER: return secret -def pad_scalar(scalar): +def pad_scalar(scalar: bytes) -> bytes: return (ZERO * (KEY_SIZE - len(scalar))) + scalar -def validate_secret(secret): +def validate_secret(secret: bytes) -> bytes: if not 0 < bytes_to_int(secret) < GROUP_ORDER_INT: raise ValueError('Secret scalar must be greater than 0 and less than {}.'.format(GROUP_ORDER_INT)) return pad_scalar(secret) -def verify_signature(signature, message, public_key, hasher=sha256, context=GLOBAL_CONTEXT): +def verify_signature( + signature: bytes, message: bytes, public_key: bytes, hasher: Hasher = sha256, context: Context = GLOBAL_CONTEXT +) -> bool: pubkey = ffi.new('secp256k1_pubkey *') pubkey_parsed = lib.secp256k1_ec_pubkey_parse(context.ctx, pubkey, public_key, len(public_key)) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..1772231 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +disallow_untyped_defs = false +follow_imports = normal +ignore_missing_imports = true +pretty = true +show_column_numbers = true +warn_no_return = false +warn_unused_ignores = true diff --git a/tox.ini b/tox.ini index 2094575..e094bb8 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ envlist = bench lint fmt + typing [testenv] passenv = * @@ -51,3 +52,10 @@ commands = isort . black . {[testenv:lint]commands} + +[testenv:typing] +skip_install = true +deps = + mypy==0.790 +commands = + mypy coincurve