diff --git a/src/apps/ethereum/address.py b/src/apps/ethereum/address.py index d853256c0d..0cd4caa3bc 100644 --- a/src/apps/ethereum/address.py +++ b/src/apps/ethereum/address.py @@ -1,3 +1,7 @@ +from ubinascii import unhexlify + +from trezor import wire + from apps.common import HARDENED, paths from apps.ethereum import networks @@ -51,13 +55,17 @@ def validate_full_path(path: list) -> bool: return True -def ethereum_address_hex(address, network=None): +def address_from_bytes(address_bytes: bytes, network=None) -> str: + """ + Converts address in bytes to a checksummed string as defined + in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md + """ from ubinascii import hexlify from trezor.crypto.hashlib import sha3_256 rskip60 = network is not None and network.rskip60 - hx = hexlify(address).decode() + hx = hexlify(address_bytes).decode() prefix = str(network.chain_id) + "0x" if rskip60 else "" hs = sha3_256(prefix + hx, keccak=True).digest() @@ -74,3 +82,15 @@ def ethereum_address_hex(address, network=None): h += l return "0x" + h + + +def bytes_from_address(address: str, network=None) -> bytes: + if len(address) == 40: + return unhexlify(address) + + elif len(address) == 42: + if address[0:2] not in ("0x", "0X"): + raise wire.ProcessError("Ethereum: invalid beginning of an address") + return unhexlify(address[2:]) + + raise wire.ProcessError("Ethereum: Invalid address length") diff --git a/src/apps/ethereum/get_address.py b/src/apps/ethereum/get_address.py index 6531905f72..f2a3adac6d 100644 --- a/src/apps/ethereum/get_address.py +++ b/src/apps/ethereum/get_address.py @@ -5,7 +5,7 @@ from trezor.messages.EthereumAddress import EthereumAddress from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr from apps.ethereum import networks -from apps.ethereum.address import ethereum_address_hex, validate_full_path +from apps.ethereum.address import address_from_bytes, validate_full_path async def get_address(ctx, msg, keychain): @@ -20,7 +20,7 @@ async def get_address(ctx, msg, keychain): network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF) else: network = None - address = ethereum_address_hex(address_bytes, network) + address = address_from_bytes(address_bytes, network) if msg.show_display: desc = address_n_to_str(msg.address_n) diff --git a/src/apps/ethereum/sign_message.py b/src/apps/ethereum/sign_message.py index ed6fe8dbe6..54b85320a3 100644 --- a/src/apps/ethereum/sign_message.py +++ b/src/apps/ethereum/sign_message.py @@ -7,7 +7,7 @@ from trezor.utils import HashWriter from apps.common import paths from apps.common.confirm import require_confirm from apps.common.signverify import split_message -from apps.ethereum.address import validate_full_path +from apps.ethereum import address def message_digest(message): @@ -20,7 +20,7 @@ def message_digest(message): async def sign_message(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n) await require_confirm_sign_message(ctx, msg.message) node = keychain.derive(msg.address_n) @@ -32,7 +32,7 @@ async def sign_message(ctx, msg, keychain): ) sig = EthereumMessageSignature() - sig.address = node.ethereum_pubkeyhash() + sig.address = address.address_from_bytes(node.ethereum_pubkeyhash()) sig.signature = signature[1:] + bytearray([signature[0]]) return sig diff --git a/src/apps/ethereum/verify_message.py b/src/apps/ethereum/verify_message.py index 790235b567..3750512f97 100644 --- a/src/apps/ethereum/verify_message.py +++ b/src/apps/ethereum/verify_message.py @@ -1,5 +1,3 @@ -from ubinascii import hexlify - from trezor import wire from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import sha3_256 @@ -9,6 +7,7 @@ from trezor.ui.text import Text from apps.common.confirm import require_confirm from apps.common.layout import split_address from apps.common.signverify import split_message +from apps.ethereum.address import address_from_bytes, bytes_from_address from apps.ethereum.sign_message import message_digest @@ -28,10 +27,11 @@ async def verify_message(ctx, msg): pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:] - if msg.address != pkh: + address_bytes = bytes_from_address(msg.address) + if address_bytes != pkh: raise wire.DataError("Invalid signature") - address = "0x" + hexlify(msg.address).decode() + address = address_from_bytes(address_bytes) await require_confirm_verify_message(ctx, address, msg.message) diff --git a/tests/test_apps.ethereum.address.py b/tests/test_apps.ethereum.address.py index 8b9d3ad51b..e4bf1829db 100644 --- a/tests/test_apps.ethereum.address.py +++ b/tests/test_apps.ethereum.address.py @@ -1,12 +1,12 @@ from common import * from apps.common.paths import HARDENED -from apps.ethereum.address import ethereum_address_hex, validate_full_path +from apps.ethereum.address import address_from_bytes, bytes_from_address, validate_full_path from apps.ethereum.networks import NetworkInfo, by_chain_id class TestEthereumGetAddress(unittest.TestCase): - def test_ethereum_address_hex_eip55(self): + def test_address_from_bytes_eip55(self): # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md eip55 = [ '0x52908400098527886E0F7030069857D2E4169EE7', @@ -21,10 +21,10 @@ class TestEthereumGetAddress(unittest.TestCase): for s in eip55: s = s[2:] b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) - h = ethereum_address_hex(b) + h = address_from_bytes(b) self.assertEqual(h, '0x' + s) - def test_ethereum_address_hex_rskip60(self): + def test_address_from_bytes_rskip60(self): # https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md rskip60_chain_30 = [ '0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD', @@ -42,13 +42,13 @@ class TestEthereumGetAddress(unittest.TestCase): for s in rskip60_chain_30: s = s[2:] b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) - h = ethereum_address_hex(b, n) + h = address_from_bytes(b, n) self.assertEqual(h, '0x' + s) n.chain_id = 31 for s in rskip60_chain_31: s = s[2:] b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) - h = ethereum_address_hex(b, n) + h = address_from_bytes(b, n) self.assertEqual(h, '0x' + s) def test_paths(self):