1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 15:38:11 +00:00

eth: modify sign/verify functions to accept addresses as strings

This commit is contained in:
Tomas Susanka 2019-01-28 15:44:18 +01:00
parent 3387b157a7
commit 1946a9f93e
5 changed files with 37 additions and 17 deletions

View File

@ -1,3 +1,7 @@
from ubinascii import unhexlify
from trezor import wire
from apps.common import HARDENED, paths from apps.common import HARDENED, paths
from apps.ethereum import networks from apps.ethereum import networks
@ -51,13 +55,17 @@ def validate_full_path(path: list) -> bool:
return True return True
def ethereum_address_hex(address, network=None): def address_from_bytes(address_bytes: bytes, network=None) -> str:
"""
Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
"""
from ubinascii import hexlify from ubinascii import hexlify
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
rskip60 = network is not None and network.rskip60 rskip60 = network is not None and network.rskip60
hx = hexlify(address).decode() hx = hexlify(address_bytes).decode()
prefix = str(network.chain_id) + "0x" if rskip60 else "" prefix = str(network.chain_id) + "0x" if rskip60 else ""
hs = sha3_256(prefix + hx, keccak=True).digest() hs = sha3_256(prefix + hx, keccak=True).digest()
@ -74,3 +82,15 @@ def ethereum_address_hex(address, network=None):
h += l h += l
return "0x" + h return "0x" + h
def bytes_from_address(address: str, network=None) -> bytes:
if len(address) == 40:
return unhexlify(address)
elif len(address) == 42:
if address[0:2] not in ("0x", "0X"):
raise wire.ProcessError("Ethereum: invalid beginning of an address")
return unhexlify(address[2:])
raise wire.ProcessError("Ethereum: Invalid address length")

View File

@ -5,7 +5,7 @@ from trezor.messages.EthereumAddress import EthereumAddress
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ethereum import networks from apps.ethereum import networks
from apps.ethereum.address import ethereum_address_hex, validate_full_path from apps.ethereum.address import address_from_bytes, validate_full_path
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
@ -20,7 +20,7 @@ async def get_address(ctx, msg, keychain):
network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF) network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF)
else: else:
network = None network = None
address = ethereum_address_hex(address_bytes, network) address = address_from_bytes(address_bytes, network)
if msg.show_display: if msg.show_display:
desc = address_n_to_str(msg.address_n) desc = address_n_to_str(msg.address_n)

View File

@ -7,7 +7,7 @@ from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.signverify import split_message from apps.common.signverify import split_message
from apps.ethereum.address import validate_full_path from apps.ethereum import address
def message_digest(message): def message_digest(message):
@ -20,7 +20,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
await require_confirm_sign_message(ctx, msg.message) await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
@ -32,7 +32,7 @@ async def sign_message(ctx, msg, keychain):
) )
sig = EthereumMessageSignature() sig = EthereumMessageSignature()
sig.address = node.ethereum_pubkeyhash() sig.address = address.address_from_bytes(node.ethereum_pubkeyhash())
sig.signature = signature[1:] + bytearray([signature[0]]) sig.signature = signature[1:] + bytearray([signature[0]])
return sig return sig

View File

@ -1,5 +1,3 @@
from ubinascii import hexlify
from trezor import wire from trezor import wire
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
@ -9,6 +7,7 @@ from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.layout import split_address from apps.common.layout import split_address
from apps.common.signverify import split_message from apps.common.signverify import split_message
from apps.ethereum.address import address_from_bytes, bytes_from_address
from apps.ethereum.sign_message import message_digest from apps.ethereum.sign_message import message_digest
@ -28,10 +27,11 @@ async def verify_message(ctx, msg):
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:] pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
if msg.address != pkh: address_bytes = bytes_from_address(msg.address)
if address_bytes != pkh:
raise wire.DataError("Invalid signature") raise wire.DataError("Invalid signature")
address = "0x" + hexlify(msg.address).decode() address = address_from_bytes(address_bytes)
await require_confirm_verify_message(ctx, address, msg.message) await require_confirm_verify_message(ctx, address, msg.message)

View File

@ -1,12 +1,12 @@
from common import * from common import *
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from apps.ethereum.address import ethereum_address_hex, validate_full_path from apps.ethereum.address import address_from_bytes, bytes_from_address, validate_full_path
from apps.ethereum.networks import NetworkInfo, by_chain_id from apps.ethereum.networks import NetworkInfo, by_chain_id
class TestEthereumGetAddress(unittest.TestCase): class TestEthereumGetAddress(unittest.TestCase):
def test_ethereum_address_hex_eip55(self): def test_address_from_bytes_eip55(self):
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
eip55 = [ eip55 = [
'0x52908400098527886E0F7030069857D2E4169EE7', '0x52908400098527886E0F7030069857D2E4169EE7',
@ -21,10 +21,10 @@ class TestEthereumGetAddress(unittest.TestCase):
for s in eip55: for s in eip55:
s = s[2:] s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
h = ethereum_address_hex(b) h = address_from_bytes(b)
self.assertEqual(h, '0x' + s) self.assertEqual(h, '0x' + s)
def test_ethereum_address_hex_rskip60(self): def test_address_from_bytes_rskip60(self):
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md # https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
rskip60_chain_30 = [ rskip60_chain_30 = [
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD', '0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD',
@ -42,13 +42,13 @@ class TestEthereumGetAddress(unittest.TestCase):
for s in rskip60_chain_30: for s in rskip60_chain_30:
s = s[2:] s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
h = ethereum_address_hex(b, n) h = address_from_bytes(b, n)
self.assertEqual(h, '0x' + s) self.assertEqual(h, '0x' + s)
n.chain_id = 31 n.chain_id = 31
for s in rskip60_chain_31: for s in rskip60_chain_31:
s = s[2:] s = s[2:]
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
h = ethereum_address_hex(b, n) h = address_from_bytes(b, n)
self.assertEqual(h, '0x' + s) self.assertEqual(h, '0x' + s)
def test_paths(self): def test_paths(self):