diff --git a/src/apps/common/address_type.py b/src/apps/common/address_type.py new file mode 100644 index 0000000000..53abbceff8 --- /dev/null +++ b/src/apps/common/address_type.py @@ -0,0 +1,27 @@ +def length(address_type): + if address_type <= 0xFF: + return 1 + if address_type <= 0xFFFF: + return 2 + if address_type <= 0xFFFFFF: + return 3 + # else + return 4 + + +def check(address_type, raw_address): + if address_type <= 0xFF: + return address_type == raw_address[0] + if address_type <= 0xFFFF: + return address_type == (raw_address[0] << 8) | raw_address[1] + if address_type <= 0xFFFFFF: + return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2] + # else + return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3] + + +def strip(address_type, raw_address): + if not check(address_type, raw_address): + raise ValueError('Invalid address') + l = length(address_type) + return raw_address[l:] diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 1584c7da55..66c4120485 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -2,6 +2,7 @@ from trezor.crypto.hashlib import sha256, ripemd160 from trezor.crypto.curve import secp256k1 from trezor.crypto import base58, der +from . import address_type from . import coins from trezor.messages.CoinType import CoinType @@ -312,25 +313,15 @@ def get_tx_hash(w, double: bool, reverse: bool=False) -> bytes: # TX Outputs # === -def len_address_type(address_type): - if address_type <= 0xFF: - return 1 - if address_type <= 0xFFFF: - return 2 - if address_type <= 0xFFFFFF: - return 3 - # else - return 4 - def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: if o.script_type == OutputScriptType.PAYTOADDRESS: ra = output_paytoaddress_extract_raw_address(o, coin, root) - at = len_address_type(coin.address_type) - return script_paytoaddress_new(ra[at:]) + ra = address_type.strip(coin.address_type, ra) + return script_paytoaddress_new(ra) elif o.script_type == OutputScriptType.PAYTOSCRIPTHASH: ra = output_paytoaddress_extract_raw_address(o, coin, root, p2sh=True) - at = len_address_type(coin.address_type_p2sh) - return script_paytoscripthash_new(ra[at:]) + ra = address_type.strip(coin.address_type_p2sh, ra) + return script_paytoscripthash_new(ra) elif o.script_type == OutputScriptType.PAYTOOPRETURN: if o.amount == 0: return script_paytoopreturn_new(o.op_return_data) @@ -343,28 +334,18 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: return -def check_address_type(address_type, raw_address): - if address_type <= 0xFF: - return address_type == raw_address[0] - if address_type <= 0xFFFF: - return address_type == (raw_address[0] << 8) | raw_address[1] - if address_type <= 0xFFFFFF: - return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2] - # else - return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3] - def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root, p2sh=False) -> bytes: - address_type = coin.address_type_p2sh if p2sh else coin.address_type + addr_type = coin.address_type_p2sh if p2sh else coin.address_type # TODO: dont encode/decode more then necessary address_n = getattr(o, 'address_n', None) if address_n is not None: node = node_derive(root, address_n) - address = node.address(address_type) + address = node.address(addr_type) return base58.decode_check(address) address = getattr(o, 'address', None) if address: raw = base58.decode_check(address) - if not check_address_type(address_type, raw): + if not address_type.check(addr_type, raw): raise SigningError(FailureType.SyntaxError, 'Invalid address type') return raw