diff --git a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h index d6ba9d1beb..306f60ea6b 100644 --- a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h +++ b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h @@ -422,7 +422,16 @@ STATIC mp_obj_t mod_trezorcrypto_HDNode_ethereum_pubkeyhash(mp_obj_t self) { } STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorcrypto_HDNode_ethereum_pubkeyhash_obj, mod_trezorcrypto_HDNode_ethereum_pubkeyhash); +STATIC mp_obj_t mod_trezorcrypto_HDNode___del__(mp_obj_t self) { + mp_obj_HDNode_t *o = MP_OBJ_TO_PTR(self); + o->fingerprint = 0; + memzero(&o->hdnode, sizeof(o->hdnode)); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorcrypto_HDNode___del___obj, mod_trezorcrypto_HDNode___del__); + STATIC const mp_rom_map_elem_t mod_trezorcrypto_HDNode_locals_dict_table[] = { + { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mod_trezorcrypto_HDNode___del___obj) }, { MP_ROM_QSTR(MP_QSTR_derive), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_obj) }, { MP_ROM_QSTR(MP_QSTR_derive_cardano), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_cardano_obj) }, { MP_ROM_QSTR(MP_QSTR_derive_path), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_path_obj) }, @@ -539,9 +548,7 @@ STATIC mp_obj_t mod_trezorcrypto_bip32_from_mnemonic_cardano(mp_obj_t mnemonic, return MP_OBJ_FROM_PTR(o); } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_bip32_from_mnemonic_cardano_obj, - mod_trezorcrypto_bip32_from_mnemonic_cardano); - +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_bip32_from_mnemonic_cardano_obj, mod_trezorcrypto_bip32_from_mnemonic_cardano); STATIC const mp_rom_map_elem_t mod_trezorcrypto_bip32_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_bip32) }, diff --git a/src/apps/cardano/__init__.py b/src/apps/cardano/__init__.py index 022a12f197..dcfa3b5f41 100644 --- a/src/apps/cardano/__init__.py +++ b/src/apps/cardano/__init__.py @@ -1,6 +1,10 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + +SEED_NAMESPACE = [[HARDENED | 44, HARDENED | 1815]] + def boot(): wire.add(MessageType.CardanoGetAddress, __name__, "get_address") diff --git a/src/apps/cardano/address.py b/src/apps/cardano/address.py index 1f8a412a08..078d309035 100644 --- a/src/apps/cardano/address.py +++ b/src/apps/cardano/address.py @@ -1,8 +1,27 @@ from trezor.crypto import base58, crc, hashlib -from . import cbor +from apps.cardano import cbor +from apps.common import HARDENED +from apps.common.seed import remove_ed25519_prefix -from apps.common import HARDENED, seed + +def derive_address_and_node(keychain, path: list): + node = keychain.derive(path) + + address_payload = None + address_attributes = {} + + address_root = _get_address_root(node, address_payload) + address_type = 0 + address_data = [address_root, address_attributes, address_type] + address_data_encoded = cbor.encode(address_data) + + address = base58.encode( + cbor.encode( + [cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)] + ) + ) + return (address, node) def validate_full_path(path: list) -> bool: @@ -36,31 +55,9 @@ def _address_hash(data) -> bytes: def _get_address_root(node, payload): - extpubkey = seed.remove_ed25519_prefix(node.public_key()) + node.chain_code() + extpubkey = remove_ed25519_prefix(node.public_key()) + node.chain_code() if payload: payload = {1: cbor.encode(payload)} else: payload = {} return _address_hash([0, [0, extpubkey], payload]) - - -def derive_address_and_node(root_node, path: list): - derived_node = root_node.clone() - - address_payload = None - address_attributes = {} - - for indice in path: - derived_node.derive_cardano(indice) - - address_root = _get_address_root(derived_node, address_payload) - address_type = 0 - address_data = [address_root, address_attributes, address_type] - address_data_encoded = cbor.encode(address_data) - - address = base58.encode( - cbor.encode( - [cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)] - ) - ) - return (address, derived_node) diff --git a/src/apps/cardano/get_address.py b/src/apps/cardano/get_address.py index 660826c801..9b03ebca47 100644 --- a/src/apps/cardano/get_address.py +++ b/src/apps/cardano/get_address.py @@ -1,28 +1,23 @@ from trezor import log, ui, wire -from trezor.crypto import bip32 from trezor.messages.CardanoAddress import CardanoAddress -from .address import derive_address_and_node, validate_full_path -from .layout import confirm_with_pagination - -from apps.common import paths, seed, storage +from apps.cardano import seed +from apps.cardano.address import derive_address_and_node, validate_full_path +from apps.cardano.layout import confirm_with_pagination +from apps.common import paths async def get_address(ctx, msg): + keychain = await seed.get_keychain(ctx) + await paths.validate_path(ctx, validate_full_path, path=msg.address_n) - mnemonic = storage.get_mnemonic() - passphrase = await seed._get_cached_passphrase(ctx) - root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase) - try: - address, _ = derive_address_and_node(root_node, msg.address_n) + address, _ = derive_address_and_node(keychain, msg.address_n) except ValueError as e: if __debug__: log.exception(__name__, e) raise wire.ProcessError("Deriving address failed") - mnemonic = None - root_node = None if msg.show_display: if not await confirm_with_pagination( diff --git a/src/apps/cardano/get_public_key.py b/src/apps/cardano/get_public_key.py index 24b7417912..8ebe9e3ef7 100644 --- a/src/apps/cardano/get_public_key.py +++ b/src/apps/cardano/get_public_key.py @@ -1,42 +1,38 @@ from ubinascii import hexlify from trezor import log, wire -from trezor.crypto import bip32 from trezor.messages.CardanoPublicKey import CardanoPublicKey from trezor.messages.HDNodeType import HDNodeType -from .address import derive_address_and_node - -from apps.common import layout, paths, seed, storage +from apps.cardano import seed +from apps.cardano.address import derive_address_and_node +from apps.common import layout, paths +from apps.common.seed import remove_ed25519_prefix async def get_public_key(ctx, msg): + keychain = await seed.get_keychain(ctx) + await paths.validate_path( ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815 ) - mnemonic = storage.get_mnemonic() - passphrase = await seed._get_cached_passphrase(ctx) - root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase) - try: - key = _get_public_key(root_node, msg.address_n) + key = _get_public_key(keychain, msg.address_n) except ValueError as e: if __debug__: log.exception(__name__, e) raise wire.ProcessError("Deriving public key failed") - mnemonic = None - root_node = None if msg.show_display: await layout.show_pubkey(ctx, key.node.public_key) return key -def _get_public_key(root_node, derivation_path: list): - _, node = derive_address_and_node(root_node, derivation_path) +def _get_public_key(keychain, derivation_path: list): + _, node = derive_address_and_node(keychain, derivation_path) - public_key = hexlify(seed.remove_ed25519_prefix(node.public_key())).decode() + public_key = hexlify(remove_ed25519_prefix(node.public_key())).decode() chain_code = hexlify(node.chain_code()).decode() xpub_key = public_key + chain_code @@ -45,7 +41,7 @@ def _get_public_key(root_node, derivation_path: list): child_num=node.child_num(), fingerprint=node.fingerprint(), chain_code=node.chain_code(), - public_key=seed.remove_ed25519_prefix(node.public_key()), + public_key=remove_ed25519_prefix(node.public_key()), ) return CardanoPublicKey(node=node_type, xpub=xpub_key) diff --git a/src/apps/cardano/seed.py b/src/apps/cardano/seed.py new file mode 100644 index 0000000000..102c19b69b --- /dev/null +++ b/src/apps/cardano/seed.py @@ -0,0 +1,43 @@ +from trezor import wire +from trezor.crypto import bip32 + +from apps.cardano import SEED_NAMESPACE +from apps.common import cache, storage +from apps.common.request_passphrase import protect_by_passphrase + + +class Keychain: + def __init__(self, path: list, root: bip32.HDNode): + self.path = path + self.root = root + + def derive(self, node_path: list) -> bip32.HDNode: + # check we are in the cardano namespace + prefix = node_path[: len(self.path)] + suffix = node_path[len(self.path) :] + if prefix != self.path: + raise wire.DataError("Forbidden key path") + # derive child node from the root + node = self.root.clone() + for i in suffix: + node.derive_cardano(i) + return node + + +async def get_keychain(ctx: wire.Context) -> Keychain: + if not storage.is_initialized(): + raise wire.ProcessError("Device is not initialized") + + # derive the root node from mnemonic and passphrase + passphrase = cache.get_passphrase() + if passphrase is None: + passphrase = await protect_by_passphrase(ctx) + cache.set_passphrase(passphrase) + root = bip32.from_mnemonic_cardano(storage.get_mnemonic(), passphrase) + + # derive the namespaced root node + for i in SEED_NAMESPACE[0]: + root.derive_cardano(i) + + keychain = Keychain(SEED_NAMESPACE[0], root) + return keychain diff --git a/src/apps/cardano/sign_tx.py b/src/apps/cardano/sign_tx.py index 970224d586..ae453af1e9 100644 --- a/src/apps/cardano/sign_tx.py +++ b/src/apps/cardano/sign_tx.py @@ -1,18 +1,17 @@ from trezor import log, ui, wire -from trezor.crypto import base58, bip32, hashlib +from trezor.crypto import base58, hashlib from trezor.crypto.curve import ed25519 from trezor.messages.CardanoSignedTx import CardanoSignedTx from trezor.messages.CardanoTxRequest import CardanoTxRequest from trezor.messages.MessageType import CardanoTxAck from trezor.ui.text import BR -from .address import derive_address_and_node, validate_full_path -from .layout import confirm_with_pagination, progress - -from apps.cardano import cbor -from apps.common import seed, storage +from apps.cardano import cbor, seed +from apps.cardano.address import derive_address_and_node, validate_full_path +from apps.cardano.layout import confirm_with_pagination, progress from apps.common.layout import address_n_to_str, split_address from apps.common.paths import validate_path +from apps.common.seed import remove_ed25519_prefix from apps.homescreen.homescreen import display_homescreen @@ -80,9 +79,7 @@ async def request_transaction(ctx, tx_req: CardanoTxRequest, index: int): async def sign_tx(ctx, msg): - mnemonic = storage.get_mnemonic() - passphrase = await seed._get_cached_passphrase(ctx) - root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + keychain = await seed.get_keychain(ctx) progress.init(msg.transactions_count, "Loading data") @@ -103,7 +100,7 @@ async def sign_tx(ctx, msg): # sign the transaction bundle and prepare the result transaction = Transaction( - msg.inputs, msg.outputs, transactions, root_node, msg.network + msg.inputs, msg.outputs, transactions, keychain, msg.network ) tx_body, tx_hash = transaction.serialise_tx() tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash) @@ -135,12 +132,12 @@ def _micro_ada_to_ada(amount: float) -> float: class Transaction: def __init__( - self, inputs: list, outputs: list, transactions: list, root_node, network: int + self, inputs: list, outputs: list, transactions: list, keychain, network: int ): self.inputs = inputs self.outputs = outputs self.transactions = transactions - self.root_node = root_node + self.keychain = keychain # attributes have to be always empty in current Cardano self.attributes = {} if network == 1: @@ -170,7 +167,7 @@ class Transaction: nodes = [] for input in self.inputs: - _, node = derive_address_and_node(self.root_node, input.address_n) + _, node = derive_address_and_node(self.keychain, input.address_n) nodes.append(node) for index, output_index in enumerate(output_indexes): @@ -198,7 +195,7 @@ class Transaction: for output in self.outputs: if output.address_n: - address, _ = derive_address_and_node(self.root_node, output.address_n) + address, _ = derive_address_and_node(self.keychain, output.address_n) change_addresses.append(address) change_derivation_paths.append(output.address_n) change_coins.append(output.amount) @@ -225,7 +222,7 @@ class Transaction: node.private_key(), node.private_key_ext(), message ) extended_public_key = ( - seed.remove_ed25519_prefix(node.public_key()) + node.chain_code() + remove_ed25519_prefix(node.public_key()) + node.chain_code() ) witnesses.append( [ diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index a2a27aab02..1890b013b2 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -4,41 +4,75 @@ from trezor.crypto import bip32, bip39 from apps.common import cache, storage from apps.common.request_passphrase import protect_by_passphrase -_DEFAULT_CURVE = "secp256k1" +allow = list -async def derive_node( - ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE -) -> bip32.HDNode: - seed = await _get_cached_seed(ctx) - node = bip32.from_seed(seed, curve_name) - node.derive_path(path) - return node +class Keychain: + """ + Keychain provides an API for deriving HD keys from previously allowed + key-spaces. + """ + + def __init__(self, seed: bytes, namespaces: list): + self.seed = seed + self.namespaces = namespaces + self.roots = [None] * len(namespaces) + + def __del__(self): + for root in self.roots: + if root is not None: + root.__del__() + del self.roots + del self.seed + + def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: + # find the root node index + root_index = 0 + for curve, *path in self.namespaces: + prefix = node_path[: len(path)] + suffix = node_path[len(path) :] + if curve == curve_name and path == prefix: + break + root_index += 1 + else: + raise wire.DataError("Forbidden key path") + + # create the root node if not cached + root = self.roots[root_index] + if root is None: + root = bip32.from_seed(self.seed, curve_name) + root.derive_path(path) + self.roots[root_index] = root + + # derive child node from the root + node = root.clone() + node.derive_path(suffix) + return node -async def _get_cached_seed(ctx: wire.Context) -> bytes: +async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: if not storage.is_initialized(): raise wire.ProcessError("Device is not initialized") - if cache.get_seed() is None: - passphrase = await _get_cached_passphrase(ctx) + + seed = cache.get_seed() + if seed is None: + # derive seed from mnemonic and passphrase + passphrase = cache.get_passphrase() + if passphrase is None: + passphrase = await protect_by_passphrase(ctx) + cache.set_passphrase(passphrase) seed = bip39.seed(storage.get_mnemonic(), passphrase) cache.set_seed(seed) - return cache.get_seed() - -async def _get_cached_passphrase(ctx: wire.Context) -> str: - if cache.get_passphrase() is None: - passphrase = await protect_by_passphrase(ctx) - cache.set_passphrase(passphrase) - return cache.get_passphrase() + keychain = Keychain(seed, namespaces) + return keychain def derive_node_without_passphrase( - path: list, curve_name: str = _DEFAULT_CURVE + path: list, curve_name: str = "secp256k1" ) -> bip32.HDNode: if not storage.is_initialized(): raise Exception("Device is not initialized") - seed = bip39.seed(storage.get_mnemonic(), "") node = bip32.from_seed(seed, curve_name) node.derive_path(path) diff --git a/src/apps/common/signverify.py b/src/apps/common/signverify.py index 99476f740f..4b13731238 100644 --- a/src/apps/common/signverify.py +++ b/src/apps/common/signverify.py @@ -3,7 +3,7 @@ from ubinascii import hexlify from trezor.crypto.hashlib import blake256, sha256 from trezor.utils import HashWriter -from apps.wallet.sign_tx.signing import write_varint +from apps.wallet.sign_tx.writers import write_varint def message_digest(coin, message): diff --git a/src/apps/ethereum/__init__.py b/src/apps/ethereum/__init__.py index f0b0ff9394..04b07a773d 100644 --- a/src/apps/ethereum/__init__.py +++ b/src/apps/ethereum/__init__.py @@ -1,9 +1,12 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.EthereumGetAddress, __name__, "get_address") - wire.add(MessageType.EthereumSignTx, __name__, "sign_tx") - wire.add(MessageType.EthereumSignMessage, __name__, "sign_message") + ns = [["secp256k1", HARDENED | 44, HARDENED | 60]] + wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns) + wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns) + wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns) wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message") diff --git a/src/apps/ethereum/get_address.py b/src/apps/ethereum/get_address.py index e0c7e7b155..39c7bacb63 100644 --- a/src/apps/ethereum/get_address.py +++ b/src/apps/ethereum/get_address.py @@ -1,20 +1,17 @@ -from .address import ethereum_address_hex, validate_full_path +from trezor.crypto.curve import secp256k1 +from trezor.crypto.hashlib import sha3_256 +from trezor.messages.EthereumAddress import EthereumAddress from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr from apps.ethereum import networks +from apps.ethereum.address import ethereum_address_hex, validate_full_path -async def get_address(ctx, msg): - from trezor.messages.EthereumAddress import EthereumAddress - from trezor.crypto.curve import secp256k1 - from trezor.crypto.hashlib import sha3_256 - from apps.common import seed - +async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n) - + node = keychain.derive(msg.address_n) seckey = node.private_key() public_key = secp256k1.publickey(seckey, False) # uncompressed address = sha3_256(public_key[1:], keccak=True).digest()[12:] diff --git a/src/apps/ethereum/networks.py b/src/apps/ethereum/networks.py index f5647095f6..4b72a84225 100644 --- a/src/apps/ethereum/networks.py +++ b/src/apps/ethereum/networks.py @@ -82,14 +82,14 @@ NETWORKS = [ NetworkInfo( chain_id=30, slip44=137, - shortcut="RSK", + shortcut="RBTC", name="RSK", rskip60=True, ), NetworkInfo( chain_id=31, slip44=37310, - shortcut="tRSK", + shortcut="tRBTC", name="RSK Testnet", rskip60=True, ), diff --git a/src/apps/ethereum/sign_message.py b/src/apps/ethereum/sign_message.py index c869f45076..ed6fe8dbe6 100644 --- a/src/apps/ethereum/sign_message.py +++ b/src/apps/ethereum/sign_message.py @@ -4,11 +4,10 @@ from trezor.messages.EthereumMessageSignature import EthereumMessageSignature from trezor.ui.text import Text from trezor.utils import HashWriter -from .address import validate_full_path - -from apps.common import paths, seed +from apps.common import paths from apps.common.confirm import require_confirm from apps.common.signverify import split_message +from apps.ethereum.address import validate_full_path def message_digest(message): @@ -20,12 +19,11 @@ def message_digest(message): return h.get_digest() -async def sign_message(ctx, msg): +async def sign_message(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await require_confirm_sign_message(ctx, msg.message) - node = await seed.derive_node(ctx, msg.address_n) - + node = keychain.derive(msg.address_n) signature = secp256k1.sign( node.private_key(), message_digest(msg.message), diff --git a/src/apps/ethereum/sign_tx.py b/src/apps/ethereum/sign_tx.py index 436f7a5c40..8f94f819c3 100644 --- a/src/apps/ethereum/sign_tx.py +++ b/src/apps/ethereum/sign_tx.py @@ -7,10 +7,9 @@ from trezor.messages.EthereumTxRequest import EthereumTxRequest from trezor.messages.MessageType import EthereumTxAck from trezor.utils import HashWriter -from .address import validate_full_path - -from apps.common import paths, seed +from apps.common import paths from apps.ethereum import tokens +from apps.ethereum.address import validate_full_path from apps.ethereum.layout import ( require_confirm_data, require_confirm_fee, @@ -21,7 +20,7 @@ from apps.ethereum.layout import ( MAX_CHAIN_ID = 2147483629 -async def sign_tx(ctx, msg): +async def sign_tx(ctx, msg, keychain): msg = sanitize(msg) check(msg) await paths.validate_path(ctx, validate_full_path, path=msg.address_n) @@ -91,7 +90,9 @@ async def sign_tx(ctx, msg): sha.extend(rlp.encode(0)) digest = sha.get_digest() - return await send_signature(ctx, msg, digest) + result = sign_digest(msg, keychain, digest) + + return result def get_total_length(msg: EthereumSignTx, data_total: int) -> int: @@ -130,9 +131,8 @@ async def send_request_chunk(ctx, data_left: int): return await ctx.call(req, EthereumTxAck) -async def send_signature(ctx, msg: EthereumSignTx, digest): - node = await seed.derive_node(ctx, msg.address_n) - +def sign_digest(msg: EthereumSignTx, keychain, digest): + node = keychain.derive(msg.address_n) signature = secp256k1.sign( node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM ) diff --git a/src/apps/lisk/__init__.py b/src/apps/lisk/__init__.py index acef4dd9da..f4a092cec2 100644 --- a/src/apps/lisk/__init__.py +++ b/src/apps/lisk/__init__.py @@ -1,10 +1,13 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.LiskGetAddress, __name__, "get_address") - wire.add(MessageType.LiskSignMessage, __name__, "sign_message") + ns = [["ed25519", HARDENED | 44, HARDENED | 134]] + wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns) + wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns) + wire.add(MessageType.LiskSignTx, __name__, "sign_tx", ns) + wire.add(MessageType.LiskSignMessage, __name__, "sign_message", ns) wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message") - wire.add(MessageType.LiskSignTx, __name__, "sign_tx") diff --git a/src/apps/lisk/get_address.py b/src/apps/lisk/get_address.py index 9217420c3c..c77d0045b5 100644 --- a/src/apps/lisk/get_address.py +++ b/src/apps/lisk/get_address.py @@ -2,14 +2,14 @@ from trezor.messages.LiskAddress import LiskAddress from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path -from apps.common import paths, seed +from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr -async def get_address(ctx, msg): +async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, LISK_CURVE) pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker address = get_address_from_public_key(pubkey) diff --git a/src/apps/lisk/get_public_key.py b/src/apps/lisk/get_public_key.py index 37c613997f..93e3907cf0 100644 --- a/src/apps/lisk/get_public_key.py +++ b/src/apps/lisk/get_public_key.py @@ -2,13 +2,13 @@ from trezor.messages.LiskPublicKey import LiskPublicKey from .helpers import LISK_CURVE, validate_full_path -from apps.common import layout, paths, seed +from apps.common import layout, paths -async def get_public_key(ctx, msg): +async def get_public_key(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, LISK_CURVE) pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker diff --git a/src/apps/lisk/sign_message.py b/src/apps/lisk/sign_message.py index 8ccb88b72e..9b1c99834d 100644 --- a/src/apps/lisk/sign_message.py +++ b/src/apps/lisk/sign_message.py @@ -6,10 +6,10 @@ from trezor.utils import HashWriter from .helpers import LISK_CURVE, validate_full_path -from apps.common import paths, seed +from apps.common import paths from apps.common.confirm import require_confirm from apps.common.signverify import split_message -from apps.wallet.sign_tx.signing import write_varint +from apps.wallet.sign_tx.writers import write_varint def message_digest(message): @@ -22,11 +22,11 @@ def message_digest(message): return sha256(h.get_digest()).digest() -async def sign_message(ctx, msg): +async def sign_message(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await require_confirm_sign_message(ctx, msg.message) - node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, LISK_CURVE) seckey = node.private_key() pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker diff --git a/src/apps/lisk/sign_tx.py b/src/apps/lisk/sign_tx.py index 9466620303..0ae7c0e732 100644 --- a/src/apps/lisk/sign_tx.py +++ b/src/apps/lisk/sign_tx.py @@ -7,16 +7,19 @@ from trezor.messages import LiskTransactionType from trezor.messages.LiskSignedTx import LiskSignedTx from trezor.utils import HashWriter -from . import layout -from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path - -from apps.common import paths, seed +from apps.common import paths +from apps.lisk import layout +from apps.lisk.helpers import ( + LISK_CURVE, + get_address_from_public_key, + validate_full_path, +) -async def sign_tx(ctx, msg): +async def sign_tx(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, path=msg.address_n) - pubkey, seckey = await _get_keys(ctx, msg) + pubkey, seckey = _get_keys(keychain, msg) transaction = _update_raw_tx(msg.transaction, pubkey) try: @@ -37,8 +40,8 @@ async def sign_tx(ctx, msg): return LiskSignedTx(signature=signature) -async def _get_keys(ctx, msg): - node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE) +def _get_keys(keychain, msg): + node = keychain.derive(msg.address_n, LISK_CURVE) seckey = node.private_key() pubkey = node.public_key() diff --git a/src/apps/monero/__init__.py b/src/apps/monero/__init__.py index c2a8364cd3..d453140650 100644 --- a/src/apps/monero/__init__.py +++ b/src/apps/monero/__init__.py @@ -1,12 +1,20 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.MoneroGetAddress, __name__, "get_address") - wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only") - wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx") - wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync") + ns = [ + ["secp256k1", HARDENED | 44, HARDENED | 128], + ["ed25519", HARDENED | 44, HARDENED | 128], + ] + wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns) + wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns) + wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx", ns) + wire.add( + MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync", ns + ) if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"): wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag") diff --git a/src/apps/monero/get_address.py b/src/apps/monero/get_address.py index 873420d3ca..aced5b7dc4 100644 --- a/src/apps/monero/get_address.py +++ b/src/apps/monero/get_address.py @@ -5,10 +5,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr from apps.monero import misc -async def get_address(ctx, msg): +async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) - creds = await misc.get_creds(ctx, msg.address_n, msg.network_type) + creds = misc.get_creds(keychain, msg.address_n, msg.network_type) if msg.show_display: desc = address_n_to_str(msg.address_n) diff --git a/src/apps/monero/get_watch_only.py b/src/apps/monero/get_watch_only.py index 311a6fc8bd..8a636a0a86 100644 --- a/src/apps/monero/get_watch_only.py +++ b/src/apps/monero/get_watch_only.py @@ -7,12 +7,12 @@ from apps.monero.layout import confirms from apps.monero.xmr import crypto -async def get_watch_only(ctx, msg: MoneroGetWatchKey): +async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain): await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await confirms.require_confirm_watchkey(ctx) - creds = await misc.get_creds(ctx, msg.address_n, msg.network_type) + creds = misc.get_creds(keychain, msg.address_n, msg.network_type) address = creds.address watch_key = crypto.encodeint(creds.view_key_private) diff --git a/src/apps/monero/key_image_sync.py b/src/apps/monero/key_image_sync.py index de265b8eaf..e025642310 100644 --- a/src/apps/monero/key_image_sync.py +++ b/src/apps/monero/key_image_sync.py @@ -14,10 +14,10 @@ from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr.crypto import chacha_poly -async def key_image_sync(ctx, msg): +async def key_image_sync(ctx, msg, keychain): state = KeyImageSync() - res = await _init_step(state, ctx, msg) + res = await _init_step(state, ctx, msg, keychain) while True: msg = await ctx.call( res, @@ -46,10 +46,10 @@ class KeyImageSync: self.hasher = crypto.get_keccak() -async def _init_step(s, ctx, msg): +async def _init_step(s, ctx, msg, keychain): await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) - s.creds = await misc.get_creds(ctx, msg.address_n, msg.network_type) + s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) await confirms.require_confirm_keyimage_sync(ctx) diff --git a/src/apps/monero/misc.py b/src/apps/monero/misc.py index 5cdd459c96..2e5cd0474a 100644 --- a/src/apps/monero/misc.py +++ b/src/apps/monero/misc.py @@ -1,8 +1,7 @@ from apps.common import HARDENED -async def get_creds(ctx, address_n=None, network_type=None): - from apps.common import seed +def get_creds(keychain, address_n=None, network_type=None): from apps.monero.xmr import crypto, monero from apps.monero.xmr.credentials import AccountCreds @@ -12,7 +11,7 @@ async def get_creds(ctx, address_n=None, network_type=None): curve = "ed25519" else: curve = "secp256k1" - node = await seed.derive_node(ctx, address_n, curve) + node = keychain.derive(address_n, curve) if use_slip0010: key_seed = node.private_key() diff --git a/src/apps/monero/sign_tx.py b/src/apps/monero/sign_tx.py index c85118681e..0382d28ac4 100644 --- a/src/apps/monero/sign_tx.py +++ b/src/apps/monero/sign_tx.py @@ -6,7 +6,7 @@ from trezor.messages import MessageType from apps.monero.signing.state import State -async def sign_tx(ctx, received_msg): +async def sign_tx(ctx, received_msg, keychain): state = State(ctx) mods = utils.unimport_begin() @@ -18,7 +18,7 @@ async def sign_tx(ctx, received_msg): gc.collect() gc.threshold(gc.mem_free() // 4 + gc.mem_alloc()) - result_msg, accept_msgs = await sign_tx_dispatch(state, received_msg) + result_msg, accept_msgs = await sign_tx_dispatch(state, received_msg, keychain) if accept_msgs is None: break @@ -32,13 +32,13 @@ async def sign_tx(ctx, received_msg): return result_msg -async def sign_tx_dispatch(state, msg): +async def sign_tx_dispatch(state, msg, keychain): if msg.MESSAGE_WIRE_TYPE == MessageType.MoneroTransactionInitRequest: from apps.monero.signing import step_01_init_transaction return ( await step_01_init_transaction.init_transaction( - state, msg.address_n, msg.network_type, msg.tsx_data + state, msg.address_n, msg.network_type, msg.tsx_data, keychain ), (MessageType.MoneroTransactionSetInputRequest,), ) diff --git a/src/apps/monero/signing/step_01_init_transaction.py b/src/apps/monero/signing/step_01_init_transaction.py index df7496c1e2..a5c25ea86e 100644 --- a/src/apps/monero/signing/step_01_init_transaction.py +++ b/src/apps/monero/signing/step_01_init_transaction.py @@ -16,14 +16,18 @@ if False: async def init_transaction( - state: State, address_n: list, network_type: int, tsx_data: MoneroTransactionData + state: State, + address_n: list, + network_type: int, + tsx_data: MoneroTransactionData, + keychain, ): from apps.monero.signing import offloading_keys from apps.common import paths await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n) - state.creds = await misc.get_creds(state.ctx, address_n, network_type) + state.creds = misc.get_creds(keychain, address_n, network_type) state.fee = state.fee if state.fee > 0 else 0 state.tx_priv = crypto.random_scalar() state.tx_pub = crypto.scalarmult_base(state.tx_priv) diff --git a/src/apps/nem/__init__.py b/src/apps/nem/__init__.py index f75b6ff3d6..8433f24482 100644 --- a/src/apps/nem/__init__.py +++ b/src/apps/nem/__init__.py @@ -1,7 +1,13 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.NEMGetAddress, __name__, "get_address") - wire.add(MessageType.NEMSignTx, __name__, "sign_tx") + ns = [ + ["ed25519-keccak", HARDENED | 44, HARDENED | 43], + ["ed25519-keccak", HARDENED | 44, HARDENED | 1], + ] + wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns) + wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/nem/get_address.py b/src/apps/nem/get_address.py index 2ca653e552..676e43657e 100644 --- a/src/apps/nem/get_address.py +++ b/src/apps/nem/get_address.py @@ -3,16 +3,15 @@ from trezor.messages.NEMAddress import NEMAddress from .helpers import NEM_CURVE, check_path, get_network_str from .validators import validate_network -from apps.common import seed from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.paths import validate_path -async def get_address(ctx, msg): +async def get_address(ctx, msg, keychain): network = validate_network(msg.network) - await validate_path(ctx, check_path, path=msg.address_n, network=msg.network) + await validate_path(ctx, check_path, path=msg.address_n, network=network) - node = await seed.derive_node(ctx, msg.address_n, NEM_CURVE) + node = keychain.derive(msg.address_n, NEM_CURVE) address = node.nem_address(network) if msg.show_display: diff --git a/src/apps/nem/sign_tx.py b/src/apps/nem/sign_tx.py index 7ca411f8bb..1a76f163b7 100644 --- a/src/apps/nem/sign_tx.py +++ b/src/apps/nem/sign_tx.py @@ -2,21 +2,21 @@ from trezor.crypto.curve import ed25519 from trezor.messages.NEMSignedTx import NEMSignedTx from trezor.messages.NEMSignTx import NEMSignTx -from . import mosaic, multisig, namespace, transfer -from .helpers import NEM_CURVE, NEM_HASH_ALG, check_path -from .validators import validate - from apps.common import seed from apps.common.paths import validate_path +from apps.nem import mosaic, multisig, namespace, transfer +from apps.nem.helpers import NEM_CURVE, NEM_HASH_ALG, check_path +from apps.nem.validators import validate -async def sign_tx(ctx, msg: NEMSignTx): +async def sign_tx(ctx, msg: NEMSignTx, keychain): validate(msg) + await validate_path( ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network ) - node = await seed.derive_node(ctx, msg.transaction.address_n, NEM_CURVE) + node = keychain.derive(msg.transaction.address_n, NEM_CURVE) if msg.multisig: public_key = msg.multisig.signer diff --git a/src/apps/ripple/__init__.py b/src/apps/ripple/__init__.py index f4a863edb6..2e24cf5816 100644 --- a/src/apps/ripple/__init__.py +++ b/src/apps/ripple/__init__.py @@ -1,7 +1,10 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.RippleGetAddress, __name__, "get_address") - wire.add(MessageType.RippleSignTx, __name__, "sign_tx") + ns = [["secp256k1", HARDENED | 44, HARDENED | 144]] + wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns) + wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/ripple/get_address.py b/src/apps/ripple/get_address.py index 5af586a747..a5b26ba762 100644 --- a/src/apps/ripple/get_address.py +++ b/src/apps/ripple/get_address.py @@ -1,16 +1,15 @@ from trezor.messages.RippleAddress import RippleAddress from trezor.messages.RippleGetAddress import RippleGetAddress -from . import helpers - -from apps.common import paths, seed +from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr +from apps.ripple import helpers -async def get_address(ctx, msg: RippleGetAddress): +async def get_address(ctx, msg: RippleGetAddress, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n) + node = keychain.derive(msg.address_n) pubkey = node.public_key() address = helpers.address_from_public_key(pubkey) diff --git a/src/apps/ripple/sign_tx.py b/src/apps/ripple/sign_tx.py index b35227f78d..b3d0be7c6a 100644 --- a/src/apps/ripple/sign_tx.py +++ b/src/apps/ripple/sign_tx.py @@ -5,17 +5,17 @@ from trezor.messages.RippleSignedTx import RippleSignedTx from trezor.messages.RippleSignTx import RippleSignTx from trezor.wire import ProcessError -from . import helpers, layout -from .serialize import serialize - -from apps.common import paths, seed +from apps.common import paths +from apps.ripple import helpers, layout +from apps.ripple.serialize import serialize -async def sign_tx(ctx, msg: RippleSignTx): +async def sign_tx(ctx, msg: RippleSignTx, keychain): validate(msg) + await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n) + node = keychain.derive(msg.address_n) source_address = helpers.address_from_public_key(node.public_key()) set_canonical_flag(msg) diff --git a/src/apps/stellar/__init__.py b/src/apps/stellar/__init__.py index 92fba60b08..ae5e40c88a 100644 --- a/src/apps/stellar/__init__.py +++ b/src/apps/stellar/__init__.py @@ -1,7 +1,10 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.StellarGetAddress, __name__, "get_address") - wire.add(MessageType.StellarSignTx, __name__, "sign_tx") + ns = [["ed25519", HARDENED | 44, HARDENED | 148]] + wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns) + wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/stellar/get_address.py b/src/apps/stellar/get_address.py index 1e15f6b586..8b87d34f8b 100644 --- a/src/apps/stellar/get_address.py +++ b/src/apps/stellar/get_address.py @@ -6,10 +6,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr from apps.stellar import helpers -async def get_address(ctx, msg: StellarGetAddress): +async def get_address(ctx, msg: StellarGetAddress, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, helpers.STELLAR_CURVE) + node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) address = helpers.address_from_public_key(pubkey) diff --git a/src/apps/stellar/sign_tx.py b/src/apps/stellar/sign_tx.py index f6efa56d70..55152348b6 100644 --- a/src/apps/stellar/sign_tx.py +++ b/src/apps/stellar/sign_tx.py @@ -12,15 +12,15 @@ from apps.stellar import consts, helpers, layout, writers from apps.stellar.operations import process_operation -async def sign_tx(ctx, msg: StellarSignTx): - if msg.num_operations == 0: - raise ProcessError("Stellar: At least one operation is required") - +async def sign_tx(ctx, msg: StellarSignTx, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, consts.STELLAR_CURVE) + node = keychain.derive(msg.address_n, consts.STELLAR_CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) + if msg.num_operations == 0: + raise ProcessError("Stellar: At least one operation is required") + w = bytearray() await _init(ctx, w, pubkey, msg) _timebounds(w, msg.timebounds_start, msg.timebounds_end) diff --git a/src/apps/tezos/__init__.py b/src/apps/tezos/__init__.py index ba29093bd4..1c5aa3699d 100644 --- a/src/apps/tezos/__init__.py +++ b/src/apps/tezos/__init__.py @@ -1,8 +1,11 @@ from trezor import wire from trezor.messages import MessageType +from apps.common import HARDENED + def boot(): - wire.add(MessageType.TezosGetAddress, __name__, "get_address") - wire.add(MessageType.TezosSignTx, __name__, "sign_tx") - wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key") + ns = [["ed25519", HARDENED | 44, HARDENED | 1729]] + wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns) + wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns) + wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key", ns) diff --git a/src/apps/tezos/get_address.py b/src/apps/tezos/get_address.py index 9640d2c971..4ff5635129 100644 --- a/src/apps/tezos/get_address.py +++ b/src/apps/tezos/get_address.py @@ -6,9 +6,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr from apps.tezos import helpers -async def get_address(ctx, msg): +async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE) + + node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) pk = seed.remove_ed25519_prefix(node.public_key()) pkh = hashlib.blake2b(pk, outlen=20).digest() diff --git a/src/apps/tezos/get_public_key.py b/src/apps/tezos/get_public_key.py index c96d468dde..c3f75ad97b 100644 --- a/src/apps/tezos/get_public_key.py +++ b/src/apps/tezos/get_public_key.py @@ -9,10 +9,10 @@ from apps.common.confirm import require_confirm from apps.tezos import helpers -async def get_public_key(ctx, msg): +async def get_public_key(ctx, msg, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE) + node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) pk = seed.remove_ed25519_prefix(node.public_key()) pk_prefixed = helpers.base58_encode_check(pk, prefix=helpers.TEZOS_PUBLICKEY_PREFIX) diff --git a/src/apps/tezos/sign_tx.py b/src/apps/tezos/sign_tx.py index 059acefe6a..5cf0e11d3b 100644 --- a/src/apps/tezos/sign_tx.py +++ b/src/apps/tezos/sign_tx.py @@ -4,14 +4,15 @@ from trezor.crypto.curve import ed25519 from trezor.messages import TezosContractType from trezor.messages.TezosSignedTx import TezosSignedTx -from apps.common import paths, seed +from apps.common import paths from apps.common.writers import write_bytes, write_uint8 from apps.tezos import helpers, layout -async def sign_tx(ctx, msg): +async def sign_tx(ctx, msg, keychain): await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) - node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE) + + node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) if msg.transaction is not None: to = _get_address_from_contract(msg.transaction.destination) diff --git a/src/apps/wallet/__init__.py b/src/apps/wallet/__init__.py index 7fbf9b3976..fc70179cc1 100644 --- a/src/apps/wallet/__init__.py +++ b/src/apps/wallet/__init__.py @@ -3,12 +3,22 @@ from trezor.messages import MessageType def boot(): - wire.add(MessageType.GetPublicKey, __name__, "get_public_key") - wire.add(MessageType.GetAddress, __name__, "get_address") + ns = [ + ["curve25519"], + ["ed25519"], + ["ed25519-keccak"], + ["nist256p1"], + ["secp256k1"], + ["secp256k1-decred"], + ["secp256k1-groestl"], + ["secp256k1-smart"], + ] + wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns) + wire.add(MessageType.GetAddress, __name__, "get_address", ns) wire.add(MessageType.GetEntropy, __name__, "get_entropy") - wire.add(MessageType.SignTx, __name__, "sign_tx") - wire.add(MessageType.SignMessage, __name__, "sign_message") + wire.add(MessageType.SignTx, __name__, "sign_tx", ns) + wire.add(MessageType.SignMessage, __name__, "sign_message", ns) wire.add(MessageType.VerifyMessage, __name__, "verify_message") - wire.add(MessageType.SignIdentity, __name__, "sign_identity") - wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key") - wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value") + wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns) + wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns) + wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns) diff --git a/src/apps/wallet/cipher_key_value.py b/src/apps/wallet/cipher_key_value.py index 44c04ad84e..b6c14663bf 100644 --- a/src/apps/wallet/cipher_key_value.py +++ b/src/apps/wallet/cipher_key_value.py @@ -4,11 +4,10 @@ from trezor.crypto.hashlib import sha512 from trezor.messages.CipheredKeyValue import CipheredKeyValue from trezor.ui.text import Text -from apps.common import seed from apps.common.confirm import require_confirm -async def cipher_key_value(ctx, msg): +async def cipher_key_value(ctx, msg, keychain): if len(msg.value) % 16 > 0: raise wire.DataError("Value length must be a multiple of 16") @@ -23,7 +22,7 @@ async def cipher_key_value(ctx, msg): text.normal(msg.key) await require_confirm(ctx, text) - node = await seed.derive_node(ctx, msg.address_n) + node = keychain.derive(msg.address_n) value = compute_cipher_key_value(msg, node.private_key()) return CipheredKeyValue(value=value) diff --git a/src/apps/wallet/get_address.py b/src/apps/wallet/get_address.py index 929aaf9d07..80096d37c5 100644 --- a/src/apps/wallet/get_address.py +++ b/src/apps/wallet/get_address.py @@ -1,13 +1,13 @@ from trezor.messages import InputScriptType from trezor.messages.Address import Address -from apps.common import coins, seed +from apps.common import coins from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.paths import validate_path from apps.wallet.sign_tx import addresses -async def get_address(ctx, msg): +async def get_address(ctx, msg, keychain): coin_name = msg.coin_name or "Bitcoin" coin = coins.by_name(coin_name) @@ -19,7 +19,7 @@ async def get_address(ctx, msg): script_type=msg.script_type, ) - node = await seed.derive_node(ctx, msg.address_n, curve_name=coin.curve_name) + node = keychain.derive(msg.address_n, coin.curve_name) address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address_short = addresses.address_short(coin, address) diff --git a/src/apps/wallet/get_ecdh_session_key.py b/src/apps/wallet/get_ecdh_session_key.py index 0d15005132..22b5b6071c 100644 --- a/src/apps/wallet/get_ecdh_session_key.py +++ b/src/apps/wallet/get_ecdh_session_key.py @@ -5,7 +5,7 @@ from trezor.messages.ECDHSessionKey import ECDHSessionKey from trezor.ui.text import Text from trezor.utils import chunks -from apps.common import HARDENED, seed +from apps.common import HARDENED from apps.common.confirm import require_confirm from apps.wallet.sign_identity import ( serialize_identity, @@ -13,7 +13,7 @@ from apps.wallet.sign_identity import ( ) -async def get_ecdh_session_key(ctx, msg): +async def get_ecdh_session_key(ctx, msg, keychain): if msg.ecdsa_curve_name is None: msg.ecdsa_curve_name = "secp256k1" @@ -22,7 +22,7 @@ async def get_ecdh_session_key(ctx, msg): await require_confirm_ecdh_session_key(ctx, msg.identity) address_n = get_ecdh_path(identity, msg.identity.index or 0) - node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name) + node = keychain.derive(address_n, msg.ecdsa_curve_name) session_key = ecdh( seckey=node.private_key(), diff --git a/src/apps/wallet/get_entropy.py b/src/apps/wallet/get_entropy.py index f953235863..e35878b07a 100644 --- a/src/apps/wallet/get_entropy.py +++ b/src/apps/wallet/get_entropy.py @@ -7,7 +7,6 @@ from apps.common.confirm import require_confirm async def get_entropy(ctx, msg): - text = Text("Confirm entropy") text.bold("Do you really want", "to send entropy?") text.normal("Continue only if you", "know what you are doing!") diff --git a/src/apps/wallet/get_public_key.py b/src/apps/wallet/get_public_key.py index e809573441..e779ff02dc 100644 --- a/src/apps/wallet/get_public_key.py +++ b/src/apps/wallet/get_public_key.py @@ -3,18 +3,16 @@ from trezor.messages import InputScriptType from trezor.messages.HDNodeType import HDNodeType from trezor.messages.PublicKey import PublicKey -from apps.common import coins, layout, seed +from apps.common import coins, layout -async def get_public_key(ctx, msg): +async def get_public_key(ctx, msg, keychain): coin_name = msg.coin_name or "Bitcoin" coin = coins.by_name(coin_name) + curve_name = msg.ecdsa_curve_name or coin.curve_name script_type = msg.script_type or InputScriptType.SPENDADDRESS - curve_name = msg.ecdsa_curve_name - if not curve_name: - curve_name = coin.curve_name - node = await seed.derive_node(ctx, msg.address_n, curve_name=curve_name) + node = keychain.derive(msg.address_n, curve_name=curve_name) if ( script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG] diff --git a/src/apps/wallet/sign_identity.py b/src/apps/wallet/sign_identity.py index 98f2136ceb..6c71390d6b 100644 --- a/src/apps/wallet/sign_identity.py +++ b/src/apps/wallet/sign_identity.py @@ -6,11 +6,11 @@ from trezor.messages.SignedIdentity import SignedIdentity from trezor.ui.text import Text from trezor.utils import chunks -from apps.common import HARDENED, coins, seed +from apps.common import HARDENED, coins from apps.common.confirm import require_confirm -async def sign_identity(ctx, msg): +async def sign_identity(ctx, msg, keychain): if msg.ecdsa_curve_name is None: msg.ecdsa_curve_name = "secp256k1" @@ -19,7 +19,7 @@ async def sign_identity(ctx, msg): await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual) address_n = get_identity_path(identity, msg.identity.index or 0) - node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name) + node = keychain.derive(address_n, msg.ecdsa_curve_name) coin = coins.by_name("Bitcoin") if msg.ecdsa_curve_name == "secp256k1": diff --git a/src/apps/wallet/sign_message.py b/src/apps/wallet/sign_message.py index 246307476b..ab9e58f879 100644 --- a/src/apps/wallet/sign_message.py +++ b/src/apps/wallet/sign_message.py @@ -4,14 +4,14 @@ from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPEN from trezor.messages.MessageSignature import MessageSignature from trezor.ui.text import Text -from apps.common import coins, seed +from apps.common import coins from apps.common.confirm import require_confirm from apps.common.paths import validate_path from apps.common.signverify import message_digest, split_message from apps.wallet.sign_tx.addresses import get_address, validate_full_path -async def sign_message(ctx, msg): +async def sign_message(ctx, msg, keychain): message = msg.message address_n = msg.address_n coin_name = msg.coin_name or "Bitcoin" @@ -19,7 +19,6 @@ async def sign_message(ctx, msg): coin = coins.by_name(coin_name) await require_confirm_sign_message(ctx, message) - await validate_path( ctx, validate_full_path, @@ -29,7 +28,7 @@ async def sign_message(ctx, msg): validate_script_type=False, ) - node = await seed.derive_node(ctx, address_n, curve_name=coin.curve_name) + node = keychain.derive(address_n, coin.curve_name) seckey = node.private_key() address = get_address(script_type, coin, node) diff --git a/src/apps/wallet/sign_tx/__init__.py b/src/apps/wallet/sign_tx/__init__.py index a32c2b8126..6bcc1aeebd 100644 --- a/src/apps/wallet/sign_tx/__init__.py +++ b/src/apps/wallet/sign_tx/__init__.py @@ -3,53 +3,51 @@ from trezor.messages.MessageType import TxAck from trezor.messages.RequestType import TXFINISHED from trezor.messages.TxRequest import TxRequest -from apps.common import coins, paths, seed -from apps.wallet.sign_tx.helpers import ( - UiConfirmFeeOverThreshold, - UiConfirmForeignAddress, - UiConfirmOutput, - UiConfirmTotal, +from apps.common import paths +from apps.wallet.sign_tx import ( + addresses, + helpers, + layout, + multisig, + progress, + scripts, + segwit_bip143, + signing, ) @ui.layout -async def sign_tx(ctx, msg): - from apps.wallet.sign_tx import layout, progress, signing +async def sign_tx(ctx, msg, keychain): + signer = signing.sign_tx(msg, keychain) - coin_name = msg.coin_name or "Bitcoin" - coin = coins.by_name(coin_name) - # TODO: rework this so we don't have to pass root to signing.sign_tx - root = await seed.derive_node(ctx, [], curve_name=coin.curve_name) - - signer = signing.sign_tx(msg, root) res = None while True: try: req = signer.send(res) except signing.SigningError as e: raise wire.Error(*e.args) - except signing.MultisigError as e: + except multisig.MultisigError as e: raise wire.Error(*e.args) - except signing.AddressError as e: + except addresses.AddressError as e: raise wire.Error(*e.args) - except signing.ScriptsError as e: + except scripts.ScriptsError as e: raise wire.Error(*e.args) - except signing.Bip143Error as e: + except segwit_bip143.Bip143Error as e: raise wire.Error(*e.args) if isinstance(req, TxRequest): if req.request_type == TXFINISHED: break res = await ctx.call(req, TxAck) - elif isinstance(req, UiConfirmOutput): + elif isinstance(req, helpers.UiConfirmOutput): res = await layout.confirm_output(ctx, req.output, req.coin) progress.report_init() - elif isinstance(req, UiConfirmTotal): + elif isinstance(req, helpers.UiConfirmTotal): res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin) progress.report_init() - elif isinstance(req, UiConfirmFeeOverThreshold): + elif isinstance(req, helpers.UiConfirmFeeOverThreshold): res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) progress.report_init() - elif isinstance(req, UiConfirmForeignAddress): + elif isinstance(req, helpers.UiConfirmForeignAddress): res = await paths.show_path_warning(ctx, req.address_n) else: raise TypeError("Invalid signing instruction") diff --git a/src/apps/wallet/sign_tx/decred_prefix_hasher.py b/src/apps/wallet/sign_tx/decred.py similarity index 100% rename from src/apps/wallet/sign_tx/decred_prefix_hasher.py rename to src/apps/wallet/sign_tx/decred.py diff --git a/src/apps/wallet/sign_tx/helpers.py b/src/apps/wallet/sign_tx/helpers.py index b9bf83babe..e8f6a42d9b 100644 --- a/src/apps/wallet/sign_tx/helpers.py +++ b/src/apps/wallet/sign_tx/helpers.py @@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxRequest import TxRequest +from trezor.utils import obj_eq from apps.common.coininfo import CoinInfo @@ -24,6 +25,8 @@ class UiConfirmOutput: self.output = output self.coin = coin + __eq__ = obj_eq + class UiConfirmTotal: def __init__(self, spending: int, fee: int, coin: CoinInfo): @@ -31,17 +34,23 @@ class UiConfirmTotal: self.fee = fee self.coin = coin + __eq__ = obj_eq + class UiConfirmFeeOverThreshold: def __init__(self, fee: int, coin: CoinInfo): self.fee = fee self.coin = coin + __eq__ = obj_eq + class UiConfirmForeignAddress: def __init__(self, address_n: list): self.address_n = address_n + __eq__ = obj_eq + def confirm_output(output: TxOutputType, coin: CoinInfo): return (yield UiConfirmOutput(output, coin)) diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index f72623fa22..ef37fbae38 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -1,35 +1,30 @@ from micropython import const +from trezor import utils from trezor.crypto import base58, bip32, cashaddr, der from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import blake256, sha256 -from trezor.messages import OutputScriptType +from trezor.messages import FailureType, InputScriptType, OutputScriptType +from trezor.messages.SignTx import SignTx +from trezor.messages.TxInputType import TxInputType +from trezor.messages.TxOutputBinType import TxOutputBinType +from trezor.messages.TxOutputType import TxOutputType +from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType -from trezor.utils import HashWriter -from apps.common import address_type, coins -from apps.common.coininfo import CoinInfo -from apps.common.writers import empty_bytearray -from apps.wallet.sign_tx import progress -from apps.wallet.sign_tx.addresses import * -from apps.wallet.sign_tx.decred_prefix_hasher import ( - DECRED_SERIALIZE_NO_WITNESS, - DECRED_SERIALIZE_WITNESS_SIGNING, - DECRED_SIGHASHALL, - DecredPrefixHasher, -) -from apps.wallet.sign_tx.helpers import * -from apps.wallet.sign_tx.multisig import * -from apps.wallet.sign_tx.scripts import * -from apps.wallet.sign_tx.segwit_bip143 import Bip143, Bip143Error # noqa:F401 -from apps.wallet.sign_tx.tx_weight_calculator import * -from apps.wallet.sign_tx.writers import * -from apps.wallet.sign_tx.zcash import ( # noqa:F401 - OVERWINTERED, - ZcashError, - Zip143, - Zip243, +from apps.common import address_type, coininfo, coins, seed +from apps.wallet.sign_tx import ( + addresses, + decred, + helpers, + multisig, + progress, + scripts, + segwit_bip143, + tx_weight, + writers, + zcash, ) # the number of bip32 levels used in a wallet (chain and address) @@ -58,32 +53,32 @@ class SigningError(ValueError): # - check inputs, previous transactions, and outputs # - ask for confirmations # - check fee -async def check_tx_fee(tx: SignTx, root: bip32.HDNode): +async def check_tx_fee(tx: SignTx, keychain: seed.Keychain): coin = coins.by_name(tx.coin_name) # h_first is used to make sure the inputs and outputs streamed in Phase 1 # are the same as in Phase 2. it is thus not required to fully hash the # tx, as the SignTx info is streamed only once - h_first = HashWriter(sha256()) # not a real tx hash + h_first = utils.HashWriter(sha256()) # not a real tx hash if coin.decred: - hash143 = DecredPrefixHasher(tx) # pseudo bip143 prefix hashing + hash143 = decred.DecredPrefixHasher(tx) # pseudo BIP-0143 prefix hashing tx_ser = TxRequestSerializedType() elif tx.overwintered: if tx.version == 3: - hash143 = Zip143() # ZIP-0143 transaction hashing + hash143 = zcash.Zip143() # ZIP-0143 transaction hashing elif tx.version == 4: - hash143 = Zip243() # ZIP-0243 transaction hashing + hash143 = zcash.Zip243() # ZIP-0243 transaction hashing else: raise SigningError( FailureType.DataError, "Unsupported version for overwintered transaction", ) else: - hash143 = Bip143() # BIP-0143 transaction hashing + hash143 = segwit_bip143.Bip143() # BIP-0143 transaction hashing - multifp = MultisigFingerprint() # control checksum of multisig inputs - weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) + multifp = multisig.MultisigFingerprint() # control checksum of multisig inputs + weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) total_in = 0 # sum of input amounts segwit_in = 0 # sum of segwit input amounts @@ -100,15 +95,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): for i in range(tx.inputs_count): progress.advance() # STAGE_REQUEST_1_INPUT - txi = await request_tx_input(tx_req, i) + txi = await helpers.request_tx_input(tx_req, i) wallet_path = input_extract_wallet_path(txi, wallet_path) - write_tx_input_check(h_first, txi) + writers.write_tx_input_check(h_first, txi) weight.add_input(txi) hash143.add_prevouts(txi) # all inputs are included (non-segwit as well) hash143.add_sequence(txi) - if not validate_full_path(txi.address_n, coin, txi.script_type): - await confirm_foreign_address(txi.address_n) + if not addresses.validate_full_path(txi.address_n, coin, txi.script_type): + await helpers.confirm_foreign_address(txi.address_n) if txi.multisig: multifp.add(txi.multisig) @@ -149,10 +144,10 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): raise SigningError(FailureType.DataError, "Wrong input script type") if coin.decred: - w_txi = empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash)) + w_txi = writers.empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash)) if i == 0: # serializing first input => prepend headers - write_bytes(w_txi, get_tx_header(coin, tx)) - write_tx_input_decred(w_txi, txi) + writers.write_bytes(w_txi, get_tx_header(coin, tx)) + writers.write_tx_input_decred(w_txi, txi) tx_ser.serialized_tx = w_txi tx_req.serialized = tx_ser @@ -161,15 +156,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): for o in range(tx.outputs_count): # STAGE_REQUEST_3_OUTPUT - txo = await request_tx_output(tx_req, o) + txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) weight.add_output(txo_bin.script_pubkey) - if change_out == 0 and is_change(txo, wallet_path, segwit_in, multifp): + if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp): # output is change and does not need confirmation change_out = txo.amount - elif not await confirm_output(txo, coin): + elif not await helpers.confirm_output(txo, coin): raise SigningError(FailureType.ActionCancelled, "Output cancelled") if coin.decred: @@ -180,15 +175,17 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): ) txo_bin.decred_script_version = txo.decred_script_version - w_txo_bin = empty_bytearray(4 + 8 + 2 + 4 + len(txo_bin.script_pubkey)) + w_txo_bin = writers.empty_bytearray( + 4 + 8 + 2 + 4 + len(txo_bin.script_pubkey) + ) if o == 0: # serializing first output => prepend outputs count - write_varint(w_txo_bin, tx.outputs_count) - write_tx_output(w_txo_bin, txo_bin) + writers.write_varint(w_txo_bin, tx.outputs_count) + writers.write_tx_output(w_txo_bin, txo_bin) tx_ser.serialized_tx = w_txo_bin tx_req.serialized = tx_ser hash143.set_last_output_bytes(w_txo_bin) - write_tx_output(h_first, txo_bin) + writers.write_tx_output(h_first, txo_bin) hash143.add_output(txo_bin) total_out += txo_bin.amount @@ -198,10 +195,10 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): # fee > (coin.maxfee per byte * tx size) if fee > (coin.maxfee_kb / 1000) * (weight.get_total() / 4): - if not await confirm_feeoverthreshold(fee, coin): + if not await helpers.confirm_feeoverthreshold(fee, coin): raise SigningError(FailureType.ActionCancelled, "Signing cancelled") - if not await confirm_total(total_in - change_out, fee, coin): + if not await helpers.confirm_total(total_in - change_out, fee, coin): raise SigningError(FailureType.ActionCancelled, "Total cancelled") if coin.decred: @@ -210,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): return h_first, hash143, segwit, total_in, wallet_path -async def sign_tx(tx: SignTx, root: bip32.HDNode): - tx = sanitize_sign_tx(tx) +async def sign_tx(tx: SignTx, keychain: seed.Keychain): + tx = helpers.sanitize_sign_tx(tx) progress.init(tx.inputs_count, tx.outputs_count) # Phase 1 - h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) + h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee( + tx, keychain + ) # Phase 2 # - sign inputs @@ -242,34 +241,30 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): if segwit[i_sign]: # STAGE_REQUEST_SEGWIT_INPUT - txi_sign = await request_tx_input(tx_req, i_sign) + txi_sign = await helpers.request_tx_input(tx_req, i_sign) - is_segwit = ( - txi_sign.script_type == InputScriptType.SPENDWITNESS - or txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS - ) - if not is_segwit: + if not input_is_segwit(txi_sign): raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) input_check_wallet_path(txi_sign, wallet_path) - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub) - w_txi = empty_bytearray( + w_txi = writers.empty_bytearray( 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4 ) if i_sign == 0: # serializing first input => prepend headers - write_bytes(w_txi, get_tx_header(coin, tx, True)) - write_tx_input(w_txi, txi_sign) + writers.write_bytes(w_txi, get_tx_header(coin, tx, True)) + writers.write_tx_input(w_txi, txi_sign) tx_ser.serialized_tx = w_txi tx_req.serialized = tx_ser elif coin.force_bip143 or tx.overwintered: # STAGE_REQUEST_SEGWIT_INPUT - txi_sign = await request_tx_input(tx_req, i_sign) + txi_sign = await helpers.request_tx_input(tx_req, i_sign) input_check_wallet_path(txi_sign, wallet_path) is_bip143 = ( @@ -282,19 +277,19 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): ) authorized_in -= txi_sign.amount - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() hash143_hash = hash143.preimage_hash( coin, tx, txi_sign, - ecdsa_hash_pubkey(key_sign_pub, coin), + addresses.ecdsa_hash_pubkey(key_sign_pub, coin), get_hash_type(coin), ) - # if multisig, check if singing with a key that is included in multisig + # if multisig, check if signing with a key that is included in multisig if txi_sign.multisig: - multisig_pubkey_index(txi_sign.multisig, key_sign_pub) + multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub) signature = ecdsa_sign(key_sign, hash143_hash) tx_ser.signature_index = i_sign @@ -304,56 +299,59 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) - w_txi_sign = empty_bytearray( + w_txi_sign = writers.empty_bytearray( 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4 ) if i_sign == 0: # serializing first input => prepend headers - write_bytes(w_txi_sign, get_tx_header(coin, tx)) - write_tx_input(w_txi_sign, txi_sign) + writers.write_bytes(w_txi_sign, get_tx_header(coin, tx)) + writers.write_tx_input(w_txi_sign, txi_sign) tx_ser.serialized_tx = w_txi_sign tx_req.serialized = tx_ser elif coin.decred: - txi_sign = await request_tx_input(tx_req, i_sign) + txi_sign = await helpers.request_tx_input(tx_req, i_sign) input_check_wallet_path(txi_sign, wallet_path) - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() if txi_sign.script_type == InputScriptType.SPENDMULTISIG: - prev_pkscript = output_script_multisig( - multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m + prev_pkscript = scripts.output_script_multisig( + multisig.multisig_get_pubkeys(txi_sign.multisig), + txi_sign.multisig.m, ) elif txi_sign.script_type == InputScriptType.SPENDADDRESS: - prev_pkscript = output_script_p2pkh( - ecdsa_hash_pubkey(key_sign_pub, coin) + prev_pkscript = scripts.output_script_p2pkh( + addresses.ecdsa_hash_pubkey(key_sign_pub, coin) ) else: raise ValueError("Unknown input script type") - h_witness = HashWriter(blake256()) - write_uint32(h_witness, tx.version | DECRED_SERIALIZE_WITNESS_SIGNING) - write_varint(h_witness, tx.inputs_count) + h_witness = utils.HashWriter(blake256()) + writers.write_uint32( + h_witness, tx.version | decred.DECRED_SERIALIZE_WITNESS_SIGNING + ) + writers.write_varint(h_witness, tx.inputs_count) for ii in range(tx.inputs_count): if ii == i_sign: - write_varint(h_witness, len(prev_pkscript)) - write_bytes(h_witness, prev_pkscript) + writers.write_varint(h_witness, len(prev_pkscript)) + writers.write_bytes(h_witness, prev_pkscript) else: - write_varint(h_witness, 0) + writers.write_varint(h_witness, 0) - witness_hash = get_tx_hash( + witness_hash = writers.get_tx_hash( h_witness, double=coin.sign_hash_double, reverse=False ) - h_sign = HashWriter(blake256()) - write_uint32(h_sign, DECRED_SIGHASHALL) - write_bytes(h_sign, prefix_hash) - write_bytes(h_sign, witness_hash) + h_sign = utils.HashWriter(blake256()) + writers.write_uint32(h_sign, decred.DECRED_SIGHASHALL) + writers.write_bytes(h_sign, prefix_hash) + writers.write_bytes(h_sign, witness_hash) - sig_hash = get_tx_hash(h_sign, double=coin.sign_hash_double) + sig_hash = writers.get_tx_hash(h_sign, double=coin.sign_hash_double) signature = ecdsa_sign(key_sign, sig_hash) tx_ser.signature_index = i_sign tx_ser.signature = signature @@ -362,61 +360,62 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) - w_txi_sign = empty_bytearray( + w_txi_sign = writers.empty_bytearray( 8 + 4 + len(hash143.get_last_output_bytes()) if i_sign == 0 else 0 + 16 + 4 + len(txi_sign.script_sig) ) if i_sign == 0: - write_bytes(w_txi_sign, hash143.get_last_output_bytes()) - write_uint32(w_txi_sign, tx.lock_time) - write_uint32(w_txi_sign, tx.expiry) - write_varint(w_txi_sign, tx.inputs_count) + writers.write_bytes(w_txi_sign, hash143.get_last_output_bytes()) + writers.write_uint32(w_txi_sign, tx.lock_time) + writers.write_uint32(w_txi_sign, tx.expiry) + writers.write_varint(w_txi_sign, tx.inputs_count) - write_tx_input_decred_witness(w_txi_sign, txi_sign) + writers.write_tx_input_decred_witness(w_txi_sign, txi_sign) tx_ser.serialized_tx = w_txi_sign tx_req.serialized = tx_ser else: # hash of what we are signing with this input - h_sign = HashWriter(sha256()) + h_sign = utils.HashWriter(sha256()) # same as h_first, checked before signing the digest - h_second = HashWriter(sha256()) + h_second = utils.HashWriter(sha256()) if tx.overwintered: - write_uint32( - h_sign, tx.version | OVERWINTERED + writers.write_uint32( + h_sign, tx.version | zcash.OVERWINTERED ) # nVersion | fOverwintered - write_uint32(h_sign, tx.version_group_id) # nVersionGroupId + writers.write_uint32(h_sign, tx.version_group_id) # nVersionGroupId else: - write_uint32(h_sign, tx.version) # nVersion + writers.write_uint32(h_sign, tx.version) # nVersion if tx.timestamp: - write_uint32(h_sign, tx.timestamp) + writers.write_uint32(h_sign, tx.timestamp) - write_varint(h_sign, tx.inputs_count) + writers.write_varint(h_sign, tx.inputs_count) for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT - txi = await request_tx_input(tx_req, i) + txi = await helpers.request_tx_input(tx_req, i) input_check_wallet_path(txi, wallet_path) - write_tx_input_check(h_second, txi) + writers.write_tx_input_check(h_second, txi) if i == i_sign: txi_sign = txi - key_sign = node_derive(root, txi.address_n) + key_sign = keychain.derive(txi.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() # for the signing process the script_sig is equal # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) if txi_sign.script_type == InputScriptType.SPENDMULTISIG: - txi_sign.script_sig = output_script_multisig( - multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m + txi_sign.script_sig = scripts.output_script_multisig( + multisig.multisig_get_pubkeys(txi_sign.multisig), + txi_sign.multisig.m, ) elif txi_sign.script_type == InputScriptType.SPENDADDRESS: - txi_sign.script_sig = output_script_p2pkh( - ecdsa_hash_pubkey(key_sign_pub, coin) + txi_sign.script_sig = scripts.output_script_p2pkh( + addresses.ecdsa_hash_pubkey(key_sign_pub, coin) ) if coin.bip115: - txi_sign.script_sig += script_replay_protection_bip115( + txi_sign.script_sig += scripts.script_replay_protection_bip115( txi_sign.prev_block_hash_bip115, txi_sign.prev_block_height_bip115, ) @@ -426,38 +425,38 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): ) else: txi.script_sig = bytes() - write_tx_input(h_sign, txi) + writers.write_tx_input(h_sign, txi) - write_varint(h_sign, tx.outputs_count) + writers.write_varint(h_sign, tx.outputs_count) for o in range(tx.outputs_count): # STAGE_REQUEST_4_OUTPUT - txo = await request_tx_output(tx_req, o) + txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) - write_tx_output(h_second, txo_bin) - write_tx_output(h_sign, txo_bin) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) + writers.write_tx_output(h_second, txo_bin) + writers.write_tx_output(h_sign, txo_bin) - write_uint32(h_sign, tx.lock_time) + writers.write_uint32(h_sign, tx.lock_time) if tx.overwintered: - write_uint32(h_sign, tx.expiry) # expiryHeight - write_varint(h_sign, 0) # nJoinSplit + writers.write_uint32(h_sign, tx.expiry) # expiryHeight + writers.write_varint(h_sign, 0) # nJoinSplit - write_uint32(h_sign, get_hash_type(coin)) + writers.write_uint32(h_sign, get_hash_type(coin)) # check the control digests - if get_tx_hash(h_first, False) != get_tx_hash(h_second): + if writers.get_tx_hash(h_first, False) != writers.get_tx_hash(h_second): raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) - # if multisig, check if singing with a key that is included in multisig + # if multisig, check if signing with a key that is included in multisig if txi_sign.multisig: - multisig_pubkey_index(txi_sign.multisig, key_sign_pub) + multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub) # compute the signature from the tx digest signature = ecdsa_sign( - key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double) + key_sign, writers.get_tx_hash(h_sign, double=coin.sign_hash_double) ) tx_ser.signature_index = i_sign tx_ser.signature = signature @@ -466,31 +465,31 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) - w_txi_sign = empty_bytearray( + w_txi_sign = writers.empty_bytearray( 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4 ) if i_sign == 0: # serializing first input => prepend headers - write_bytes(w_txi_sign, get_tx_header(coin, tx)) - write_tx_input(w_txi_sign, txi_sign) + writers.write_bytes(w_txi_sign, get_tx_header(coin, tx)) + writers.write_tx_input(w_txi_sign, txi_sign) tx_ser.serialized_tx = w_txi_sign tx_req.serialized = tx_ser if coin.decred: - return await request_tx_finish(tx_req) + return await helpers.request_tx_finish(tx_req) for o in range(tx.outputs_count): progress.advance() # STAGE_REQUEST_5_OUTPUT - txo = await request_tx_output(tx_req, o) + txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) # serialize output - w_txo_bin = empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) + w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) if o == 0: # serializing first output => prepend outputs count - write_varint(w_txo_bin, tx.outputs_count) - write_tx_output(w_txo_bin, txo_bin) + writers.write_varint(w_txo_bin, tx.outputs_count) + writers.write_tx_output(w_txo_bin, txo_bin) tx_ser.signature_index = None tx_ser.signature = None @@ -504,38 +503,38 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): progress.advance() if segwit[i]: # STAGE_REQUEST_SEGWIT_WITNESS - txi = await request_tx_input(tx_req, i) + txi = await helpers.request_tx_input(tx_req, i) input_check_wallet_path(txi, wallet_path) - is_segwit = ( - txi.script_type == InputScriptType.SPENDWITNESS - or txi.script_type == InputScriptType.SPENDP2SHWITNESS - ) - if not is_segwit or txi.amount > authorized_in: + if not input_is_segwit(txi) or txi.amount > authorized_in: raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) authorized_in -= txi.amount - key_sign = node_derive(root, txi.address_n) + key_sign = keychain.derive(txi.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() hash143_hash = hash143.preimage_hash( coin, tx, txi, - ecdsa_hash_pubkey(key_sign_pub, coin), + addresses.ecdsa_hash_pubkey(key_sign_pub, coin), get_hash_type(coin), ) signature = ecdsa_sign(key_sign, hash143_hash) if txi.multisig: # find out place of our signature based on the pubkey - signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub) - witness = witness_p2wsh( + signature_index = multisig.multisig_pubkey_index( + txi.multisig, key_sign_pub + ) + witness = scripts.witness_p2wsh( txi.multisig, signature, signature_index, get_hash_type(coin) ) else: - witness = witness_p2wpkh(signature, key_sign_pub, get_hash_type(coin)) + witness = scripts.witness_p2wpkh( + signature, key_sign_pub, get_hash_type(coin) + ) tx_ser.serialized_tx = witness tx_ser.signature_index = i @@ -547,66 +546,68 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): tx_req.serialized = tx_ser - write_uint32(tx_ser.serialized_tx, tx.lock_time) + writers.write_uint32(tx_ser.serialized_tx, tx.lock_time) if tx.overwintered: if tx.version == 3: - write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight - write_varint(tx_ser.serialized_tx, 0) # nJoinSplit + writers.write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight + writers.write_varint(tx_ser.serialized_tx, 0) # nJoinSplit elif tx.version == 4: - write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight - write_uint64(tx_ser.serialized_tx, 0) # valueBalance - write_varint(tx_ser.serialized_tx, 0) # nShieldedSpend - write_varint(tx_ser.serialized_tx, 0) # nShieldedOutput - write_varint(tx_ser.serialized_tx, 0) # nJoinSplit + writers.write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight + writers.write_uint64(tx_ser.serialized_tx, 0) # valueBalance + writers.write_varint(tx_ser.serialized_tx, 0) # nShieldedSpend + writers.write_varint(tx_ser.serialized_tx, 0) # nShieldedOutput + writers.write_varint(tx_ser.serialized_tx, 0) # nJoinSplit else: raise SigningError( FailureType.DataError, "Unsupported version for overwintered transaction", ) - await request_tx_finish(tx_req) + await helpers.request_tx_finish(tx_req) async def get_prevtx_output_value( - coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int + coin: coininfo.CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int ) -> int: total_out = 0 # sum of output amounts # STAGE_REQUEST_2_PREV_META - tx = await request_tx_meta(tx_req, prev_hash) + tx = await helpers.request_tx_meta(tx_req, prev_hash) if coin.decred: - txh = HashWriter(blake256()) + txh = utils.HashWriter(blake256()) else: - txh = HashWriter(sha256()) + txh = utils.HashWriter(sha256()) if tx.overwintered: - write_uint32(txh, tx.version | OVERWINTERED) # nVersion | fOverwintered - write_uint32(txh, tx.version_group_id) # nVersionGroupId + writers.write_uint32( + txh, tx.version | zcash.OVERWINTERED + ) # nVersion | fOverwintered + writers.write_uint32(txh, tx.version_group_id) # nVersionGroupId elif coin.decred: - write_uint32(txh, tx.version | DECRED_SERIALIZE_NO_WITNESS) + writers.write_uint32(txh, tx.version | decred.DECRED_SERIALIZE_NO_WITNESS) else: - write_uint32(txh, tx.version) # nVersion + writers.write_uint32(txh, tx.version) # nVersion if tx.timestamp: - write_uint32(txh, tx.timestamp) + writers.write_uint32(txh, tx.timestamp) - write_varint(txh, tx.inputs_cnt) + writers.write_varint(txh, tx.inputs_cnt) for i in range(tx.inputs_cnt): # STAGE_REQUEST_2_PREV_INPUT - txi = await request_tx_input(tx_req, i, prev_hash) + txi = await helpers.request_tx_input(tx_req, i, prev_hash) if coin.decred: - write_tx_input_decred(txh, txi) + writers.write_tx_input_decred(txh, txi) else: - write_tx_input(txh, txi) + writers.write_tx_input(txh, txi) - write_varint(txh, tx.outputs_cnt) + writers.write_varint(txh, tx.outputs_cnt) for o in range(tx.outputs_cnt): # STAGE_REQUEST_2_PREV_OUTPUT - txo_bin = await request_tx_output(tx_req, o, prev_hash) - write_tx_output(txh, txo_bin) + txo_bin = await helpers.request_tx_output(tx_req, o, prev_hash) + writers.write_tx_output(txh, txo_bin) if o == prev_index: total_out += txo_bin.amount if ( @@ -619,19 +620,22 @@ async def get_prevtx_output_value( "Cannot use utxo that has script_version != 0", ) - write_uint32(txh, tx.lock_time) + writers.write_uint32(txh, tx.lock_time) if tx.overwintered or coin.decred: - write_uint32(txh, tx.expiry) + writers.write_uint32(txh, tx.expiry) ofs = 0 while ofs < tx.extra_data_len: size = min(1024, tx.extra_data_len - ofs) - data = await request_tx_extra_data(tx_req, ofs, size, prev_hash) - write_bytes(txh, data) + data = await helpers.request_tx_extra_data(tx_req, ofs, size, prev_hash) + writers.write_bytes(txh, data) ofs += len(data) - if get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) != prev_hash: + if ( + writers.get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) + != prev_hash + ): raise SigningError(FailureType.ProcessError, "Encountered invalid prev_hash") return total_out @@ -641,7 +645,7 @@ async def get_prevtx_output_value( # === -def get_hash_type(coin: CoinInfo) -> int: +def get_hash_type(coin: coininfo.CoinInfo) -> int: SIGHASH_FORKID = const(0x40) SIGHASH_ALL = const(0x01) hashtype = SIGHASH_ALL @@ -650,19 +654,21 @@ def get_hash_type(coin: CoinInfo) -> int: return hashtype -def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False): +def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False): w_txi = bytearray() if tx.overwintered: - write_uint32(w_txi, tx.version | OVERWINTERED) # nVersion | fOverwintered - write_uint32(w_txi, tx.version_group_id) # nVersionGroupId + writers.write_uint32( + w_txi, tx.version | zcash.OVERWINTERED + ) # nVersion | fOverwintered + writers.write_uint32(w_txi, tx.version_group_id) # nVersionGroupId else: - write_uint32(w_txi, tx.version) # nVersion + writers.write_uint32(w_txi, tx.version) # nVersion if tx.timestamp: - write_uint32(w_txi, tx.timestamp) + writers.write_uint32(w_txi, tx.timestamp) if segwit: - write_varint(w_txi, 0x00) # segwit witness marker - write_varint(w_txi, 0x01) # segwit witness flag - write_varint(w_txi, tx.inputs_count) + writers.write_varint(w_txi, 0x00) # segwit witness marker + writers.write_varint(w_txi, 0x01) # segwit witness flag + writers.write_varint(w_txi, tx.inputs_count) return w_txi @@ -670,7 +676,9 @@ def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False): # === -def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> bytes: +def output_derive_script( + o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain +) -> bytes: if o.script_type == OutputScriptType.PAYTOOPRETURN: # op_return output @@ -678,21 +686,21 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> raise SigningError( FailureType.DataError, "OP_RETURN output with non-zero amount" ) - return output_script_paytoopreturn(o.op_return_data) + return scripts.output_script_paytoopreturn(o.op_return_data) if o.address_n: # change output if o.address: raise SigningError(FailureType.DataError, "Address in change output") - o.address = get_address_for_change(o, coin, root) + o.address = get_address_for_change(o, coin, keychain) else: if not o.address: raise SigningError(FailureType.DataError, "Missing address") if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh - witprog = decode_bech32_address(coin.bech32_prefix, o.address) - return output_script_native_p2wpkh_or_p2wsh(witprog) + witprog = addresses.decode_bech32_address(coin.bech32_prefix, o.address) + return scripts.output_script_native_p2wpkh_or_p2wsh(witprog) if coin.cashaddr_prefix is not None and o.address.startswith( coin.cashaddr_prefix + ":" @@ -712,9 +720,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> if address_type.check(coin.address_type, raw_address): # p2pkh pubkeyhash = address_type.strip(coin.address_type, raw_address) - script = output_script_p2pkh(pubkeyhash) + script = scripts.output_script_p2pkh(pubkeyhash) if coin.bip115: - script += script_replay_protection_bip115( + script += scripts.script_replay_protection_bip115( o.block_hash_bip115, o.block_height_bip115 ) return script @@ -722,9 +730,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> elif address_type.check(coin.address_type_p2sh, raw_address): # p2sh scripthash = address_type.strip(coin.address_type_p2sh, raw_address) - script = output_script_p2sh(scripthash) + script = scripts.output_script_p2sh(scripthash) if coin.bip115: - script += script_replay_protection_bip115( + script += scripts.script_replay_protection_bip115( o.block_hash_bip115, o.block_height_bip115 ) return script @@ -732,7 +740,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> raise SigningError(FailureType.DataError, "Invalid address type") -def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode): +def get_address_for_change( + o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain +): if o.script_type == OutputScriptType.PAYTOADDRESS: input_script_type = InputScriptType.SPENDADDRESS elif o.script_type == OutputScriptType.PAYTOMULTISIG: @@ -743,17 +753,19 @@ def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode): input_script_type = InputScriptType.SPENDP2SHWITNESS else: raise SigningError(FailureType.DataError, "Invalid script type") - return get_address( - input_script_type, coin, node_derive(root, o.address_n), o.multisig - ) + node = keychain.derive(o.address_n, coin.curve_name) + return addresses.get_address(input_script_type, coin, node, o.multisig) -def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool: - is_segwit = ( - o.script_type == OutputScriptType.PAYTOWITNESS - or o.script_type == OutputScriptType.PAYTOP2SHWITNESS - ) - if is_segwit and o.amount > segwit_in: +def output_is_change( + o: TxOutputType, + wallet_path: list, + segwit_in: int, + multifp: multisig.MultisigFingerprint, +) -> bool: + if o.multisig and not multifp.matches(o.multisig): + return False + if output_is_segwit(o) and o.amount > segwit_in: # if the output is segwit, make sure it doesn't spend more than what the # segwit inputs paid. this is to prevent user being tricked into # creating ANYONECANSPEND outputs before full segwit activation. @@ -766,38 +778,49 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool ) +def output_is_segwit(o: TxOutputType) -> bool: + return ( + o.script_type == OutputScriptType.PAYTOWITNESS + or o.script_type == OutputScriptType.PAYTOP2SHWITNESS + ) + + # Tx Inputs # === def input_derive_script( - coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None + coin: coininfo.CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None ) -> bytes: if i.script_type == InputScriptType.SPENDADDRESS: # p2pkh or p2sh - return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin)) + return scripts.input_script_p2pkh_or_p2sh( + pubkey, signature, get_hash_type(coin) + ) if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh or p2wsh using p2sh if i.multisig: # p2wsh in p2sh - pubkeys = multisig_get_pubkeys(i.multisig) - witness_script = output_script_multisig(pubkeys, i.multisig.m) + pubkeys = multisig.multisig_get_pubkeys(i.multisig) + witness_script = scripts.output_script_multisig(pubkeys, i.multisig.m) witness_script_hash = sha256(witness_script).digest() - return input_script_p2wsh_in_p2sh(witness_script_hash) + return scripts.input_script_p2wsh_in_p2sh(witness_script_hash) # p2wpkh in p2sh - return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey, coin)) + return scripts.input_script_p2wpkh_in_p2sh( + addresses.ecdsa_hash_pubkey(pubkey, coin) + ) elif i.script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or p2wsh - return input_script_native_p2wpkh_or_p2wsh() + return scripts.input_script_native_p2wpkh_or_p2wsh() elif i.script_type == InputScriptType.SPENDMULTISIG: # p2sh multisig - signature_index = multisig_pubkey_index(i.multisig, pubkey) - return input_script_multisig( + signature_index = multisig.multisig_pubkey_index(i.multisig, pubkey) + return scripts.input_script_multisig( i.multisig, signature, signature_index, get_hash_type(coin), coin ) @@ -805,6 +828,13 @@ def input_derive_script( raise SigningError(FailureType.ProcessError, "Invalid script type") +def input_is_segwit(i: TxInputType) -> bool: + return ( + i.script_type == InputScriptType.SPENDWITNESS + or i.script_type == InputScriptType.SPENDP2SHWITNESS + ) + + def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list: if wallet_path is None: return None # there was a mismatch in previous inputs @@ -828,22 +858,7 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list: ) -def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode: - node = root.clone() - node.derive_path(address_n) - return node - - def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: sig = secp256k1.sign(node.private_key(), digest) sigder = der.encode_seq((sig[1:33], sig[33:65])) return sigder - - -def is_change( - txo: TxOutputType, wallet_path: list, segwit_in: int, multifp: MultisigFingerprint -) -> bool: - if txo.multisig: - if not multifp.matches(txo.multisig): - return False - return output_is_change(txo, wallet_path, segwit_in) diff --git a/src/apps/wallet/sign_tx/tx_weight_calculator.py b/src/apps/wallet/sign_tx/tx_weight.py similarity index 100% rename from src/apps/wallet/sign_tx/tx_weight_calculator.py rename to src/apps/wallet/sign_tx/tx_weight.py diff --git a/src/apps/wallet/sign_tx/writers.py b/src/apps/wallet/sign_tx/writers.py index 6d6b954e2a..e8214cd6b4 100644 --- a/src/apps/wallet/sign_tx/writers.py +++ b/src/apps/wallet/sign_tx/writers.py @@ -3,7 +3,8 @@ from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.utils import ensure -from apps.common.writers import ( +from apps.common.writers import ( # noqa: F401 + empty_bytearray, write_bytes, write_bytes_reversed, write_uint8, diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index f48046a9e2..4aba1a3b40 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -3,12 +3,25 @@ from trezor import log, loop, messages, utils, workflow from trezor.wire import codec_v1 from trezor.wire.errors import * +from apps.common import seed + workflow_handlers = {} -def add(mtype, pkgname, modname, *args): +def add(mtype, pkgname, modname, namespace=None): """Shortcut for registering a dynamically-imported Protobuf workflow.""" - register(mtype, protobuf_workflow, import_workflow, pkgname, modname, *args) + if namespace is not None: + register( + mtype, + protobuf_workflow, + keychain_workflow, + namespace, + import_workflow, + pkgname, + modname, + ) + else: + register(mtype, protobuf_workflow, import_workflow, pkgname, modname) def register(mtype, handler, *args): @@ -133,10 +146,12 @@ async def session_handler(iface, sid): continue except Error as exc: # we log wire.Error as warning, not as exception - log.warning(__name__, "failure: %s", exc.message) + if __debug__: + log.warning(__name__, "failure: %s", exc.message) except Exception as exc: # sessions are never closed by raised exceptions - log.exception(__name__, exc) + if __debug__: + log.exception(__name__, exc) # read new message in next iteration reader = None @@ -155,7 +170,7 @@ async def protobuf_workflow(ctx, reader, handler, *args): # respond with specific code and message await ctx.write(Failure(code=exc.code, message=exc.message)) raise - except Exception: # as exc: + except Exception: # respond with a generic code and message await ctx.write( Failure(code=FailureType.FirmwareError, message="Firmware error") @@ -166,6 +181,15 @@ async def protobuf_workflow(ctx, reader, handler, *args): await ctx.write(res) +async def keychain_workflow(ctx, req, namespace, handler, *args): + keychain = await seed.get_keychain(ctx, namespace) + args += (keychain,) + try: + return await handler(ctx, req, *args) + finally: + keychain.__del__() + + def import_workflow(ctx, req, pkgname, modname, *args): modpath = "%s.%s" % (pkgname, modname) module = __import__(modpath, None, None, (modname,), 0) diff --git a/tests/test_apps.cardano.address.py b/tests/test_apps.cardano.address.py index 23cac3f621..ee6fd3e2a0 100644 --- a/tests/test_apps.cardano.address.py +++ b/tests/test_apps.cardano.address.py @@ -8,6 +8,7 @@ from apps.cardano.address import ( validate_full_path, derive_address_and_node ) +from apps.cardano.seed import Keychain from trezor.crypto import bip32 @@ -16,6 +17,9 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) addresses = [ "Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j", @@ -25,7 +29,7 @@ class TestCardanoAddress(unittest.TestCase): for i, expected in enumerate(addresses): # 44'/1815'/0'/0/i' - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) self.assertEqual(expected, address) nodes = [ @@ -50,7 +54,7 @@ class TestCardanoAddress(unittest.TestCase): ] for i, (priv, ext, pub, chain) in enumerate(nodes): - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) @@ -60,6 +64,9 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) addresses = [ "Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ", @@ -69,7 +76,7 @@ class TestCardanoAddress(unittest.TestCase): for i, expected in enumerate(addresses): # 44'/1815'/0'/0/i - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) self.assertEqual(address, expected) nodes = [ @@ -94,7 +101,7 @@ class TestCardanoAddress(unittest.TestCase): ] for i, (priv, ext, pub, chain) in enumerate(nodes): - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) @@ -105,9 +112,12 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) # 44'/1815' - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815]) self.assertEqual(address, "Ae2tdPwUPEZ2FGHX3yCKPSbSgyuuTYgMxNq652zKopxT4TuWvEd8Utd92w3") priv, ext, pub, chain = ( @@ -117,7 +127,7 @@ class TestCardanoAddress(unittest.TestCase): b"02ac67c59a8b0264724a635774ca2c242afa10d7ab70e2bf0a8f7d4bb10f1f7a" ) - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) diff --git a/tests/test_apps.cardano.get_public_key.py b/tests/test_apps.cardano.get_public_key.py index e9b7fdfa49..33cdc4e8f6 100644 --- a/tests/test_apps.cardano.get_public_key.py +++ b/tests/test_apps.cardano.get_public_key.py @@ -1,5 +1,6 @@ from common import * +from apps.cardano.seed import Keychain from apps.cardano.get_public_key import _get_public_key from trezor.crypto import bip32 from ubinascii import hexlify @@ -10,6 +11,9 @@ class TestCardanoGetPublicKey(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) derivation_paths = [ [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000], @@ -40,7 +44,7 @@ class TestCardanoGetPublicKey(unittest.TestCase): ] for index, derivation_path in enumerate(derivation_paths): - key = _get_public_key(node, derivation_path) + key = _get_public_key(keychain, derivation_path) self.assertEqual(hexlify(key.node.public_key), public_keys[index]) self.assertEqual(hexlify(key.node.chain_code), chain_codes[index]) diff --git a/tests/test_apps.ethereum.layout.py b/tests/test_apps.ethereum.layout.py index c0e72b74f1..5da5746864 100644 --- a/tests/test_apps.ethereum.layout.py +++ b/tests/test_apps.ethereum.layout.py @@ -45,7 +45,7 @@ class TestEthereumLayout(unittest.TestCase): text = format_ethereum_amount(1000000000000000000, None, 61) self.assertEqual(text, '1 ETC') text = format_ethereum_amount(1000000000000000000, None, 31) - self.assertEqual(text, '1 tRSK') + self.assertEqual(text, '1 tRBTC') text = format_ethereum_amount(1000000000000000001, None, 1) self.assertEqual(text, '1.000000000000000001 ETH') @@ -54,7 +54,7 @@ class TestEthereumLayout(unittest.TestCase): text = format_ethereum_amount(10000000000000000001, None, 61) self.assertEqual(text, '10.000000000000000001 ETC') text = format_ethereum_amount(1000000000000000001, None, 31) - self.assertEqual(text, '1.000000000000000001 tRSK') + self.assertEqual(text, '1.000000000000000001 tRBTC') # unknown chain text = format_ethereum_amount(1, None, 9999) diff --git a/tests/test_apps.wallet.address.py b/tests/test_apps.wallet.address.py index d39067ddbb..2df72c70ea 100644 --- a/tests/test_apps.wallet.address.py +++ b/tests/test_apps.wallet.address.py @@ -1,10 +1,19 @@ from common import * from trezor.crypto import bip32, bip39 +from trezor.utils import HashWriter from apps.wallet.sign_tx.addresses import validate_full_path, validate_path_for_bitcoin_public_key from apps.common.paths import HARDENED from apps.common import coins +from apps.wallet.sign_tx.addresses import * from apps.wallet.sign_tx.signing import * +from apps.wallet.sign_tx.writers import * + + +def node_derive(root, path): + node = root.clone() + node.derive_path(path) + return node class TestAddress(unittest.TestCase): diff --git a/tests/test_apps.wallet.address_grs.py b/tests/test_apps.wallet.address_grs.py index 404ddff04b..8aeb173e0c 100644 --- a/tests/test_apps.wallet.address_grs.py +++ b/tests/test_apps.wallet.address_grs.py @@ -1,10 +1,17 @@ from common import * from apps.wallet.sign_tx.signing import * +from apps.wallet.sign_tx.addresses import * from apps.common import coins from trezor.crypto import bip32, bip39 +def node_derive(root, path): + node = root.clone() + node.derive_path(path) + return node + + class TestAddressGRS(unittest.TestCase): # pylint: disable=C0301 diff --git a/tests/test_apps.wallet.segwit.bip143.native_p2wpkh.py b/tests/test_apps.wallet.segwit.bip143.native_p2wpkh.py index f1e215bff1..5433367d3a 100644 --- a/tests/test_apps.wallet.segwit.bip143.native_p2wpkh.py +++ b/tests/test_apps.wallet.segwit.bip143.native_p2wpkh.py @@ -1,6 +1,7 @@ from common import * from apps.wallet.sign_tx.signing import * +from apps.wallet.sign_tx.segwit_bip143 import * from apps.common import coins from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType diff --git a/tests/test_apps.wallet.segwit.bip143.p2wpkh_in_p2sh.py b/tests/test_apps.wallet.segwit.bip143.p2wpkh_in_p2sh.py index d8a9ba6478..652e82d199 100644 --- a/tests/test_apps.wallet.segwit.bip143.p2wpkh_in_p2sh.py +++ b/tests/test_apps.wallet.segwit.bip143.p2wpkh_in_p2sh.py @@ -1,6 +1,7 @@ from common import * from apps.wallet.sign_tx.signing import * +from apps.wallet.sign_tx.segwit_bip143 import * from apps.common import coins from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType diff --git a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py index c09f7d48ed..a2dbb8404c 100644 --- a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py +++ b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py @@ -1,7 +1,7 @@ from common import * from trezor.utils import chunks -from trezor.crypto import bip32, bip39 +from trezor.crypto import bip39 from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType @@ -15,7 +15,8 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): @@ -24,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): def test_send_native_p2wpkh(self): coin = coins.by_name('Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, 'secp256k1') inp1 = TxInputType( # 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s @@ -61,22 +60,22 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmOutput(out2, coin), + helpers.UiConfirmOutput(out2, coin), True, - signing.UiConfirmTotal(12300000, 11000, coin), + helpers.UiConfirmTotal(12300000, 11000, coin), True, # sign tx @@ -113,18 +112,17 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) def test_send_native_p2wpkh_change(self): coin = coins.by_name('Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, 'secp256k1') inp1 = TxInputType( # 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s @@ -159,19 +157,19 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmTotal(5000000 + 11000, 11000, coin), + helpers.UiConfirmTotal(5000000 + 11000, 11000, coin), True, # sign tx @@ -209,21 +207,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal)) or - (isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py index 907cd81507..9ad58c17b6 100644 --- a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py +++ b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py @@ -1,7 +1,7 @@ from common import * from trezor.utils import chunks -from trezor.crypto import bip32, bip39 +from trezor.crypto import bip39 from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType @@ -15,7 +15,8 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing # https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3 class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): @@ -24,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): def test_send_native_p2wpkh(self): coin = coins.by_name('Groestlcoin Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, coin.curve_name) inp1 = TxInputType( # 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja @@ -64,16 +63,16 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmOutput(out2, coin), + helpers.UiConfirmOutput(out2, coin), True, - signing.UiConfirmTotal(12300000, 11000, coin), + helpers.UiConfirmTotal(12300000, 11000, coin), True, # sign tx @@ -110,18 +109,17 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) def test_send_native_p2wpkh_change(self): coin = coins.by_name('Groestlcoin Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, coin.curve_name) inp1 = TxInputType( # 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja @@ -159,13 +157,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmTotal(5000000 + 11000, 11000, coin), + helpers.UiConfirmTotal(5000000 + 11000, 11000, coin), True, # sign tx @@ -203,20 +201,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py index 208e72a042..bdb6a5391b 100644 --- a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py +++ b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py @@ -1,7 +1,7 @@ from common import * from trezor.utils import chunks -from trezor.crypto import bip32, bip39 +from trezor.crypto import bip39 from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType @@ -15,7 +15,8 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): @@ -24,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): def test_send_p2wpkh_in_p2sh(self): coin = coins.by_name('Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, 'secp256k1') inp1 = TxInputType( # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -64,16 +63,16 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmOutput(out2, coin), + helpers.UiConfirmOutput(out2, coin), True, - signing.UiConfirmTotal(123445789 + 11000, 11000, coin), + helpers.UiConfirmTotal(123445789 + 11000, 11000, coin), True, # sign tx @@ -110,18 +109,17 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) def test_send_p2wpkh_in_p2sh_change(self): coin = coins.by_name('Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, 'secp256k1') inp1 = TxInputType( # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -160,14 +158,14 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmTotal(12300000 + 11000, 11000, coin), + helpers.UiConfirmTotal(12300000 + 11000, 11000, coin), True, # sign tx @@ -213,9 +211,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -224,9 +223,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): def test_send_p2wpkh_in_p2sh_attack_amount(self): coin = coins.by_name('Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, 'secp256k1') inp1 = TxInputType( # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX @@ -275,14 +272,14 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmTotal(8, 0, coin), + helpers.UiConfirmTotal(8, 0, coin), True, TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), @@ -322,26 +319,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): TxRequest(request_type=TXFINISHED, details=None) ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) i = 0 messages_count = int(len(messages) / 2) for request, response in chunks(messages, 2): if i == messages_count - 1: # last message should throw SigningError self.assertRaises(signing.SigningError, signer.send, request) else: - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) i += 1 with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py index 6e8431ac5c..152142c2bb 100644 --- a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py +++ b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py @@ -1,7 +1,7 @@ from common import * from trezor.utils import chunks -from trezor.crypto import bip32, bip39 +from trezor.crypto import bip39 from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType @@ -15,7 +15,8 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing # https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72 class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): @@ -24,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): def test_send_p2wpkh_in_p2sh(self): coin = coins.by_name('Groestlcoin Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, coin.curve_name) inp1 = TxInputType( # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 @@ -64,16 +63,16 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmOutput(out2, coin), + helpers.UiConfirmOutput(out2, coin), True, - signing.UiConfirmTotal(123445789 + 11000, 11000, coin), + helpers.UiConfirmTotal(123445789 + 11000, 11000, coin), True, # sign tx @@ -110,18 +109,17 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) def test_send_p2wpkh_in_p2sh_change(self): coin = coins.by_name('Groestlcoin Testnet') - seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, coin.curve_name) inp1 = TxInputType( # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 @@ -160,14 +158,14 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out2])), - signing.UiConfirmTotal(12300000 + 11000, 11000, coin), + helpers.UiConfirmTotal(12300000 + 11000, 11000, coin), True, # sign tx @@ -212,20 +210,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.signtx.fee_threshold.py b/tests/test_apps.wallet.signtx.fee_threshold.py index cecaa6da35..4b4f1935f1 100644 --- a/tests/test_apps.wallet.signtx.fee_threshold.py +++ b/tests/test_apps.wallet.signtx.fee_threshold.py @@ -14,7 +14,8 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing class TestSignTxFeeThreshold(unittest.TestCase): @@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -72,11 +73,11 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxAck(tx=TransactionType(bin_outputs=[pout1])), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin_bitcoin), + helpers.UiConfirmOutput(out1, coin_bitcoin), True, - signing.UiConfirmFeeOverThreshold(100000, coin_bitcoin), + helpers.UiConfirmFeeOverThreshold(100000, coin_bitcoin), True, - signing.UiConfirmTotal(290000 + 100000, 100000, coin_bitcoin), + helpers.UiConfirmTotal(290000 + 100000, 100000, coin_bitcoin), True, TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), ] @@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase): seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') root = bip32.from_seed(seed, 'secp256k1') - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) def test_under_threshold(self): coin_bitcoin = coins.by_name('Bitcoin') @@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -139,9 +141,9 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxAck(tx=TransactionType(bin_outputs=[pout1])), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin_bitcoin), + helpers.UiConfirmOutput(out1, coin_bitcoin), True, - signing.UiConfirmTotal(300000 + 90000, 90000, coin_bitcoin), + helpers.UiConfirmTotal(300000 + 90000, 90000, coin_bitcoin), True, TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), ] @@ -149,19 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase): seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') root = bip32.from_seed(seed, 'secp256k1') - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) - - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal)) or - (isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress)) or - (isinstance(a, signing.UiConfirmFeeOverThreshold) and isinstance(b, signing.UiConfirmFeeOverThreshold))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) + self.assertEqual(signer.send(request), response) if __name__ == '__main__': diff --git a/tests/test_apps.wallet.signtx.py b/tests/test_apps.wallet.signtx.py index 1c1869aba3..a7819d4ba5 100644 --- a/tests/test_apps.wallet.signtx.py +++ b/tests/test_apps.wallet.signtx.py @@ -15,7 +15,8 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing class TestSignTx(unittest.TestCase): @@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -73,9 +74,9 @@ class TestSignTx(unittest.TestCase): TxAck(tx=TransactionType(bin_outputs=[pout1])), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin_bitcoin), + helpers.UiConfirmOutput(out1, coin_bitcoin), True, - signing.UiConfirmTotal(380000 + 10000, 10000, coin_bitcoin), + helpers.UiConfirmTotal(380000 + 10000, 10000, coin_bitcoin), True, # ButtonRequest(code=ButtonRequest_ConfirmOutput), # ButtonRequest(code=ButtonRequest_SignTx), @@ -96,26 +97,16 @@ class TestSignTx(unittest.TestCase): ] seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') - root = bip32.from_seed(seed, 'secp256k1') - - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin_bitcoin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): res = signer.send(request) - self.assertEqualEx(res, response) + self.assertEqual(res, response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.signtx_grs.py b/tests/test_apps.wallet.signtx_grs.py index e74d900088..ae33831d3e 100644 --- a/tests/test_apps.wallet.signtx_grs.py +++ b/tests/test_apps.wallet.signtx_grs.py @@ -1,7 +1,7 @@ from common import * from trezor.utils import chunks -from trezor.crypto import bip32, bip39 +from trezor.crypto import bip39 from trezor.messages.SignTx import SignTx from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType @@ -15,7 +15,8 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages import OutputScriptType from apps.common import coins -from apps.wallet.sign_tx import signing +from apps.common.seed import Keychain +from apps.wallet.sign_tx import helpers, signing class TestSignTx_GRS(unittest.TestCase): @@ -62,9 +63,9 @@ class TestSignTx_GRS(unittest.TestCase): TxAck(tx=TransactionType(bin_outputs=[pout1])), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), - signing.UiConfirmOutput(out1, coin), + helpers.UiConfirmOutput(out1, coin), True, - signing.UiConfirmTotal(210016, 192, coin), + helpers.UiConfirmTotal(210016, 192, coin), True, # ButtonRequest(code=ButtonRequest_ConfirmOutput), # ButtonRequest(code=ButtonRequest_SignTx), @@ -85,22 +86,13 @@ class TestSignTx_GRS(unittest.TestCase): ] seed = bip39.seed(' '.join(['all'] * 12), '') - root = bip32.from_seed(seed, coin.curve_name) - - signer = signing.sign_tx(tx, root) + keychain = Keychain(seed, [[coin.curve_name]]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to signing.Ui* classes - if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or - (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.txweight.py b/tests/test_apps.wallet.txweight.py index 84e92018dd..becd67ac29 100644 --- a/tests/test_apps.wallet.txweight.py +++ b/tests/test_apps.wallet.txweight.py @@ -5,7 +5,7 @@ from trezor.messages import OutputScriptType from trezor.crypto import bip32, bip39 from apps.common import coins -from apps.wallet.sign_tx.tx_weight_calculator import * +from apps.wallet.sign_tx.tx_weight import * from apps.wallet.sign_tx import signing diff --git a/vendor/trezor-common b/vendor/trezor-common index 71528b5260..8906ebf92c 160000 --- a/vendor/trezor-common +++ b/vendor/trezor-common @@ -1 +1 @@ -Subproject commit 71528b526020b5c6a95261b07336cff5d68ea66e +Subproject commit 8906ebf92cf754554f231d4341976c2cf5da9a22