1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

src/apps/wallet/sign_tx: implemented simplified API for MultisigRedeemScriptType

If address_n is the same for all nodes in the multisig, provide it just once
and supply nodes directly (not in the HDNodePathType structure)
This commit is contained in:
Pavol Rusnak 2019-02-04 01:15:13 +01:00
parent 6974d037a9
commit 4225fe7fa8
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
4 changed files with 41 additions and 17 deletions

View File

@ -2,6 +2,7 @@ from trezor.crypto import bip32
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages import FailureType from trezor.messages import FailureType
from trezor.messages.HDNodePathType import HDNodePathType from trezor.messages.HDNodePathType import HDNodePathType
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.utils import HashWriter, ensure from trezor.utils import HashWriter, ensure
@ -35,26 +36,27 @@ class MultisigFingerprint:
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
pubkeys = multisig.pubkeys if multisig.nodes:
pubnodes = multisig.nodes
else:
pubnodes = [hd.node for hd in multisig.pubkeys]
m = multisig.m m = multisig.m
n = len(pubkeys) n = len(pubnodes)
if n < 1 or n > 15 or m < 1 or m > 15: if n < 1 or n > 15 or m < 1 or m > 15:
raise MultisigError(FailureType.DataError, "Invalid multisig parameters") raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
for hd in pubkeys: for d in pubnodes:
d = hd.node
if len(d.public_key) != 33 or len(d.chain_code) != 32: if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise MultisigError(FailureType.DataError, "Invalid multisig parameters") raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
# casting to bytes(), sorting on bytearray() is not supported in MicroPython # casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key)) pubnodes = sorted(pubnodes, key=lambda n: bytes(n.public_key))
h = HashWriter(sha256()) h = HashWriter(sha256())
write_uint32(h, m) write_uint32(h, m)
write_uint32(h, n) write_uint32(h, n)
for hd in pubkeys: for d in pubnodes:
d = hd.node
write_uint32(h, d.depth) write_uint32(h, d.depth)
write_uint32(h, d.fingerprint) write_uint32(h, d.fingerprint)
write_uint32(h, d.child_num) write_uint32(h, d.child_num)
@ -65,15 +67,18 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int: def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int:
for i, hd in enumerate(multisig.pubkeys): if multisig.nodes:
if multisig_get_pubkey(hd) == pubkey: for i, hd in enumerate(multisig.nodes):
return i if multisig_get_pubkey(hd, multisig.address_n) == pubkey:
return i
else:
for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd.node, hd.address_n) == pubkey:
return i
raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script") raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script")
def multisig_get_pubkey(hd: HDNodePathType) -> bytes: def multisig_get_pubkey(n: HDNodeType, p: list) -> bytes:
p = hd.address_n
n = hd.node
node = bip32.HDNode( node = bip32.HDNode(
depth=n.depth, depth=n.depth,
fingerprint=n.fingerprint, fingerprint=n.fingerprint,
@ -87,4 +92,16 @@ def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
def multisig_get_pubkeys(multisig: MultisigRedeemScriptType): def multisig_get_pubkeys(multisig: MultisigRedeemScriptType):
return [multisig_get_pubkey(hd) for hd in multisig.pubkeys] if multisig.nodes:
return [
multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes
]
else:
return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys]
def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType):
if multisig.nodes:
return len(multisig.nodes)
else:
len(multisig.pubkeys)

View File

@ -3,7 +3,7 @@ from trezor.utils import ensure
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.writers import empty_bytearray from apps.common.writers import empty_bytearray
from apps.wallet.sign_tx.multisig import multisig_get_pubkeys from apps.wallet.sign_tx.multisig import multisig_get_pubkey_count, multisig_get_pubkeys
from apps.wallet.sign_tx.writers import ( from apps.wallet.sign_tx.writers import (
write_bytes, write_bytes,
write_op_push, write_op_push,
@ -158,7 +158,7 @@ def witness_p2wsh(
): ):
# get other signatures, stretch with None to the number of the pubkeys # get other signatures, stretch with None to the number of the pubkeys
signatures = multisig.signatures + [None] * ( signatures = multisig.signatures + [None] * (
len(multisig.pubkeys) - len(multisig.signatures) multisig_get_pubkey_count(multisig) - len(multisig.signatures)
) )
# fill in our signature # fill in our signature
if signatures[signature_index]: if signatures[signature_index]:

View File

@ -3,6 +3,7 @@
import protobuf as p import protobuf as p
from .HDNodePathType import HDNodePathType from .HDNodePathType import HDNodePathType
from .HDNodeType import HDNodeType
if __debug__: if __debug__:
try: try:
@ -18,10 +19,14 @@ class MultisigRedeemScriptType(p.MessageType):
pubkeys: List[HDNodePathType] = None, pubkeys: List[HDNodePathType] = None,
signatures: List[bytes] = None, signatures: List[bytes] = None,
m: int = None, m: int = None,
nodes: List[HDNodeType] = None,
address_n: List[int] = None,
) -> None: ) -> None:
self.pubkeys = pubkeys if pubkeys is not None else [] self.pubkeys = pubkeys if pubkeys is not None else []
self.signatures = signatures if signatures is not None else [] self.signatures = signatures if signatures is not None else []
self.m = m self.m = m
self.nodes = nodes if nodes is not None else []
self.address_n = address_n if address_n is not None else []
@classmethod @classmethod
def get_fields(cls): def get_fields(cls):
@ -29,4 +34,6 @@ class MultisigRedeemScriptType(p.MessageType):
1: ('pubkeys', HDNodePathType, p.FLAG_REPEATED), 1: ('pubkeys', HDNodePathType, p.FLAG_REPEATED),
2: ('signatures', p.BytesType, p.FLAG_REPEATED), 2: ('signatures', p.BytesType, p.FLAG_REPEATED),
3: ('m', p.UVarintType, 0), 3: ('m', p.UVarintType, 0),
4: ('nodes', HDNodeType, p.FLAG_REPEATED),
5: ('address_n', p.UVarintType, p.FLAG_REPEATED),
} }

@ -1 +1 @@
Subproject commit 4b41d2e63841517bf701618434c018acf4f1bca2 Subproject commit 0735c7d6f524b4c5108d201c789612aad7ce7920