mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
eth: modify sign/verify functions to accept addresses as strings
This commit is contained in:
parent
3387b157a7
commit
1946a9f93e
@ -1,3 +1,7 @@
|
||||
from ubinascii import unhexlify
|
||||
|
||||
from trezor import wire
|
||||
|
||||
from apps.common import HARDENED, paths
|
||||
from apps.ethereum import networks
|
||||
|
||||
@ -51,13 +55,17 @@ def validate_full_path(path: list) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def ethereum_address_hex(address, network=None):
|
||||
def address_from_bytes(address_bytes: bytes, network=None) -> str:
|
||||
"""
|
||||
Converts address in bytes to a checksummed string as defined
|
||||
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
||||
"""
|
||||
from ubinascii import hexlify
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
|
||||
rskip60 = network is not None and network.rskip60
|
||||
|
||||
hx = hexlify(address).decode()
|
||||
hx = hexlify(address_bytes).decode()
|
||||
|
||||
prefix = str(network.chain_id) + "0x" if rskip60 else ""
|
||||
hs = sha3_256(prefix + hx, keccak=True).digest()
|
||||
@ -74,3 +82,15 @@ def ethereum_address_hex(address, network=None):
|
||||
h += l
|
||||
|
||||
return "0x" + h
|
||||
|
||||
|
||||
def bytes_from_address(address: str, network=None) -> bytes:
|
||||
if len(address) == 40:
|
||||
return unhexlify(address)
|
||||
|
||||
elif len(address) == 42:
|
||||
if address[0:2] not in ("0x", "0X"):
|
||||
raise wire.ProcessError("Ethereum: invalid beginning of an address")
|
||||
return unhexlify(address[2:])
|
||||
|
||||
raise wire.ProcessError("Ethereum: Invalid address length")
|
||||
|
@ -5,7 +5,7 @@ from trezor.messages.EthereumAddress import EthereumAddress
|
||||
from apps.common import paths
|
||||
from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||
from apps.ethereum import networks
|
||||
from apps.ethereum.address import ethereum_address_hex, validate_full_path
|
||||
from apps.ethereum.address import address_from_bytes, validate_full_path
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
@ -20,7 +20,7 @@ async def get_address(ctx, msg, keychain):
|
||||
network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF)
|
||||
else:
|
||||
network = None
|
||||
address = ethereum_address_hex(address_bytes, network)
|
||||
address = address_from_bytes(address_bytes, network)
|
||||
|
||||
if msg.show_display:
|
||||
desc = address_n_to_str(msg.address_n)
|
||||
|
@ -7,7 +7,7 @@ from trezor.utils import HashWriter
|
||||
from apps.common import paths
|
||||
from apps.common.confirm import require_confirm
|
||||
from apps.common.signverify import split_message
|
||||
from apps.ethereum.address import validate_full_path
|
||||
from apps.ethereum import address
|
||||
|
||||
|
||||
def message_digest(message):
|
||||
@ -20,7 +20,7 @@ def message_digest(message):
|
||||
|
||||
|
||||
async def sign_message(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
|
||||
await require_confirm_sign_message(ctx, msg.message)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
@ -32,7 +32,7 @@ async def sign_message(ctx, msg, keychain):
|
||||
)
|
||||
|
||||
sig = EthereumMessageSignature()
|
||||
sig.address = node.ethereum_pubkeyhash()
|
||||
sig.address = address.address_from_bytes(node.ethereum_pubkeyhash())
|
||||
sig.signature = signature[1:] + bytearray([signature[0]])
|
||||
return sig
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
from ubinascii import hexlify
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
@ -9,6 +7,7 @@ from trezor.ui.text import Text
|
||||
from apps.common.confirm import require_confirm
|
||||
from apps.common.layout import split_address
|
||||
from apps.common.signverify import split_message
|
||||
from apps.ethereum.address import address_from_bytes, bytes_from_address
|
||||
from apps.ethereum.sign_message import message_digest
|
||||
|
||||
|
||||
@ -28,10 +27,11 @@ async def verify_message(ctx, msg):
|
||||
|
||||
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
||||
|
||||
if msg.address != pkh:
|
||||
address_bytes = bytes_from_address(msg.address)
|
||||
if address_bytes != pkh:
|
||||
raise wire.DataError("Invalid signature")
|
||||
|
||||
address = "0x" + hexlify(msg.address).decode()
|
||||
address = address_from_bytes(address_bytes)
|
||||
|
||||
await require_confirm_verify_message(ctx, address, msg.message)
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
from common import *
|
||||
from apps.common.paths import HARDENED
|
||||
from apps.ethereum.address import ethereum_address_hex, validate_full_path
|
||||
from apps.ethereum.address import address_from_bytes, bytes_from_address, validate_full_path
|
||||
from apps.ethereum.networks import NetworkInfo, by_chain_id
|
||||
|
||||
|
||||
class TestEthereumGetAddress(unittest.TestCase):
|
||||
|
||||
def test_ethereum_address_hex_eip55(self):
|
||||
def test_address_from_bytes_eip55(self):
|
||||
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
||||
eip55 = [
|
||||
'0x52908400098527886E0F7030069857D2E4169EE7',
|
||||
@ -21,10 +21,10 @@ class TestEthereumGetAddress(unittest.TestCase):
|
||||
for s in eip55:
|
||||
s = s[2:]
|
||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||
h = ethereum_address_hex(b)
|
||||
h = address_from_bytes(b)
|
||||
self.assertEqual(h, '0x' + s)
|
||||
|
||||
def test_ethereum_address_hex_rskip60(self):
|
||||
def test_address_from_bytes_rskip60(self):
|
||||
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
|
||||
rskip60_chain_30 = [
|
||||
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD',
|
||||
@ -42,13 +42,13 @@ class TestEthereumGetAddress(unittest.TestCase):
|
||||
for s in rskip60_chain_30:
|
||||
s = s[2:]
|
||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||
h = ethereum_address_hex(b, n)
|
||||
h = address_from_bytes(b, n)
|
||||
self.assertEqual(h, '0x' + s)
|
||||
n.chain_id = 31
|
||||
for s in rskip60_chain_31:
|
||||
s = s[2:]
|
||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||
h = ethereum_address_hex(b, n)
|
||||
h = address_from_bytes(b, n)
|
||||
self.assertEqual(h, '0x' + s)
|
||||
|
||||
def test_paths(self):
|
||||
|
Loading…
Reference in New Issue
Block a user