Browse Source

refactor: Replace struct.pack/unpack.

pull/52/merge
tecnovert 4 months ago
parent
commit
8318961f0b
No known key found for this signature in database GPG Key ID: 8ED6D8750C4E3F93
  1. 33
      basicswap/network.py

33
basicswap/network.py

@ -24,7 +24,6 @@ import queue
import random
import select
import socket
import struct
import hashlib
import logging
import secrets
@ -41,7 +40,7 @@ from basicswap.contrib.rfc6979 import (
START_TOKEN = 0xabcd
MSG_START_TOKEN = struct.pack('>H', START_TOKEN)
MSG_START_TOKEN = START_TOKEN.to_bytes(2, 'big')
MSG_MAX_SIZE = 0x200000 # 2MB
@ -83,8 +82,8 @@ class MsgHandshake:
pass
def encode_aad(self): # Additional Authenticated Data
return struct.pack('>H', NetMessageTypes.HANDSHAKE) + \
struct.pack('>Q', self._timestamp) + \
return int(NetMessageTypes.HANDSHAKE).to_bytes(2, 'big') + \
self._timestamp.to_bytes(8, 'big') + \
self._ephem_pk
def encode(self):
@ -92,7 +91,7 @@ class MsgHandshake:
def decode(self, msg_mv):
o = 2
self._timestamp = struct.unpack('>Q', msg_mv[o: o + 8])[0]
self._timestamp = int.from_bytes(msg_mv[o: o + 8], 'big')
o += 8
self._ephem_pk = bytes(msg_mv[o: o + 33])
o += 33
@ -333,7 +332,7 @@ class Network:
ss = k.ecdh(peer._pubkey)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
peer._ke = hashed[:32]
peer._km = hashed[32:]
@ -386,7 +385,7 @@ class Network:
nk = PrivateKey(self._network_key)
ss = nk.ecdh(msg._ephem_pk)
hashed = hashlib.sha512(ss + struct.pack('>Q', msg._timestamp)).digest()
hashed = hashlib.sha512(ss + msg._timestamp.to_bytes(8, 'big')).digest()
peer._ke = hashed[:32]
peer._km = hashed[32:]
@ -427,7 +426,7 @@ class Network:
mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
ping_nonce = struct.unpack('>I', plaintext[:4])[0]
ping_nonce = int.from_bytes(plaintext[:4], 'big')
# Version is added to a ping following a handshake message
if len(plaintext) >= 10:
peer._ready = True
@ -450,7 +449,7 @@ class Network:
mac = msg_mv[-16:]
plaintext = cipher.decrypt_and_verify(msg_mv[2: -16], mac)
pong_nonce = struct.unpack('>I', plaintext[:4])[0]
pong_nonce = int.from_bytes(plaintext[:4], 'big')
if pong_nonce == peer._ping_nonce:
peer._last_ping_rtt = (time.time_ns() // 1000) - peer._last_ping_at
@ -462,14 +461,14 @@ class Network:
def send_ping(self, peer):
ping_nonce = random.getrandbits(32)
msg_bytes = struct.pack('>H', NetMessageTypes.PING)
msg_bytes = int(NetMessageTypes.PING).to_bytes(2, 'big')
nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes)
cipher.update(nonce)
payload = struct.pack('>I', ping_nonce)
payload = ping_nonce.to_bytes(4, 'big')
if peer._last_ping_at == 0:
payload += self._sc._version
ct, mac = cipher.encrypt_and_digest(payload)
@ -484,14 +483,14 @@ class Network:
self.send_msg(peer, msg_bytes)
def send_pong(self, peer, ping_nonce):
msg_bytes = struct.pack('>H', NetMessageTypes.PONG)
msg_bytes = int(NetMessageTypes.PONG).to_bytes(2, 'big')
nonce = peer._sent_nonce[:24]
cipher = ChaCha20_Poly1305.new(key=peer._ke, nonce=nonce)
cipher.update(msg_bytes)
cipher.update(nonce)
payload = struct.pack('>I', ping_nonce)
payload = ping_nonce.to_bytes(4, 'big')
ct, mac = cipher.encrypt_and_digest(payload)
msg_bytes += ct + mac
@ -503,7 +502,7 @@ class Network:
msg_encoded = msg if isinstance(msg, bytes) else msg.encode()
len_encoded = len(msg_encoded)
msg_packed = bytearray(MSG_START_TOKEN) + struct.pack('>I', len_encoded) + msg_encoded
msg_packed = bytearray(MSG_START_TOKEN) + len_encoded.to_bytes(4, 'big') + msg_encoded
peer._socket.sendall(msg_packed)
peer._bytes_sent += len_encoded
@ -515,7 +514,7 @@ class Network:
try:
mv = memoryview(msg_bytes)
o = 0
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
msg_type = int.from_bytes(mv[o: o + 2], 'big')
if msg_type == NetMessageTypes.HANDSHAKE:
self.process_handshake(peer, mv)
elif msg_type == NetMessageTypes.PING:
@ -548,13 +547,13 @@ class Network:
raise ValueError('Invalid start token')
o += 2
msg_len = struct.unpack('>I', mv[o: o + 4])[0]
msg_len = int.from_bytes(mv[o: o + 4], 'big')
o += 4
if msg_len < 2 or msg_len > MSG_MAX_SIZE:
raise ValueError('Invalid data length')
# Precheck msg_type
msg_type = struct.unpack('>H', mv[o: o + 2])[0]
msg_type = int.from_bytes(mv[o: o + 2], 'big')
# o += 2 # Don't inc offset, msg includes type
if not NetMessageTypes.has_value(msg_type):
raise ValueError('Invalid msg type')

Loading…
Cancel
Save