1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 12:00:59 +00:00

apps.common: refactor address_type functions

This commit is contained in:
Pavol Rusnak 2016-11-16 12:42:11 +01:00
parent cfdd517bf4
commit 315440fc18
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
2 changed files with 35 additions and 27 deletions

View File

@ -0,0 +1,27 @@
def length(address_type):
if address_type <= 0xFF:
return 1
if address_type <= 0xFFFF:
return 2
if address_type <= 0xFFFFFF:
return 3
# else
return 4
def check(address_type, raw_address):
if address_type <= 0xFF:
return address_type == raw_address[0]
if address_type <= 0xFFFF:
return address_type == (raw_address[0] << 8) | raw_address[1]
if address_type <= 0xFFFFFF:
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
# else
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3]
def strip(address_type, raw_address):
if not check(address_type, raw_address):
raise ValueError('Invalid address')
l = length(address_type)
return raw_address[l:]

View File

@ -2,6 +2,7 @@ from trezor.crypto.hashlib import sha256, ripemd160
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto import base58, der from trezor.crypto import base58, der
from . import address_type
from . import coins from . import coins
from trezor.messages.CoinType import CoinType from trezor.messages.CoinType import CoinType
@ -312,25 +313,15 @@ def get_tx_hash(w, double: bool, reverse: bool=False) -> bytes:
# TX Outputs # TX Outputs
# === # ===
def len_address_type(address_type):
if address_type <= 0xFF:
return 1
if address_type <= 0xFFFF:
return 2
if address_type <= 0xFFFFFF:
return 3
# else
return 4
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if o.script_type == OutputScriptType.PAYTOADDRESS: if o.script_type == OutputScriptType.PAYTOADDRESS:
ra = output_paytoaddress_extract_raw_address(o, coin, root) ra = output_paytoaddress_extract_raw_address(o, coin, root)
at = len_address_type(coin.address_type) ra = address_type.strip(coin.address_type, ra)
return script_paytoaddress_new(ra[at:]) return script_paytoaddress_new(ra)
elif o.script_type == OutputScriptType.PAYTOSCRIPTHASH: elif o.script_type == OutputScriptType.PAYTOSCRIPTHASH:
ra = output_paytoaddress_extract_raw_address(o, coin, root, p2sh=True) ra = output_paytoaddress_extract_raw_address(o, coin, root, p2sh=True)
at = len_address_type(coin.address_type_p2sh) ra = address_type.strip(coin.address_type_p2sh, ra)
return script_paytoscripthash_new(ra[at:]) return script_paytoscripthash_new(ra)
elif o.script_type == OutputScriptType.PAYTOOPRETURN: elif o.script_type == OutputScriptType.PAYTOOPRETURN:
if o.amount == 0: if o.amount == 0:
return script_paytoopreturn_new(o.op_return_data) return script_paytoopreturn_new(o.op_return_data)
@ -343,28 +334,18 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
return return
def check_address_type(address_type, raw_address):
if address_type <= 0xFF:
return address_type == raw_address[0]
if address_type <= 0xFFFF:
return address_type == (raw_address[0] << 8) | raw_address[1]
if address_type <= 0xFFFFFF:
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
# else
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3]
def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root, p2sh=False) -> bytes: def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root, p2sh=False) -> bytes:
address_type = coin.address_type_p2sh if p2sh else coin.address_type addr_type = coin.address_type_p2sh if p2sh else coin.address_type
# TODO: dont encode/decode more then necessary # TODO: dont encode/decode more then necessary
address_n = getattr(o, 'address_n', None) address_n = getattr(o, 'address_n', None)
if address_n is not None: if address_n is not None:
node = node_derive(root, address_n) node = node_derive(root, address_n)
address = node.address(address_type) address = node.address(addr_type)
return base58.decode_check(address) return base58.decode_check(address)
address = getattr(o, 'address', None) address = getattr(o, 'address', None)
if address: if address:
raw = base58.decode_check(address) raw = base58.decode_check(address)
if not check_address_type(address_type, raw): if not address_type.check(addr_type, raw):
raise SigningError(FailureType.SyntaxError, raise SigningError(FailureType.SyntaxError,
'Invalid address type') 'Invalid address type')
return raw return raw