feat(core/ethereum): enable type-checking for Ethereum app

pull/1804/head
matejcik 3 years ago committed by matejcik
parent 0c2863fc8d
commit e80077e0a4

@ -0,0 +1 @@
Full type-checking for Ethereum app

@ -109,6 +109,7 @@ mypy:
src/main.py \
src/apps/bitcoin \
src/apps/cardano \
src/apps/ethereum \
src/apps/management \
src/apps/misc \
src/apps/webauthn \

@ -1,35 +1,42 @@
from ubinascii import unhexlify
from ubinascii import hexlify, unhexlify
from trezor import wire
if False:
from .networks import NetworkInfo
def address_from_bytes(address_bytes: bytes, network=None) -> str:
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str:
"""
Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
"""
from ubinascii import hexlify
from trezor.crypto.hashlib import sha3_256
rskip60 = network is not None and network.rskip60
hx = hexlify(address_bytes).decode()
prefix = str(network.chain_id) + "0x" if rskip60 else ""
hs = sha3_256(prefix + hx, keccak=True).digest()
h = ""
for i in range(20):
l = hx[i * 2]
if hs[i] & 0x80 and l >= "a" and l <= "f":
l = l.upper()
h += l
l = hx[i * 2 + 1]
if hs[i] & 0x08 and l >= "a" and l <= "f":
l = l.upper()
h += l
return "0x" + h
if network is not None and network.rskip60:
prefix = str(network.chain_id) + "0x"
else:
prefix = ""
address_hex = hexlify(address_bytes).decode()
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
def maybe_upper(i: int) -> str:
"""Uppercase i-th letter only if the corresponding nibble has high bit set."""
digest_byte = digest[i // 2]
hex_letter = address_hex[i]
if i % 2 == 0:
# even letter -> high nibble
bit = 0x80
else:
# odd letter -> low nibble
bit = 0x08
if digest_byte & bit:
return hex_letter.upper()
else:
return hex_letter
return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex)))
def bytes_from_address(address: str) -> bytes:

@ -9,9 +9,17 @@ from . import networks
from .address import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if False:
from trezor.messages import EthereumGetAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def get_address(ctx, msg, keychain):
async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
) -> EthereumAddress:
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)

@ -7,9 +7,17 @@ from apps.common import coins, paths
from .keychain import with_keychain_from_path
if False:
from trezor.messages import EthereumGetPublicKey
from trezor.wire import Context
from apps.common.keychain import Keychain
@with_keychain_from_path(paths.PATTERN_BIP44_PUBKEY)
async def get_public_key(ctx, msg, keychain):
async def get_public_key(
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
) -> EthereumPublicKey:
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)

@ -6,17 +6,31 @@ from apps.common.keychain import get_keychain
from . import CURVE, networks
if False:
from typing import Callable, Iterable
from typing_extensions import Protocol
from typing import Callable, Iterable, TypeVar, Union
from protobuf import MessageType
from trezor.messages import EthereumSignTx
from trezor.messages import (
EthereumGetAddress,
EthereumGetPublicKey,
EthereumSignMessage,
EthereumSignTx,
EthereumSignTxEIP1559,
)
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain
class MsgWithAddressN(MessageType, Protocol):
address_n: paths.Bip32Path
EthereumMessages = Union[
EthereumGetAddress,
EthereumGetPublicKey,
EthereumSignTx,
EthereumSignMessage,
]
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
EthereumSignTxAny = Union[
EthereumSignTx,
EthereumSignTxEIP1559,
]
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny)
# We believe Ethereum should use 44'/60'/a' for everything, because it is
@ -47,13 +61,9 @@ def _schemas_from_address_n(
def with_keychain_from_path(
*patterns: str,
) -> Callable[
[HandlerWithKeychain[MsgWithAddressN, MsgOut]], Handler[MsgWithAddressN, MsgOut]
]:
def decorator(
func: HandlerWithKeychain[MsgWithAddressN, MsgOut]
) -> Handler[MsgWithAddressN, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut:
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
@ -64,11 +74,9 @@ def with_keychain_from_path(
return decorator
def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
if msg.chain_id is None:
return _schemas_from_address_n(PATTERNS_ADDRESS, msg.address_n)
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]:
info = networks.by_chain_id(msg.chain_id)
slip44_id: tuple[int, ...]
if info is None:
# allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1)
@ -85,10 +93,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
def with_keychain_from_chain_id(
func: HandlerWithKeychain[EthereumSignTx, MsgOut]
) -> Handler[EthereumSignTx, MsgOut]:
func: HandlerWithKeychain[MsgInChainId, MsgOut]
) -> Handler[MsgInChainId, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut:
async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut:
schemas = _schemas_from_chain_id(msg)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:

@ -14,13 +14,24 @@ from trezor.ui.layouts.tt.altcoin import confirm_total_ethereum
from . import networks, tokens
from .address import address_from_bytes
if False:
from typing import Awaitable
async def require_confirm_tx(ctx, to_bytes, value, chain_id, token=None):
from trezor.wire import Context
def require_confirm_tx(
ctx: Context,
to_bytes: bytes,
value: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
) -> Awaitable[None]:
if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
else:
to_str = "new contract?"
await confirm_output(
return confirm_output(
ctx,
address=to_str,
amount=format_ethereum_amount(value, token, chain_id),
@ -30,10 +41,15 @@ async def require_confirm_tx(ctx, to_bytes, value, chain_id, token=None):
)
async def require_confirm_fee(
ctx, spending, gas_price, gas_limit, chain_id, token=None
):
await confirm_total_ethereum(
def require_confirm_fee(
ctx: Context,
spending: int,
gas_price: int,
gas_limit: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
) -> Awaitable[None]:
return confirm_total_ethereum(
ctx,
format_ethereum_amount(spending, token, chain_id),
format_ethereum_amount(gas_price, None, chain_id),
@ -42,8 +58,8 @@ async def require_confirm_fee(
async def require_confirm_eip1559_fee(
ctx, max_priority_fee, max_gas_fee, gas_limit, chain_id
):
ctx: Context, max_priority_fee: int, max_gas_fee: int, gas_limit: int, chain_id: int
) -> None:
await confirm_amount(
ctx,
title="Confirm fee",
@ -64,9 +80,11 @@ async def require_confirm_eip1559_fee(
)
async def require_confirm_unknown_token(ctx, address_bytes):
def require_confirm_unknown_token(
ctx: Context, address_bytes: bytes
) -> Awaitable[None]:
contract_address_hex = "0x" + hexlify(address_bytes).decode()
await confirm_address(
return confirm_address(
ctx,
"Unknown token",
contract_address_hex,
@ -77,8 +95,8 @@ async def require_confirm_unknown_token(ctx, address_bytes):
)
async def require_confirm_data(ctx, data, data_total):
await confirm_blob(
def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitable[None]:
return confirm_blob(
ctx,
"confirm_data",
title="Confirm data",
@ -88,13 +106,6 @@ async def require_confirm_data(ctx, data, data_total):
)
def format_ethereum_amount(value: int, token, chain_id: int):
if token is tokens.UNKNOWN_TOKEN:
suffix = "Wei UNKN"
decimals = 0
elif token:
suffix = token[2]
decimals = token[3]
def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int
) -> str:

@ -1,8 +1,6 @@
# generated from networks.py.mako
# do not edit manually!
from micropython import const
from apps.common.paths import HARDENED
if False:

@ -1,8 +1,6 @@
# generated from networks.py.mako
# do not edit manually!
from micropython import const
from apps.common.paths import HARDENED
if False:

@ -10,18 +10,26 @@ from apps.common.signverify import decode_message
from . import address
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if False:
from trezor.messages import EthereumSignMessage
from trezor.wire import Context
def message_digest(message):
from apps.common.keychain import Keychain
def message_digest(message: bytes) -> bytes:
h = HashWriter(sha3_256(keccak=True))
signed_message_header = "\x19Ethereum Signed Message:\n"
signed_message_header = b"\x19Ethereum Signed Message:\n"
h.extend(signed_message_header)
h.extend(str(len(message)))
h.extend(str(len(message)).encode())
h.extend(message)
return h.get_digest()
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_message(ctx, msg, keychain):
async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
) -> EthereumMessageSignature:
await paths.validate_path(ctx, keychain, msg.address_n)
await confirm_signverify(ctx, "ETH", decode_message(msg.message))

@ -16,6 +16,14 @@ from .layout import (
require_confirm_unknown_token,
)
if False:
from typing import Tuple
from apps.common.keychain import Keychain
from .keychain import EthereumSignTxAny
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
# chain_ids larger than this will only return one bit and the caller must recalculate
# the full value: v = 2 * chain_id + 35 + v_bit
@ -23,9 +31,9 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id
async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg)
async def sign_tx(
ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain
) -> EthereumTxRequest:
check(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
@ -84,7 +92,9 @@ async def sign_tx(ctx, msg, keychain):
return result
async def handle_erc20(ctx, msg):
async def handle_erc20(
ctx: wire.Context, msg: EthereumSignTxAny
) -> Tuple[tokens.TokenInfo | None, bytes, bytes, int]:
token = None
address_bytes = recipient = address.bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big")
@ -111,7 +121,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
if msg.tx_type is not None:
length += rlp.length(msg.tx_type)
for item in (
fields: Tuple[rlp.RLPItem, ...] = (
msg.nonce,
msg.gas_price,
msg.gas_limit,
@ -120,8 +130,10 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
msg.chain_id,
0,
0,
):
length += rlp.length(item)
)
for field in fields:
length += rlp.length(field)
length += rlp.header_length(data_total, msg.data_initial_chunk)
length += data_total
@ -129,7 +141,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
return length
async def send_request_chunk(ctx, data_left: int):
async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck:
# TODO: layoutProgress ?
req = EthereumTxRequest()
if data_left <= 1024:
@ -140,7 +152,9 @@ async def send_request_chunk(ctx, data_left: int):
return await ctx.call(req, EthereumTxAck)
def sign_digest(msg: EthereumSignTx, keychain, digest):
def sign_digest(
msg: EthereumSignTx, keychain: Keychain, digest: bytes
) -> EthereumTxRequest:
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -159,7 +173,7 @@ def sign_digest(msg: EthereumSignTx, keychain, digest):
return req
def check(msg: EthereumSignTx):
def check(msg: EthereumSignTx) -> None:
if msg.tx_type not in [1, 6, None]:
raise wire.DataError("tx_type out of bounds")
@ -170,7 +184,7 @@ def check(msg: EthereumSignTx):
raise wire.DataError("Safety check failed")
def check_data(msg: EthereumSignTx):
def check_data(msg: EthereumSignTxAny) -> None:
if msg.data_length > 0:
if not msg.data_initial_chunk:
raise wire.DataError("Data length provided, but no initial chunk")
@ -183,15 +197,13 @@ def check_data(msg: EthereumSignTx):
def check_gas(msg: EthereumSignTx) -> bool:
if msg.gas_price is None or msg.gas_limit is None:
return False
if len(msg.gas_price) + len(msg.gas_limit) > 30:
# sanity check that fee doesn't overflow
return False
return True
def check_to(msg: EthereumTxRequest) -> bool:
def check_to(msg: EthereumSignTxAny) -> bool:
if msg.to == "":
if msg.data_length == 0:
# sending transaction to address 0 (contract creation) without a data field
@ -200,17 +212,3 @@ def check_to(msg: EthereumTxRequest) -> bool:
if len(msg.to) not in (40, 42):
return False
return True
def sanitize(msg):
if msg.value is None:
msg.value = b""
if msg.data_initial_chunk is None:
msg.data_initial_chunk = b""
if msg.data_length is None:
msg.data_length = 0
if msg.to is None:
msg.to = ""
if msg.nonce is None:
msg.nonce = b""
return msg

@ -2,7 +2,7 @@ from trezor import wire
from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumAccessList, EthereumSignTxEIP1559, EthereumTxRequest
from trezor.messages import EthereumAccessList, EthereumTxRequest
from trezor.utils import HashWriter
from apps.common import paths
@ -14,7 +14,14 @@ from .layout import (
require_confirm_eip1559_fee,
require_confirm_tx,
)
from .sign_tx import check_data, check_to, handle_erc20, sanitize, send_request_chunk
from .sign_tx import check_data, check_to, handle_erc20, send_request_chunk
if False:
from typing import Tuple
from trezor.messages import EthereumSignTxEIP1559
from apps.common.keychain import Keychain
TX_TYPE = 2
@ -45,9 +52,9 @@ def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> N
@with_keychain_from_chain_id
async def sign_tx_eip1559(ctx, msg, keychain):
msg = sanitize(msg)
async def sign_tx_eip1559(
ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain
) -> EthereumTxRequest:
check(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
@ -81,7 +88,7 @@ async def sign_tx_eip1559(ctx, msg, keychain):
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
for field in (
fields: Tuple[rlp.RLPItem, ...] = (
msg.chain_id,
msg.nonce,
msg.max_priority_fee,
@ -89,7 +96,8 @@ async def sign_tx_eip1559(ctx, msg, keychain):
msg.gas_limit,
address_bytes,
msg.value,
):
)
for field in fields:
rlp.write(sha, field)
if data_left == 0:
@ -114,7 +122,7 @@ async def sign_tx_eip1559(ctx, msg, keychain):
def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
length = 0
for item in (
fields: Tuple[rlp.RLPItem, ...] = (
msg.nonce,
msg.gas_limit,
address.bytes_from_address(msg.to),
@ -122,8 +130,9 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
msg.chain_id,
msg.max_gas_fee,
msg.max_priority_fee,
):
length += rlp.length(item)
)
for field in fields:
length += rlp.length(field)
length += rlp.header_length(data_total, msg.data_initial_chunk)
length += data_total
@ -133,7 +142,9 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
return length
def sign_digest(msg: EthereumSignTxEIP1559, keychain, digest):
def sign_digest(
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
) -> EthereumTxRequest:
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -147,7 +158,7 @@ def sign_digest(msg: EthereumSignTxEIP1559, keychain, digest):
return req
def check(msg: EthereumSignTxEIP1559):
def check(msg: EthereumSignTxEIP1559) -> None:
check_data(msg)
if not check_to(msg):

@ -9,8 +9,12 @@ from apps.common.signverify import decode_message
from .address import address_from_bytes, bytes_from_address
from .sign_message import message_digest
if False:
from trezor.messages import EthereumVerifyMessage
from trezor.wire import Context
async def verify_message(ctx, msg):
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
digest = message_digest(msg.message)
if len(msg.signature) != 65:
raise wire.DataError("Invalid signature")

@ -22,10 +22,9 @@ class TestEthereumGetAddress(unittest.TestCase):
'0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb',
]
for s in eip55:
s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
b = unhexlify(s[2:])
h = address_from_bytes(b)
self.assertEqual(h, '0x' + s)
self.assertEqual(h, s)
def test_address_from_bytes_rskip60(self):
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
@ -43,16 +42,14 @@ class TestEthereumGetAddress(unittest.TestCase):
]
n = NetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True)
for s in rskip60_chain_30:
s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
b = unhexlify(s[2:])
h = address_from_bytes(b, n)
self.assertEqual(h, '0x' + s)
self.assertEqual(h, s)
n.chain_id = 31
for s in rskip60_chain_31:
s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
b = unhexlify(s[2:])
h = address_from_bytes(b, n)
self.assertEqual(h, '0x' + s)
self.assertEqual(h, s)
if __name__ == '__main__':

@ -118,6 +118,8 @@ class TestEthereumKeychain(unittest.TestCase):
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
chain_id=1,
gas_price=b"",
gas_limit=b"",
),
)
)
@ -128,6 +130,8 @@ class TestEthereumKeychain(unittest.TestCase):
EthereumSignTx(
address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
chain_id=61,
gas_price=b"",
gas_limit=b"",
),
)
)
@ -140,6 +144,8 @@ class TestEthereumKeychain(unittest.TestCase):
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
chain_id=61,
gas_price=b"",
gas_limit=b"",
),
)
)
@ -151,6 +157,8 @@ class TestEthereumKeychain(unittest.TestCase):
EthereumSignTx(
address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
chain_id=2,
gas_price=b"",
gas_limit=b"",
),
)
)

@ -6,7 +6,7 @@ if not utils.BITCOIN_ONLY:
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumLayout(unittest.TestCase):
class TestFormatEthereumAmount(unittest.TestCase):
def test_format(self):
text = format_ethereum_amount(1, None, 1)
@ -60,12 +60,14 @@ class TestEthereumLayout(unittest.TestCase):
text = format_ethereum_amount(1000000000000000001, None, 31)
self.assertEqual(text, '1.000000000000000001 tRBTC')
def test_unknown_chain(self):
# unknown chain
text = format_ethereum_amount(1, None, 9999)
self.assertEqual(text, '1 Wei UNKN')
text = format_ethereum_amount(10000000000000000001, None, 9999)
self.assertEqual(text, '10.000000000000000001 UNKN')
def test_tokens(self):
# tokens with low decimal values
# USDC has 6 decimals
usdc_token = token_by_chain_address(1, unhexlify("a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48"))
@ -87,6 +89,16 @@ class TestEthereumLayout(unittest.TestCase):
text = format_ethereum_amount(11, ico_token, 1)
self.assertEqual(text, '0.0000000011 ICO')
def test_unknown_token(self):
unknown_token = token_by_chain_address(1, b"hello")
text = format_ethereum_amount(1, unknown_token, 1)
self.assertEqual(text, '1 Wei UNKN')
text = format_ethereum_amount(0, unknown_token, 1)
self.assertEqual(text, '0 Wei UNKN')
# unknown token has 0 decimals so is always wei
text = format_ethereum_amount(1000000000000000000, unknown_token, 1)
self.assertEqual(text, '1000000000000000000 Wei UNKN')
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save