@ -14,6 +14,7 @@ if TYPE_CHECKING:
from . definitions import Definitions
from . keychain import MsgInSignTx
from typing import Any
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -32,7 +33,6 @@ async def sign_tx(
from trezor . utils import HashWriter
from apps . common import paths
from . layout import require_confirm_data , require_confirm_tx
# check
@ -44,16 +44,11 @@ async def sign_tx(
await paths . validate_path ( keychain , msg . address_n )
# Handle ERC20s
token , address_bytes , recipient , value = await handle_erc20 ( msg , defs )
data_total = msg . data_length # local_cache_attribute
if token is None and data_total > 0 :
await require_confirm_data ( msg . data_initial_chunk , data_total )
address_bytes = bytes_from_address ( msg . to )
token , value = await sign_tx_common ( msg , defs , address_bytes )
await require_confirm_tx (
recipient ,
address_bytes ,
value ,
int . from_bytes ( msg . gas_price , " big " ) ,
int . from_bytes ( msg . gas_limit , " big " ) ,
@ -62,6 +57,7 @@ async def sign_tx(
bool ( msg . chunkify ) ,
)
data_total = msg . data_length
data = bytearray ( )
data + = msg . data_initial_chunk
data_left = data_total - len ( msg . data_initial_chunk )
@ -99,33 +95,136 @@ async def sign_tx(
return result
async def handle_erc20 (
async def sign_tx_common (
msg : MsgInSignTx ,
definitions : Definitions ,
) - > tuple [ EthereumTokenInfo | None , bytes , bytes , int ] :
address_bytes : bytes ,
) - > tuple [ EthereumTokenInfo | None , int ] :
from . layout import (
require_confirm_unknown_token ,
require_confirm_smart_contract ,
require_confirm_tx ,
require_confirm_data ,
)
from . import tokens
from . layout import require_confirm_unknown_token
data_initial_chunk = msg . data_initial_chunk # local_cache_attribute
token = None
address_bytes = recipient = bytes_from_address ( msg . to )
value = int . from_bytes ( msg . value , " big " )
if (
len ( msg . to ) in ( 40 , 42 )
and len ( msg . value ) == 0
and msg . data_length == 68
and len ( data_initial_chunk ) == 68
and data_initial_chunk [ : 16 ]
== b " \xa9 \x05 \x9c \xbb \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 \x00 "
) :
if len ( msg . to ) in ( 40 , 42 ) and value == 0 :
# Smart Contract
func_name , func_args , transfer = _resolve_tx_data_field ( data_initial_chunk )
# TODO handle ValueError from _resolve_tx_data_field, presently it fails `test_data_streaming`
token = definitions . get_token ( address_bytes )
recipient = data_initial_chunk [ 16 : 36 ]
value = int . from_bytes ( data_initial_chunk [ 36 : 68 ] , " big " )
if token is tokens . UNKNOWN_TOKEN :
await require_confirm_unknown_token ( address_bytes )
return token , address_bytes , recipient , value
if transfer [ 0 ] :
# 'transfer' functions should override the value
arg_val_idx = transfer [ 1 ]
value = int ( func_args [ arg_val_idx ] [ 1 ] )
else :
# we want to show default network at summary screen (i.e. 0 ETH)
token = None
await require_confirm_smart_contract ( func_name , func_args )
else :
# Regular transaction
await require_confirm_tx ( address_bytes , value , definitions . network , token )
if msg . data_length > 0 :
await require_confirm_data ( data_initial_chunk , msg . data_length )
return token , value
def _resolve_type ( val : memoryview , type_str : str ) - > str :
from ubinascii import hexlify
from . helpers import address_from_bytes
if type_str == " int " :
return str ( int . from_bytes ( val , " big " ) )
elif type_str == " str " :
# TODO improve shown text
return bytes ( val ) . decode ( )
elif type_str == " bytes " :
return hexlify ( val ) . decode ( )
elif type_str == " address " :
return address_from_bytes ( val [ - 20 : ] )
else :
raise ValueError
FUNCTIONS_DEF : dict [ bytes , dict [ str , Any ] ] = {
b " \xa9 \x05 \x9c \xbb " : {
" name " : " transfer " ,
" args " : [
( " Recipient " , " address " ) ,
( " Amount " , " int " ) ,
] ,
" transfer " : ( True , 1 ) ,
} ,
b " \x09 \x5e \xa7 \xb3 " : {
" name " : " approve " ,
" args " : [
( " Address " , " address " ) ,
( " Amount " , " int " ) ,
] ,
" transfer " : ( False , 0 ) ,
} ,
b " \x00 \x00 \x00 \x42 " : {
" name " : " args_test " ,
" args " : [
( " Arg0_int " , " int " ) ,
( " Arg1_str " , " str " ) ,
( " Arg2_bytes " , " bytes " ) ,
( " Arg3_address " , " address " ) ,
] ,
" transfer " : ( False , 0 ) ,
# TODO token to address assignment is done by additional entry, where:
# - 1st value is the idx of the value of a token
# - 2nd value is the idx of the address of the token -> to be used in the `definitions.get_token(addr)` call
# - if 1st == 2nd: token is assigned based on contract address
# "token_assign": (0, 3),
} ,
}
def _resolve_tx_data_field (
data_bytes : bytes ,
) - > tuple [ str , list [ tuple [ str , str | bytes ] ] , tuple [ bool , int ] ] :
from ubinascii import hexlify
data_args_len = len ( data_bytes )
N_BYTES_FUNC = 4
N_BYTES_ARG = 32
n_args = ( data_args_len - N_BYTES_FUNC ) / / N_BYTES_ARG
data = memoryview ( data_bytes )
def _data_field_aligned ( data_args_len : int , n_args : int ) - > bool :
# checks if "Data" field doesn't have trailing bytes
return data_args_len == ( n_args * N_BYTES_ARG + N_BYTES_FUNC )
def _get_nth_arg ( data_mv : memoryview , n : int ) - > memoryview :
# returns slice of the nth argument in "Data" field
beg = ( n + 0 ) * N_BYTES_ARG + N_BYTES_FUNC
end = ( n + 1 ) * N_BYTES_ARG + N_BYTES_FUNC
return data_mv [ beg : end ]
if data_args_len < N_BYTES_FUNC or not _data_field_aligned ( data_args_len , n_args ) :
raise ValueError
func_signature_bytes = data_bytes [ : N_BYTES_FUNC ]
func_def = FUNCTIONS_DEF . get ( func_signature_bytes , None )
if func_def is not None and n_args == len ( func_def [ " args " ] ) :
func_name = func_def [ " name " ]
func_args = [
( f " { name } : " , _resolve_type ( _get_nth_arg ( data , i ) , type_str ) )
for i , ( name , type_str ) in enumerate ( func_def [ " args " ] )
]
transfer = func_def [ " transfer " ]
else :
func_name = hexlify ( func_signature_bytes ) . decode ( )
func_args = [ ( f " Input { i } : " , _get_nth_arg ( data , i ) ) for i in range ( n_args ) ]
transfer = ( False , 0 )
return ( func_name , func_args , transfer )
def _get_total_length ( msg : EthereumSignTx , data_total : int ) - > int :