mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 12:00:59 +00:00
apps.common: refactor address_type functions
This commit is contained in:
parent
cfdd517bf4
commit
315440fc18
27
src/apps/common/address_type.py
Normal file
27
src/apps/common/address_type.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
def length(address_type):
|
||||||
|
if address_type <= 0xFF:
|
||||||
|
return 1
|
||||||
|
if address_type <= 0xFFFF:
|
||||||
|
return 2
|
||||||
|
if address_type <= 0xFFFFFF:
|
||||||
|
return 3
|
||||||
|
# else
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
def check(address_type, raw_address):
|
||||||
|
if address_type <= 0xFF:
|
||||||
|
return address_type == raw_address[0]
|
||||||
|
if address_type <= 0xFFFF:
|
||||||
|
return address_type == (raw_address[0] << 8) | raw_address[1]
|
||||||
|
if address_type <= 0xFFFFFF:
|
||||||
|
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
|
||||||
|
# else
|
||||||
|
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3]
|
||||||
|
|
||||||
|
|
||||||
|
def strip(address_type, raw_address):
|
||||||
|
if not check(address_type, raw_address):
|
||||||
|
raise ValueError('Invalid address')
|
||||||
|
l = length(address_type)
|
||||||
|
return raw_address[l:]
|
@ -2,6 +2,7 @@ from trezor.crypto.hashlib import sha256, ripemd160
|
|||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from trezor.crypto import base58, der
|
from trezor.crypto import base58, der
|
||||||
|
|
||||||
|
from . import address_type
|
||||||
from . import coins
|
from . import coins
|
||||||
|
|
||||||
from trezor.messages.CoinType import CoinType
|
from trezor.messages.CoinType import CoinType
|
||||||
@ -312,25 +313,15 @@ def get_tx_hash(w, double: bool, reverse: bool=False) -> bytes:
|
|||||||
# TX Outputs
|
# TX Outputs
|
||||||
# ===
|
# ===
|
||||||
|
|
||||||
def len_address_type(address_type):
|
|
||||||
if address_type <= 0xFF:
|
|
||||||
return 1
|
|
||||||
if address_type <= 0xFFFF:
|
|
||||||
return 2
|
|
||||||
if address_type <= 0xFFFFFF:
|
|
||||||
return 3
|
|
||||||
# else
|
|
||||||
return 4
|
|
||||||
|
|
||||||
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
||||||
if o.script_type == OutputScriptType.PAYTOADDRESS:
|
if o.script_type == OutputScriptType.PAYTOADDRESS:
|
||||||
ra = output_paytoaddress_extract_raw_address(o, coin, root)
|
ra = output_paytoaddress_extract_raw_address(o, coin, root)
|
||||||
at = len_address_type(coin.address_type)
|
ra = address_type.strip(coin.address_type, ra)
|
||||||
return script_paytoaddress_new(ra[at:])
|
return script_paytoaddress_new(ra)
|
||||||
elif o.script_type == OutputScriptType.PAYTOSCRIPTHASH:
|
elif o.script_type == OutputScriptType.PAYTOSCRIPTHASH:
|
||||||
ra = output_paytoaddress_extract_raw_address(o, coin, root, p2sh=True)
|
ra = output_paytoaddress_extract_raw_address(o, coin, root, p2sh=True)
|
||||||
at = len_address_type(coin.address_type_p2sh)
|
ra = address_type.strip(coin.address_type_p2sh, ra)
|
||||||
return script_paytoscripthash_new(ra[at:])
|
return script_paytoscripthash_new(ra)
|
||||||
elif o.script_type == OutputScriptType.PAYTOOPRETURN:
|
elif o.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||||
if o.amount == 0:
|
if o.amount == 0:
|
||||||
return script_paytoopreturn_new(o.op_return_data)
|
return script_paytoopreturn_new(o.op_return_data)
|
||||||
@ -343,28 +334,18 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def check_address_type(address_type, raw_address):
|
|
||||||
if address_type <= 0xFF:
|
|
||||||
return address_type == raw_address[0]
|
|
||||||
if address_type <= 0xFFFF:
|
|
||||||
return address_type == (raw_address[0] << 8) | raw_address[1]
|
|
||||||
if address_type <= 0xFFFFFF:
|
|
||||||
return address_type == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2]
|
|
||||||
# else
|
|
||||||
return address_type == (raw_address[0] << 24) | (raw_address[1] << 16) | (raw_address[2] << 8) | raw_address[3]
|
|
||||||
|
|
||||||
def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root, p2sh=False) -> bytes:
|
def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root, p2sh=False) -> bytes:
|
||||||
address_type = coin.address_type_p2sh if p2sh else coin.address_type
|
addr_type = coin.address_type_p2sh if p2sh else coin.address_type
|
||||||
# TODO: dont encode/decode more then necessary
|
# TODO: dont encode/decode more then necessary
|
||||||
address_n = getattr(o, 'address_n', None)
|
address_n = getattr(o, 'address_n', None)
|
||||||
if address_n is not None:
|
if address_n is not None:
|
||||||
node = node_derive(root, address_n)
|
node = node_derive(root, address_n)
|
||||||
address = node.address(address_type)
|
address = node.address(addr_type)
|
||||||
return base58.decode_check(address)
|
return base58.decode_check(address)
|
||||||
address = getattr(o, 'address', None)
|
address = getattr(o, 'address', None)
|
||||||
if address:
|
if address:
|
||||||
raw = base58.decode_check(address)
|
raw = base58.decode_check(address)
|
||||||
if not check_address_type(address_type, raw):
|
if not address_type.check(addr_type, raw):
|
||||||
raise SigningError(FailureType.SyntaxError,
|
raise SigningError(FailureType.SyntaxError,
|
||||||
'Invalid address type')
|
'Invalid address type')
|
||||||
return raw
|
return raw
|
||||||
|
Loading…
Reference in New Issue
Block a user