src/apps/wallet/sign_tx: implemented simplified API for MultisigRedeemScriptType

If address_n is the same for all nodes in the multisig, provide it just once
and supply nodes directly (not in the HDNodePathType structure)
pull/25/head
Pavol Rusnak 5 years ago
parent 6974d037a9
commit 4225fe7fa8
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -2,6 +2,7 @@ from trezor.crypto import bip32
from trezor.crypto.hashlib import sha256
from trezor.messages import FailureType
from trezor.messages.HDNodePathType import HDNodePathType
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.utils import HashWriter, ensure
@ -35,26 +36,27 @@ class MultisigFingerprint:
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
pubkeys = multisig.pubkeys
if multisig.nodes:
pubnodes = multisig.nodes
else:
pubnodes = [hd.node for hd in multisig.pubkeys]
m = multisig.m
n = len(pubkeys)
n = len(pubnodes)
if n < 1 or n > 15 or m < 1 or m > 15:
raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
for hd in pubkeys:
d = hd.node
for d in pubnodes:
if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise MultisigError(FailureType.DataError, "Invalid multisig parameters")
# casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key))
pubnodes = sorted(pubnodes, key=lambda n: bytes(n.public_key))
h = HashWriter(sha256())
write_uint32(h, m)
write_uint32(h, n)
for hd in pubkeys:
d = hd.node
for d in pubnodes:
write_uint32(h, d.depth)
write_uint32(h, d.fingerprint)
write_uint32(h, d.child_num)
@ -65,15 +67,18 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int:
for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd) == pubkey:
return i
if multisig.nodes:
for i, hd in enumerate(multisig.nodes):
if multisig_get_pubkey(hd, multisig.address_n) == pubkey:
return i
else:
for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd.node, hd.address_n) == pubkey:
return i
raise MultisigError(FailureType.DataError, "Pubkey not found in multisig script")
def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
p = hd.address_n
n = hd.node
def multisig_get_pubkey(n: HDNodeType, p: list) -> bytes:
node = bip32.HDNode(
depth=n.depth,
fingerprint=n.fingerprint,
@ -87,4 +92,16 @@ def multisig_get_pubkey(hd: HDNodePathType) -> bytes:
def multisig_get_pubkeys(multisig: MultisigRedeemScriptType):
return [multisig_get_pubkey(hd) for hd in multisig.pubkeys]
if multisig.nodes:
return [
multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes
]
else:
return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys]
def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType):
if multisig.nodes:
return len(multisig.nodes)
else:
len(multisig.pubkeys)

@ -3,7 +3,7 @@ from trezor.utils import ensure
from apps.common.coininfo import CoinInfo
from apps.common.writers import empty_bytearray
from apps.wallet.sign_tx.multisig import multisig_get_pubkeys
from apps.wallet.sign_tx.multisig import multisig_get_pubkey_count, multisig_get_pubkeys
from apps.wallet.sign_tx.writers import (
write_bytes,
write_op_push,
@ -158,7 +158,7 @@ def witness_p2wsh(
):
# get other signatures, stretch with None to the number of the pubkeys
signatures = multisig.signatures + [None] * (
len(multisig.pubkeys) - len(multisig.signatures)
multisig_get_pubkey_count(multisig) - len(multisig.signatures)
)
# fill in our signature
if signatures[signature_index]:

@ -3,6 +3,7 @@
import protobuf as p
from .HDNodePathType import HDNodePathType
from .HDNodeType import HDNodeType
if __debug__:
try:
@ -18,10 +19,14 @@ class MultisigRedeemScriptType(p.MessageType):
pubkeys: List[HDNodePathType] = None,
signatures: List[bytes] = None,
m: int = None,
nodes: List[HDNodeType] = None,
address_n: List[int] = None,
) -> None:
self.pubkeys = pubkeys if pubkeys is not None else []
self.signatures = signatures if signatures is not None else []
self.m = m
self.nodes = nodes if nodes is not None else []
self.address_n = address_n if address_n is not None else []
@classmethod
def get_fields(cls):
@ -29,4 +34,6 @@ class MultisigRedeemScriptType(p.MessageType):
1: ('pubkeys', HDNodePathType, p.FLAG_REPEATED),
2: ('signatures', p.BytesType, p.FLAG_REPEATED),
3: ('m', p.UVarintType, 0),
4: ('nodes', HDNodeType, p.FLAG_REPEATED),
5: ('address_n', p.UVarintType, p.FLAG_REPEATED),
}

@ -1 +1 @@
Subproject commit 4b41d2e63841517bf701618434c018acf4f1bca2
Subproject commit 0735c7d6f524b4c5108d201c789612aad7ce7920
Loading…
Cancel
Save