eth: modify sign_tx to accept strings as well

pull/25/head
Tomas Susanka 5 years ago
parent 1946a9f93e
commit 4e9ef09798

@ -84,7 +84,7 @@ def address_from_bytes(address_bytes: bytes, network=None) -> str:
return "0x" + h return "0x" + h
def bytes_from_address(address: str, network=None) -> bytes: def bytes_from_address(address: str) -> bytes:
if len(address) == 40: if len(address) == 40:
return unhexlify(address) return unhexlify(address)
@ -93,4 +93,7 @@ def bytes_from_address(address: str, network=None) -> bytes:
raise wire.ProcessError("Ethereum: invalid beginning of an address") raise wire.ProcessError("Ethereum: invalid beginning of an address")
return unhexlify(address[2:]) return unhexlify(address[2:])
elif len(address) == 0:
return bytes()
raise wire.ProcessError("Ethereum: Invalid address length") raise wire.ProcessError("Ethereum: Invalid address length")

@ -8,12 +8,12 @@ from trezor.utils import chunks, format_amount
from apps.common.confirm import require_confirm, require_hold_to_confirm from apps.common.confirm import require_confirm, require_hold_to_confirm
from apps.common.layout import split_address from apps.common.layout import split_address
from apps.ethereum import networks, tokens from apps.ethereum import networks, tokens
from apps.ethereum.address import ethereum_address_hex from apps.ethereum.address import address_from_bytes
async def require_confirm_tx(ctx, to, value, chain_id, token=None, tx_type=None): async def require_confirm_tx(ctx, to_bytes, value, chain_id, token=None, tx_type=None):
if to: if to_bytes:
to_str = ethereum_address_hex(to, networks.by_chain_id(chain_id)) to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
else: else:
to_str = "new contract?" to_str = "new contract?"
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN, new_lines=False) text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.GREEN, new_lines=False)

@ -8,7 +8,7 @@ from trezor.messages.MessageType import EthereumTxAck
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.ethereum import tokens from apps.ethereum import address, tokens
from apps.ethereum.address import validate_full_path from apps.ethereum.address import validate_full_path
from apps.ethereum.layout import ( from apps.ethereum.layout import (
require_confirm_data, require_confirm_data,
@ -29,17 +29,17 @@ async def sign_tx(ctx, msg, keychain):
# detect ERC - 20 token # detect ERC - 20 token
token = None token = None
recipient = msg.to address_bytes = recipient = address.bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big") value = int.from_bytes(msg.value, "big")
if ( if (
len(msg.to) == 20 len(msg.to) in (40, 42)
and len(msg.value) == 0 and len(msg.value) == 0
and data_total == 68 and data_total == 68
and len(msg.data_initial_chunk) == 68 and len(msg.data_initial_chunk) == 68
and msg.data_initial_chunk[:16] and msg.data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" == b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
): ):
token = tokens.token_by_chain_address(msg.chain_id, msg.to) token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
recipient = msg.data_initial_chunk[16:36] recipient = msg.data_initial_chunk[16:36]
value = int.from_bytes(msg.data_initial_chunk[36:68], "big") value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
@ -69,7 +69,7 @@ async def sign_tx(ctx, msg, keychain):
if msg.tx_type is not None: if msg.tx_type is not None:
sha.extend(rlp.encode(msg.tx_type)) sha.extend(rlp.encode(msg.tx_type))
for field in [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value]: for field in (msg.nonce, msg.gas_price, msg.gas_limit, address_bytes, msg.value):
sha.extend(rlp.encode(field)) sha.extend(rlp.encode(field))
if data_left == 0: if data_left == 0:
@ -100,8 +100,12 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
if msg.tx_type is not None: if msg.tx_type is not None:
length += rlp.field_length(1, [msg.tx_type]) length += rlp.field_length(1, [msg.tx_type])
for field in [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value]: length += rlp.field_length(len(msg.nonce), msg.nonce[:1])
length += rlp.field_length(len(field), field[:1]) length += rlp.field_length(len(msg.gas_price), msg.gas_price)
length += rlp.field_length(len(msg.gas_limit), msg.gas_limit)
to = address.bytes_from_address(msg.to)
length += rlp.field_length(len(to), to)
length += rlp.field_length(len(msg.value), msg.value)
if msg.chain_id: # forks replay protection if msg.chain_id: # forks replay protection
if msg.chain_id < 0x100: if msg.chain_id < 0x100:
@ -182,12 +186,12 @@ def check_gas(msg: EthereumSignTx) -> bool:
def check_to(msg: EthereumTxRequest) -> bool: def check_to(msg: EthereumTxRequest) -> bool:
if msg.to == b"": if msg.to == "":
if msg.data_length == 0: if msg.data_length == 0:
# sending transaction to address 0 (contract creation) without a data field # sending transaction to address 0 (contract creation) without a data field
return False return False
else: else:
if len(msg.to) != 20: if len(msg.to) not in (40, 42):
return False return False
return True return True
@ -200,7 +204,7 @@ def sanitize(msg):
if msg.data_length is None: if msg.data_length is None:
msg.data_length = 0 msg.data_length = 0
if msg.to is None: if msg.to is None:
msg.to = b"" msg.to = ""
if msg.nonce is None: if msg.nonce is None:
msg.nonce = b"" msg.nonce = b""
if msg.chain_id is None: if msg.chain_id is None:

Loading…
Cancel
Save