1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-05 04:50:57 +00:00

core: Fix typing.

This commit is contained in:
Andrew Kozlik 2020-04-07 09:39:11 +02:00 committed by Andrew Kozlik
parent ba8b34b2d7
commit 79c60615de
8 changed files with 65 additions and 45 deletions

View File

@ -7,17 +7,20 @@ from trezor.utils import HashWriter, ensure
from apps.wallet.sign_tx.writers import write_bytes_fixed, write_uint32 from apps.wallet.sign_tx.writers import write_bytes_fixed, write_uint32
if False:
from typing import List, Optional
class MultisigError(ValueError): class MultisigError(ValueError):
pass pass
class MultisigFingerprint: class MultisigFingerprint:
def __init__(self): def __init__(self) -> None:
self.fingerprint = None # multisig fingerprint bytes self.fingerprint = None # type: Optional[bytes] # multisig fingerprint bytes
self.mismatch = False # flag if multisig input fingerprints are equal self.mismatch = False # flag if multisig input fingerprints are equal
def add(self, multisig: MultisigRedeemScriptType): def add(self, multisig: MultisigRedeemScriptType) -> None:
fp = multisig_fingerprint(multisig) fp = multisig_fingerprint(multisig)
ensure(fp is not None) ensure(fp is not None)
if self.fingerprint is None: if self.fingerprint is None:
@ -25,7 +28,7 @@ class MultisigFingerprint:
elif self.fingerprint != fp: elif self.fingerprint != fp:
self.mismatch = True self.mismatch = True
def matches(self, multisig: MultisigRedeemScriptType): def matches(self, multisig: MultisigRedeemScriptType) -> bool:
fp = multisig_fingerprint(multisig) fp = multisig_fingerprint(multisig)
ensure(fp is not None) ensure(fp is not None)
if self.mismatch is False and self.fingerprint == fp: if self.mismatch is False and self.fingerprint == fp:
@ -90,14 +93,14 @@ def multisig_get_pubkey(n: HDNodeType, p: list) -> bytes:
return node.public_key() return node.public_key()
def multisig_get_pubkeys(multisig: MultisigRedeemScriptType): def multisig_get_pubkeys(multisig: MultisigRedeemScriptType) -> List[bytes]:
if multisig.nodes: if multisig.nodes:
return [multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes] return [multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes]
else: else:
return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys] return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys]
def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType): def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType) -> int:
if multisig.nodes: if multisig.nodes:
return len(multisig.nodes) return len(multisig.nodes)
else: else:

View File

@ -2,6 +2,9 @@ from ustruct import unpack
from trezor.strings import format_amount from trezor.strings import format_amount
if False:
from typing import Optional
currencies = { currencies = {
1: ("OMNI", True), 1: ("OMNI", True),
2: ("tOMNI", True), 2: ("tOMNI", True),
@ -14,7 +17,7 @@ def is_valid(data: bytes) -> bool:
return len(data) >= 8 and data[:4] == b"omni" return len(data) >= 8 and data[:4] == b"omni"
def parse(data: bytes) -> bool: def parse(data: bytes) -> Optional[str]:
if not is_valid(data): if not is_valid(data):
return None return None
tx_version, tx_type = unpack(">HH", data[4:8]) tx_version, tx_type = unpack(">HH", data[4:8])

View File

@ -4,7 +4,7 @@ _progress = 0
_steps = 0 _steps = 0
def init(inputs, outputs): def init(inputs: int, outputs: int) -> None:
global _progress, _steps global _progress, _steps
_progress = 0 _progress = 0
_steps = inputs + inputs + outputs + inputs _steps = inputs + inputs + outputs + inputs
@ -12,18 +12,18 @@ def init(inputs, outputs):
report() report()
def advance(): def advance() -> None:
global _progress global _progress
_progress += 1 _progress += 1
report() report()
def report_init(): def report_init() -> None:
ui.display.clear() ui.display.clear()
ui.header("Signing transaction") ui.header("Signing transaction")
def report(): def report() -> None:
if utils.DISABLE_ANIMATION: if utils.DISABLE_ANIMATION:
return return
p = 1000 * _progress // _steps p = 1000 * _progress // _steps

View File

@ -132,7 +132,7 @@ def input_script_p2wsh_in_p2sh(script_hash: bytes) -> bytearray:
# === # ===
def witness_p2wpkh(signature: bytes, pubkey: bytes, sighash: int): def witness_p2wpkh(signature: bytes, pubkey: bytes, sighash: int) -> bytearray:
w = empty_bytearray(1 + 5 + len(signature) + 1 + 5 + len(pubkey)) w = empty_bytearray(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2 write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
append_signature(w, signature, sighash) append_signature(w, signature, sighash)

View File

@ -47,7 +47,7 @@ def derive_script_code(txi: TxInputType, pubkeyhash: bytes) -> bytearray:
class Zip143: class Zip143:
def __init__(self, branch_id): def __init__(self, branch_id: int) -> None:
self.branch_id = branch_id self.branch_id = branch_id
self.h_prevouts = HashWriter(blake2b(outlen=32, personal=b"ZcashPrevoutHash")) self.h_prevouts = HashWriter(blake2b(outlen=32, personal=b"ZcashPrevoutHash"))
self.h_sequence = HashWriter(blake2b(outlen=32, personal=b"ZcashSequencHash")) self.h_sequence = HashWriter(blake2b(outlen=32, personal=b"ZcashSequencHash"))
@ -119,7 +119,7 @@ class Zip143:
class Zip243(Zip143): class Zip243(Zip143):
def __init__(self, branch_id): def __init__(self, branch_id) -> None:
super().__init__(branch_id) super().__init__(branch_id)
def preimage_hash( def preimage_hash(

View File

@ -13,11 +13,14 @@
# This module adds shiny packaging and support for python3. # This module adds shiny packaging and support for python3.
# #
if False:
from typing import Callable
# 58 character alphabet used # 58 character alphabet used
_alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" _alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
def encode(data: bytes, alphabet=_alphabet) -> str: def encode(data: bytes, alphabet: str = _alphabet) -> str:
""" """
Convert bytes to base58 encoded string. Convert bytes to base58 encoded string.
""" """
@ -38,7 +41,7 @@ def encode(data: bytes, alphabet=_alphabet) -> str:
return "".join((c for c in reversed(result + alphabet[0] * (origlen - newlen)))) return "".join((c for c in reversed(result + alphabet[0] * (origlen - newlen))))
def decode(string: str, alphabet=_alphabet) -> bytes: def decode(string: str, alphabet: str = _alphabet) -> bytes:
""" """
Convert base58 encoded string to bytes. Convert base58 encoded string to bytes.
""" """
@ -89,14 +92,16 @@ def ripemd160_32(data: bytes) -> bytes:
return ripemd160(data).digest()[:4] return ripemd160(data).digest()[:4]
def encode_check(data: bytes, digestfunc=sha256d_32) -> str: def encode_check(data: bytes, digestfunc: Callable[[bytes], bytes] = sha256d_32) -> str:
""" """
Convert bytes to base58 encoded string, append checksum. Convert bytes to base58 encoded string, append checksum.
""" """
return encode(data + digestfunc(data)) return encode(data + digestfunc(data))
def decode_check(string: str, digestfunc=sha256d_32) -> bytes: def decode_check(
string: str, digestfunc: Callable[[bytes], bytes] = sha256d_32
) -> bytes:
""" """
Convert base58 encoded string to bytes and verify checksum. Convert base58 encoded string to bytes and verify checksum.
""" """
@ -104,7 +109,7 @@ def decode_check(string: str, digestfunc=sha256d_32) -> bytes:
return verify_checksum(result, digestfunc) return verify_checksum(result, digestfunc)
def verify_checksum(data: bytes, digestfunc) -> bytes: def verify_checksum(data: bytes, digestfunc: Callable[[bytes], bytes]) -> bytes:
digestlen = len(digestfunc(b"")) digestlen = len(digestfunc(b""))
result, check = data[:-digestlen], data[-digestlen:] result, check = data[:-digestlen], data[-digestlen:]

View File

@ -20,11 +20,13 @@
"""Reference implementation for Bech32 and segwit addresses.""" """Reference implementation for Bech32 and segwit addresses."""
if False:
from typing import List, Optional, Tuple
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
def bech32_polymod(values): def bech32_polymod(values: List[int]) -> int:
"""Internal function that computes the Bech32 checksum.""" """Internal function that computes the Bech32 checksum."""
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1 chk = 1
@ -36,30 +38,30 @@ def bech32_polymod(values):
return chk return chk
def bech32_hrp_expand(hrp): def bech32_hrp_expand(hrp: str) -> List[int]:
"""Expand the HRP into values for checksum computation.""" """Expand the HRP into values for checksum computation."""
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
def bech32_verify_checksum(hrp, data): def bech32_verify_checksum(hrp: str, data: List[int]) -> bool:
"""Verify a checksum given HRP and converted data characters.""" """Verify a checksum given HRP and converted data characters."""
return bech32_polymod(bech32_hrp_expand(hrp) + data) == 1 return bech32_polymod(bech32_hrp_expand(hrp) + data) == 1
def bech32_create_checksum(hrp, data): def bech32_create_checksum(hrp: str, data: List[int]) -> List[int]:
"""Compute the checksum values given HRP and data.""" """Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data values = bech32_hrp_expand(hrp) + data
polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ 1 polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ 1
return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)] return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)]
def bech32_encode(hrp, data): def bech32_encode(hrp: str, data: List[int]) -> str:
"""Compute a Bech32 string given HRP and data values.""" """Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data) combined = data + bech32_create_checksum(hrp, data)
return hrp + "1" + "".join([CHARSET[d] for d in combined]) return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech): def bech32_decode(bech: str) -> Tuple[Optional[str], Optional[List[int]]]:
"""Validate a Bech32 string, and determine HRP and data.""" """Validate a Bech32 string, and determine HRP and data."""
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech bech.lower() != bech and bech.upper() != bech
@ -78,7 +80,9 @@ def bech32_decode(bech):
return (hrp, data[:-6]) return (hrp, data[:-6])
def convertbits(data, frombits, tobits, pad=True): def convertbits(
data: List[int], frombits: int, tobits: int, pad: bool = True
) -> List[int]:
"""General power-of-2 base conversion.""" """General power-of-2 base conversion."""
acc = 0 acc = 0
bits = 0 bits = 0
@ -87,7 +91,7 @@ def convertbits(data, frombits, tobits, pad=True):
max_acc = (1 << (frombits + tobits - 1)) - 1 max_acc = (1 << (frombits + tobits - 1)) - 1
for value in data: for value in data:
if value < 0 or (value >> frombits): if value < 0 or (value >> frombits):
return None raise ValueError
acc = ((acc << frombits) | value) & max_acc acc = ((acc << frombits) | value) & max_acc
bits += frombits bits += frombits
while bits >= tobits: while bits >= tobits:
@ -97,17 +101,17 @@ def convertbits(data, frombits, tobits, pad=True):
if bits: if bits:
ret.append((acc << (tobits - bits)) & maxv) ret.append((acc << (tobits - bits)) & maxv)
elif bits >= frombits or ((acc << (tobits - bits)) & maxv): elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
return None raise ValueError
return ret return ret
def decode(hrp, addr): def decode(hrp: str, addr: str) -> Tuple[Optional[int], Optional[List[int]]]:
"""Decode a segwit address.""" """Decode a segwit address."""
hrpgot, data = bech32_decode(addr) hrpgot, data = bech32_decode(addr)
if hrpgot != hrp: if data is None or hrpgot != hrp:
return (None, None) return (None, None)
decoded = convertbits(data[1:], 5, 8, False) decoded = convertbits(data[1:], 5, 8, False)
if decoded is None or len(decoded) < 2 or len(decoded) > 40: if len(decoded) < 2 or len(decoded) > 40:
return (None, None) return (None, None)
if data[0] > 16: if data[0] > 16:
return (None, None) return (None, None)
@ -116,7 +120,7 @@ def decode(hrp, addr):
return (data[0], decoded) return (data[0], decoded)
def encode(hrp, witver, witprog): def encode(hrp: str, witver: int, witprog: List[int]) -> Optional[str]:
"""Encode a segwit address.""" """Encode a segwit address."""
ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5)) ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5))
if decode(hrp, ret) == (None, None): if decode(hrp, ret) == (None, None):

View File

@ -20,12 +20,15 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
if False:
from typing import Iterable, List, Tuple
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
ADDRESS_TYPE_P2KH = 0 ADDRESS_TYPE_P2KH = 0
ADDRESS_TYPE_P2SH = 8 ADDRESS_TYPE_P2SH = 8
def cashaddr_polymod(values): def cashaddr_polymod(values: List[int]) -> int:
generator = [0x98F2BC8E61, 0x79B76D99E2, 0xF33E5FB3C4, 0xAE2EABE2A8, 0x1E4F43E470] generator = [0x98F2BC8E61, 0x79B76D99E2, 0xF33E5FB3C4, 0xAE2EABE2A8, 0x1E4F43E470]
chk = 1 chk = 1
for value in values: for value in values:
@ -36,11 +39,11 @@ def cashaddr_polymod(values):
return chk ^ 1 return chk ^ 1
def prefix_expand(prefix): def prefix_expand(prefix: str) -> List[int]:
return [ord(x) & 0x1F for x in prefix] + [0] return [ord(x) & 0x1F for x in prefix] + [0]
def calculate_checksum(prefix, payload): def calculate_checksum(prefix: str, payload: List[int]) -> List[int]:
poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0]) poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0])
out = list() out = list()
for i in range(8): for i in range(8):
@ -48,25 +51,27 @@ def calculate_checksum(prefix, payload):
return out return out
def verify_checksum(prefix, payload): def verify_checksum(prefix: str, payload: List[int]) -> bool:
return cashaddr_polymod(prefix_expand(prefix) + payload) == 0 return cashaddr_polymod(prefix_expand(prefix) + payload) == 0
def b32decode(inputs): def b32decode(inputs: str) -> List[int]:
out = list() out = list()
for letter in inputs: for letter in inputs:
out.append(CHARSET.find(letter)) out.append(CHARSET.find(letter))
return out return out
def b32encode(inputs): def b32encode(inputs: List[int]) -> str:
out = "" out = ""
for char_code in inputs: for char_code in inputs:
out += CHARSET[char_code] out += CHARSET[char_code]
return out return out
def convertbits(data, frombits, tobits, pad=True): def convertbits(
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
) -> List[int]:
acc = 0 acc = 0
bits = 0 bits = 0
ret = [] ret = []
@ -74,7 +79,7 @@ def convertbits(data, frombits, tobits, pad=True):
max_acc = (1 << (frombits + tobits - 1)) - 1 max_acc = (1 << (frombits + tobits - 1)) - 1
for value in data: for value in data:
if value < 0 or (value >> frombits): if value < 0 or (value >> frombits):
return None raise ValueError
acc = ((acc << frombits) | value) & max_acc acc = ((acc << frombits) | value) & max_acc
bits += frombits bits += frombits
while bits >= tobits: while bits >= tobits:
@ -84,18 +89,18 @@ def convertbits(data, frombits, tobits, pad=True):
if bits: if bits:
ret.append((acc << (tobits - bits)) & maxv) ret.append((acc << (tobits - bits)) & maxv)
elif bits >= frombits or ((acc << (tobits - bits)) & maxv): elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
return None raise ValueError
return ret return ret
def encode(prefix, version, payload): def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
payload = bytes([version]) + payload payload_bytes = bytes([version]) + payload_bytes
payload = convertbits(payload, 8, 5) payload = convertbits(payload_bytes, 8, 5)
checksum = calculate_checksum(prefix, payload) checksum = calculate_checksum(prefix, payload)
return prefix + ":" + b32encode(payload + checksum) return prefix + ":" + b32encode(payload + checksum)
def decode(prefix, addr): def decode(prefix: str, addr: str) -> Tuple[int, bytes]:
addr = addr.lower() addr = addr.lower()
decoded = b32decode(addr) decoded = b32decode(addr)
if not verify_checksum(prefix, decoded): if not verify_checksum(prefix, decoded):