mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 15:38:11 +00:00
eth: modify sign/verify functions to accept addresses as strings
This commit is contained in:
parent
3387b157a7
commit
1946a9f93e
@ -1,3 +1,7 @@
|
|||||||
|
from ubinascii import unhexlify
|
||||||
|
|
||||||
|
from trezor import wire
|
||||||
|
|
||||||
from apps.common import HARDENED, paths
|
from apps.common import HARDENED, paths
|
||||||
from apps.ethereum import networks
|
from apps.ethereum import networks
|
||||||
|
|
||||||
@ -51,13 +55,17 @@ def validate_full_path(path: list) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def ethereum_address_hex(address, network=None):
|
def address_from_bytes(address_bytes: bytes, network=None) -> str:
|
||||||
|
"""
|
||||||
|
Converts address in bytes to a checksummed string as defined
|
||||||
|
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
||||||
|
"""
|
||||||
from ubinascii import hexlify
|
from ubinascii import hexlify
|
||||||
from trezor.crypto.hashlib import sha3_256
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
|
||||||
rskip60 = network is not None and network.rskip60
|
rskip60 = network is not None and network.rskip60
|
||||||
|
|
||||||
hx = hexlify(address).decode()
|
hx = hexlify(address_bytes).decode()
|
||||||
|
|
||||||
prefix = str(network.chain_id) + "0x" if rskip60 else ""
|
prefix = str(network.chain_id) + "0x" if rskip60 else ""
|
||||||
hs = sha3_256(prefix + hx, keccak=True).digest()
|
hs = sha3_256(prefix + hx, keccak=True).digest()
|
||||||
@ -74,3 +82,15 @@ def ethereum_address_hex(address, network=None):
|
|||||||
h += l
|
h += l
|
||||||
|
|
||||||
return "0x" + h
|
return "0x" + h
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_from_address(address: str, network=None) -> bytes:
|
||||||
|
if len(address) == 40:
|
||||||
|
return unhexlify(address)
|
||||||
|
|
||||||
|
elif len(address) == 42:
|
||||||
|
if address[0:2] not in ("0x", "0X"):
|
||||||
|
raise wire.ProcessError("Ethereum: invalid beginning of an address")
|
||||||
|
return unhexlify(address[2:])
|
||||||
|
|
||||||
|
raise wire.ProcessError("Ethereum: Invalid address length")
|
||||||
|
@ -5,7 +5,7 @@ from trezor.messages.EthereumAddress import EthereumAddress
|
|||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
from apps.common.layout import address_n_to_str, show_address, show_qr
|
from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||||
from apps.ethereum import networks
|
from apps.ethereum import networks
|
||||||
from apps.ethereum.address import ethereum_address_hex, validate_full_path
|
from apps.ethereum.address import address_from_bytes, validate_full_path
|
||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
@ -20,7 +20,7 @@ async def get_address(ctx, msg, keychain):
|
|||||||
network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF)
|
network = networks.by_slip44(msg.address_n[1] & 0x7FFFFFFF)
|
||||||
else:
|
else:
|
||||||
network = None
|
network = None
|
||||||
address = ethereum_address_hex(address_bytes, network)
|
address = address_from_bytes(address_bytes, network)
|
||||||
|
|
||||||
if msg.show_display:
|
if msg.show_display:
|
||||||
desc = address_n_to_str(msg.address_n)
|
desc = address_n_to_str(msg.address_n)
|
||||||
|
@ -7,7 +7,7 @@ from trezor.utils import HashWriter
|
|||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from apps.common.signverify import split_message
|
from apps.common.signverify import split_message
|
||||||
from apps.ethereum.address import validate_full_path
|
from apps.ethereum import address
|
||||||
|
|
||||||
|
|
||||||
def message_digest(message):
|
def message_digest(message):
|
||||||
@ -20,7 +20,7 @@ def message_digest(message):
|
|||||||
|
|
||||||
|
|
||||||
async def sign_message(ctx, msg, keychain):
|
async def sign_message(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
|
||||||
await require_confirm_sign_message(ctx, msg.message)
|
await require_confirm_sign_message(ctx, msg.message)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
@ -32,7 +32,7 @@ async def sign_message(ctx, msg, keychain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
sig = EthereumMessageSignature()
|
sig = EthereumMessageSignature()
|
||||||
sig.address = node.ethereum_pubkeyhash()
|
sig.address = address.address_from_bytes(node.ethereum_pubkeyhash())
|
||||||
sig.signature = signature[1:] + bytearray([signature[0]])
|
sig.signature = signature[1:] + bytearray([signature[0]])
|
||||||
return sig
|
return sig
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from ubinascii import hexlify
|
|
||||||
|
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from trezor.crypto.hashlib import sha3_256
|
from trezor.crypto.hashlib import sha3_256
|
||||||
@ -9,6 +7,7 @@ from trezor.ui.text import Text
|
|||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from apps.common.layout import split_address
|
from apps.common.layout import split_address
|
||||||
from apps.common.signverify import split_message
|
from apps.common.signverify import split_message
|
||||||
|
from apps.ethereum.address import address_from_bytes, bytes_from_address
|
||||||
from apps.ethereum.sign_message import message_digest
|
from apps.ethereum.sign_message import message_digest
|
||||||
|
|
||||||
|
|
||||||
@ -28,10 +27,11 @@ async def verify_message(ctx, msg):
|
|||||||
|
|
||||||
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
||||||
|
|
||||||
if msg.address != pkh:
|
address_bytes = bytes_from_address(msg.address)
|
||||||
|
if address_bytes != pkh:
|
||||||
raise wire.DataError("Invalid signature")
|
raise wire.DataError("Invalid signature")
|
||||||
|
|
||||||
address = "0x" + hexlify(msg.address).decode()
|
address = address_from_bytes(address_bytes)
|
||||||
|
|
||||||
await require_confirm_verify_message(ctx, address, msg.message)
|
await require_confirm_verify_message(ctx, address, msg.message)
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from common import *
|
from common import *
|
||||||
from apps.common.paths import HARDENED
|
from apps.common.paths import HARDENED
|
||||||
from apps.ethereum.address import ethereum_address_hex, validate_full_path
|
from apps.ethereum.address import address_from_bytes, bytes_from_address, validate_full_path
|
||||||
from apps.ethereum.networks import NetworkInfo, by_chain_id
|
from apps.ethereum.networks import NetworkInfo, by_chain_id
|
||||||
|
|
||||||
|
|
||||||
class TestEthereumGetAddress(unittest.TestCase):
|
class TestEthereumGetAddress(unittest.TestCase):
|
||||||
|
|
||||||
def test_ethereum_address_hex_eip55(self):
|
def test_address_from_bytes_eip55(self):
|
||||||
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
||||||
eip55 = [
|
eip55 = [
|
||||||
'0x52908400098527886E0F7030069857D2E4169EE7',
|
'0x52908400098527886E0F7030069857D2E4169EE7',
|
||||||
@ -21,10 +21,10 @@ class TestEthereumGetAddress(unittest.TestCase):
|
|||||||
for s in eip55:
|
for s in eip55:
|
||||||
s = s[2:]
|
s = s[2:]
|
||||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||||
h = ethereum_address_hex(b)
|
h = address_from_bytes(b)
|
||||||
self.assertEqual(h, '0x' + s)
|
self.assertEqual(h, '0x' + s)
|
||||||
|
|
||||||
def test_ethereum_address_hex_rskip60(self):
|
def test_address_from_bytes_rskip60(self):
|
||||||
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
|
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
|
||||||
rskip60_chain_30 = [
|
rskip60_chain_30 = [
|
||||||
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD',
|
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD',
|
||||||
@ -42,13 +42,13 @@ class TestEthereumGetAddress(unittest.TestCase):
|
|||||||
for s in rskip60_chain_30:
|
for s in rskip60_chain_30:
|
||||||
s = s[2:]
|
s = s[2:]
|
||||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||||
h = ethereum_address_hex(b, n)
|
h = address_from_bytes(b, n)
|
||||||
self.assertEqual(h, '0x' + s)
|
self.assertEqual(h, '0x' + s)
|
||||||
n.chain_id = 31
|
n.chain_id = 31
|
||||||
for s in rskip60_chain_31:
|
for s in rskip60_chain_31:
|
||||||
s = s[2:]
|
s = s[2:]
|
||||||
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)])
|
||||||
h = ethereum_address_hex(b, n)
|
h = address_from_bytes(b, n)
|
||||||
self.assertEqual(h, '0x' + s)
|
self.assertEqual(h, '0x' + s)
|
||||||
|
|
||||||
def test_paths(self):
|
def test_paths(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user