diff --git a/src/apps/wallet/sign_tx/multisig.py b/src/apps/wallet/sign_tx/multisig.py index 213f2c9dc0..538c1bdec9 100644 --- a/src/apps/wallet/sign_tx/multisig.py +++ b/src/apps/wallet/sign_tx/multisig.py @@ -2,6 +2,7 @@ from trezor.crypto import bip32 from trezor.crypto.hashlib import sha256 from trezor.messages import FailureType from trezor.messages.HDNodePathType import HDNodePathType +from trezor.messages.HDNodeType import HDNodeType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.utils import HashWriter, ensure @@ -35,26 +36,27 @@ class MultisigFingerprint: def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: - pubkeys = multisig.pubkeys + if multisig.nodes: + pubnodes = multisig.nodes + else: + pubnodes = [hd.node for hd in multisig.pubkeys] m = multisig.m - n = len(pubkeys) + n = len(pubnodes) if n < 1 or n > 15 or m < 1 or m > 15: raise MultisigError(FailureType.DataError, "Invalid multisig parameters") - for hd in pubkeys: - d = hd.node + for d in pubnodes: if len(d.public_key) != 33 or len(d.chain_code) != 32: raise MultisigError(FailureType.DataError, "Invalid multisig parameters") # casting to bytes(), sorting on bytearray() is not supported in MicroPython - pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key)) + pubnodes = sorted(pubnodes, key=lambda n: bytes(n.public_key)) h = HashWriter(sha256()) write_uint32(h, m) write_uint32(h, n) - for hd in pubkeys: - d = hd.node + for d in pubnodes: write_uint32(h, d.depth) write_uint32(h, d.fingerprint) write_uint32(h, d.child_num) @@ -65,15 +67,18 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int: - for i, hd in enumerate(multisig.pubkeys): - if multisig_get_pubkey(hd) == pubkey: - return i + if multisig.nodes: + for i, hd in enumerate(multisig.nodes): + if multisig_get_pubkey(hd, multisig.address_n) == pubkey: + return i + else: + for i, hd in enumerate(multisig.pubkeys): + if multisig_get_pubkey(hd.node, hd.address_n) == pubkey: + return i raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script") -def multisig_get_pubkey(hd: HDNodePathType) -> bytes: - p = hd.address_n - n = hd.node +def multisig_get_pubkey(n: HDNodeType, p: list) -> bytes: node = bip32.HDNode( depth=n.depth, fingerprint=n.fingerprint, @@ -87,4 +92,16 @@ def multisig_get_pubkey(hd: HDNodePathType) -> bytes: def multisig_get_pubkeys(multisig: MultisigRedeemScriptType): - return [multisig_get_pubkey(hd) for hd in multisig.pubkeys] + if multisig.nodes: + return [ + multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes + ] + else: + return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys] + + +def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType): + if multisig.nodes: + return len(multisig.nodes) + else: + len(multisig.pubkeys) diff --git a/src/apps/wallet/sign_tx/scripts.py b/src/apps/wallet/sign_tx/scripts.py index 4c97f0d34c..7a8407e9f4 100644 --- a/src/apps/wallet/sign_tx/scripts.py +++ b/src/apps/wallet/sign_tx/scripts.py @@ -3,7 +3,7 @@ from trezor.utils import ensure from apps.common.coininfo import CoinInfo from apps.common.writers import empty_bytearray -from apps.wallet.sign_tx.multisig import multisig_get_pubkeys +from apps.wallet.sign_tx.multisig import multisig_get_pubkey_count, multisig_get_pubkeys from apps.wallet.sign_tx.writers import ( write_bytes, write_op_push, @@ -158,7 +158,7 @@ def witness_p2wsh( ): # get other signatures, stretch with None to the number of the pubkeys signatures = multisig.signatures + [None] * ( - len(multisig.pubkeys) - len(multisig.signatures) + multisig_get_pubkey_count(multisig) - len(multisig.signatures) ) # fill in our signature if signatures[signature_index]: diff --git a/src/trezor/messages/MultisigRedeemScriptType.py b/src/trezor/messages/MultisigRedeemScriptType.py index c1c6fd29de..e8985cabd1 100644 --- a/src/trezor/messages/MultisigRedeemScriptType.py +++ b/src/trezor/messages/MultisigRedeemScriptType.py @@ -3,6 +3,7 @@ import protobuf as p from .HDNodePathType import HDNodePathType +from .HDNodeType import HDNodeType if __debug__: try: @@ -18,10 +19,14 @@ class MultisigRedeemScriptType(p.MessageType): pubkeys: List[HDNodePathType] = None, signatures: List[bytes] = None, m: int = None, + nodes: List[HDNodeType] = None, + address_n: List[int] = None, ) -> None: self.pubkeys = pubkeys if pubkeys is not None else [] self.signatures = signatures if signatures is not None else [] self.m = m + self.nodes = nodes if nodes is not None else [] + self.address_n = address_n if address_n is not None else [] @classmethod def get_fields(cls): @@ -29,4 +34,6 @@ class MultisigRedeemScriptType(p.MessageType): 1: ('pubkeys', HDNodePathType, p.FLAG_REPEATED), 2: ('signatures', p.BytesType, p.FLAG_REPEATED), 3: ('m', p.UVarintType, 0), + 4: ('nodes', HDNodeType, p.FLAG_REPEATED), + 5: ('address_n', p.UVarintType, p.FLAG_REPEATED), } diff --git a/vendor/trezor-common b/vendor/trezor-common index 4b41d2e638..0735c7d6f5 160000 --- a/vendor/trezor-common +++ b/vendor/trezor-common @@ -1 +1 @@ -Subproject commit 4b41d2e63841517bf701618434c018acf4f1bca2 +Subproject commit 0735c7d6f524b4c5108d201c789612aad7ce7920