diff --git a/src/apps/common/address_type.py b/src/apps/common/address_type.py index 43737d1e9d..41ba68b6d0 100644 --- a/src/apps/common/address_type.py +++ b/src/apps/common/address_type.py @@ -9,49 +9,18 @@ def length(address_type): return 4 -def addrtype_bytes(address_type: int): - if address_type <= 0xFF: - return bytes([address_type]) - if address_type <= 0xFFFF: - return bytes([(address_type >> 8), (address_type & 0xFF)]) - if address_type <= 0xFFFFFF: - return bytes([(address_type >> 16), (address_type >> 8), (address_type & 0xFF)]) - # else - return bytes( - [ - (address_type >> 24), - (address_type >> 16), - (address_type >> 8), - (address_type & 0xFF), - ] - ) +def tobytes(address_type: int): + return address_type.to_bytes(length(address_type), "big") def check(address_type, raw_address): - if address_type <= 0xFF: - return address_type == raw_address[0] - if address_type <= 0xFFFF: - return address_type == (raw_address[0] << 8) | raw_address[1] - if address_type <= 0xFFFFFF: - return ( - address_type - == (raw_address[0] << 16) | (raw_address[1] << 8) | raw_address[2] - ) - # else - return ( - address_type - == (raw_address[0] << 24) - | (raw_address[1] << 16) - | (raw_address[2] << 8) - | raw_address[3] - ) + return raw_address.startswith(tobytes(address_type)) def strip(address_type, raw_address): if not check(address_type, raw_address): raise ValueError("Invalid address") - l = length(address_type) - return raw_address[l:] + return raw_address[length(address_type) :] def split(coin, raw_address): diff --git a/src/apps/wallet/sign_tx/addresses.py b/src/apps/wallet/sign_tx/addresses.py index 184124589f..0c62a39703 100644 --- a/src/apps/wallet/sign_tx/addresses.py +++ b/src/apps/wallet/sign_tx/addresses.py @@ -5,7 +5,7 @@ from trezor.crypto.hashlib import ripemd160, sha256 from trezor.messages import FailureType, InputScriptType from trezor.utils import ensure -from apps.common.address_type import addrtype_bytes +from apps.common import address_type from apps.common.coininfo import CoinInfo from apps.wallet.sign_tx.multisig import multisig_get_pubkeys, multisig_pubkey_index from apps.wallet.sign_tx.scripts import ( @@ -117,12 +117,12 @@ def address_multisig_p2wsh(pubkeys: bytes, m: int, hrp: str): def address_pkh(pubkey: bytes, coin: CoinInfo) -> str: - s = addrtype_bytes(coin.address_type) + sha256_ripemd160_digest(pubkey) + s = address_type.tobytes(coin.address_type) + sha256_ripemd160_digest(pubkey) return base58.encode_check(bytes(s), coin.b58_hash) def address_p2sh(redeem_script_hash: bytes, coin: CoinInfo) -> str: - s = addrtype_bytes(coin.address_type_p2sh) + redeem_script_hash + s = address_type.tobytes(coin.address_type_p2sh) + redeem_script_hash return base58.encode_check(bytes(s), coin.b58_hash)