1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 01:18:28 +00:00

feat(core):: support sortedmulti

This commit is contained in:
Ondřej Vejpustek 2024-11-07 14:06:09 +01:00
parent 5c6198cf13
commit d19cbfa3cb
6 changed files with 160 additions and 37 deletions

View File

@ -0,0 +1 @@
Added support for sortedmulti.

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.enums import MultisigPubkeysOrder
from trezor.wire import DataError from trezor.wire import DataError
if TYPE_CHECKING: if TYPE_CHECKING:
@ -28,7 +29,8 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
if len(d.public_key) != 33 or len(d.chain_code) != 32: if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
pubnodes = sorted(pubnodes, key=lambda n: n.public_key + n.chain_code) if multisig.pubkeys_order == MultisigPubkeysOrder.LEXICOGRAPHIC:
pubnodes = sorted(pubnodes, key=lambda n: n.public_key + n.chain_code)
h = HashWriter(sha256()) h = HashWriter(sha256())
write_uint32(h, m) write_uint32(h, m)
@ -84,9 +86,14 @@ def multisig_get_pubkey(n: HDNodeType, p: paths.Bip32Path) -> bytes:
def multisig_get_pubkeys(multisig: MultisigRedeemScriptType) -> list[bytes]: def multisig_get_pubkeys(multisig: MultisigRedeemScriptType) -> list[bytes]:
validate_multisig(multisig) validate_multisig(multisig)
if multisig.nodes: if multisig.nodes:
return [multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes] pubkeys = [multisig_get_pubkey(hd, multisig.address_n) for hd in multisig.nodes]
else: else:
return [multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys] pubkeys = [
multisig_get_pubkey(hd.node, hd.address_n) for hd in multisig.pubkeys
]
if multisig.pubkeys_order == MultisigPubkeysOrder.LEXICOGRAPHIC:
pubkeys = sorted(pubkeys)
return pubkeys
def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType) -> int: def multisig_get_pubkey_count(multisig: MultisigRedeemScriptType) -> int:

View File

@ -19,6 +19,7 @@ class ChangeDetector:
from .matchcheck import ( from .matchcheck import (
MultisigChecker, MultisigChecker,
MultisigFingerprintChecker, MultisigFingerprintChecker,
PubkeysOrderChecker,
ScriptTypeChecker, ScriptTypeChecker,
WalletPathChecker, WalletPathChecker,
) )
@ -29,6 +30,9 @@ class ChangeDetector:
# Checksum of multisig inputs, used to validate change-output. # Checksum of multisig inputs, used to validate change-output.
self.multisig_fingerprint = MultisigFingerprintChecker() self.multisig_fingerprint = MultisigFingerprintChecker()
# Whether all inputs use sorted pubkeys or all inputs use unsorted pubkeys, used to validate change-output.
self.pubkeys_order = PubkeysOrderChecker()
# Common prefix of input paths, used to validate change-output. # Common prefix of input paths, used to validate change-output.
self.wallet_path = WalletPathChecker() self.wallet_path = WalletPathChecker()
@ -41,19 +45,25 @@ class ChangeDetector:
self.script_type.add_input(txi) self.script_type.add_input(txi)
self.multisig_fingerprint.add_input(txi) self.multisig_fingerprint.add_input(txi)
self.multisig.add_input(txi) self.multisig.add_input(txi)
self.pubkeys_order.add_input(txi)
def check_input(self, txi: TxInput) -> None: def check_input(self, txi: TxInput) -> None:
self.wallet_path.check_input(txi) self.wallet_path.check_input(txi)
self.script_type.check_input(txi) self.script_type.check_input(txi)
self.multisig_fingerprint.check_input(txi) self.multisig_fingerprint.check_input(txi)
self.multisig.check_input(txi) self.multisig.check_input(txi)
self.pubkeys_order.check_input(txi)
def output_is_change(self, txo: TxOutput) -> bool: def output_is_change(self, txo: TxOutput) -> bool:
if txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES: if txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES:
return False return False
if txo.multisig and not self.multisig_fingerprint.output_matches(txo): if txo.multisig:
return False if not (
self.pubkeys_order.output_matches(txo)
and self.multisig_fingerprint.output_matches(txo)
):
return False
return ( return (
self.multisig.output_matches(txo) self.multisig.output_matches(txo)

View File

@ -107,6 +107,13 @@ class MultisigFingerprintChecker(MatchChecker):
return multisig.multisig_fingerprint(txio.multisig) return multisig.multisig_fingerprint(txio.multisig)
class PubkeysOrderChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
if not txio.multisig:
return None
return txio.multisig is not None and txio.multisig.pubkeys_order
class MultisigChecker(MatchChecker): class MultisigChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any: def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
return txio.multisig is not None return txio.multisig is not None

View File

@ -1,7 +1,8 @@
from common import H_, await_result, unittest # isort:skip from common import H_, await_result, unittest # isort:skip
from ubinascii import hexlify, unhexlify from ubinascii import hexlify, unhexlify
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType, OutputScriptType, MultisigPubkeysOrder
from trezor.enums.MultisigPubkeysOrder import LEXICOGRAPHIC, PRESERVED
from trezor.messages import ( from trezor.messages import (
TxInput, TxInput,
TxOutput, TxOutput,
@ -39,12 +40,13 @@ xpub3 = HDNodeType(
) )
def get_multisig(path: list[int], xpubs: list[HDNodeType]) -> MultisigRedeemScriptType: def get_multisig(path: list[int], xpubs: list[HDNodeType], pubkeys_order: MultisigPubkeysOrder) -> MultisigRedeemScriptType:
return MultisigRedeemScriptType( return MultisigRedeemScriptType(
nodes=xpubs, nodes=xpubs,
signatures=b"" * len(xpubs), signatures=b"" * len(xpubs),
address_n=path[-2:], address_n=path[-2:],
m=2, m=2,
pubkeys_order=pubkeys_order,
) )
@ -58,23 +60,23 @@ def get_singlesig_input(path: list[int]) -> TxInput:
) )
def get_multisig_input(path: list[int], xpubs: list[HDNodeType]) -> TxInput: def get_multisig_input(path: list[int], xpubs: list[HDNodeType], pubkeys_order: MultisigPubkeysOrder) -> TxInput:
return TxInput( return TxInput(
address_n=path, address_n=path,
amount=1_000_000, amount=1_000_000,
prev_hash=bytes(32), prev_hash=bytes(32),
prev_index=0, prev_index=0,
script_type=InputScriptType.SPENDMULTISIG, script_type=InputScriptType.SPENDMULTISIG,
multisig=get_multisig(path, xpubs), multisig=get_multisig(path, xpubs, pubkeys_order),
) )
def get_internal_multisig_output(path: list[int], xpubs: list[HDNodeType]) -> TxOutput: def get_internal_multisig_output(path: list[int], xpubs: list[HDNodeType], pubkeys_order: MultisigPubkeysOrder) -> TxOutput:
return TxOutput( return TxOutput(
address_n=path, address_n=path,
amount=1_000_000, amount=1_000_000,
script_type=OutputScriptType.PAYTOMULTISIG, script_type=OutputScriptType.PAYTOMULTISIG,
multisig=get_multisig(path, xpubs), multisig=get_multisig(path, xpubs, pubkeys_order),
) )
@ -115,7 +117,8 @@ class TestChangeDetector(unittest.TestCase):
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False
# Multisig instead of singlesig # Multisig instead of singlesig
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1])) == False assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1], LEXICOGRAPHIC)) == False
# External output # External output
assert self.d.output_is_change(get_external_singlesig_output()) == False assert self.d.output_is_change(get_external_singlesig_output()) == False
@ -130,23 +133,24 @@ class TestChangeDetector(unittest.TestCase):
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False
# Multisig instead of singlesig # Multisig instead of singlesig
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1])) == False assert self.d.output_is_change(get_internal_multisig_output([H_(45), 1, 0, 0], [xpub1], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 1, 0, 0], [xpub1], LEXICOGRAPHIC)) == False
# External output # External output
assert self.d.output_is_change(get_external_singlesig_output()) == False assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_multisig(self): def test_unsorted_multisig(self):
# Different change and account index # Different change and account index
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2])) self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 1], [xpub1, xpub2])) self.d.add_input(get_multisig_input([H_(45), 0, 0, 1], [xpub1, xpub2], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 1, 0], [xpub1, xpub2])) self.d.add_input(get_multisig_input([H_(45), 0, 1, 0], [xpub1, xpub2], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 1, 1], [xpub1, xpub2])) self.d.add_input(get_multisig_input([H_(45), 0, 1, 1], [xpub1, xpub2], PRESERVED))
# Same outputs as inputs # Same outputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 1], [xpub1, xpub2])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 1], [xpub1, xpub2], PRESERVED)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 0], [xpub1, xpub2])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 0], [xpub1, xpub2], PRESERVED)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 1], [xpub1, xpub2])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 1], [xpub1, xpub2], PRESERVED)) == True
# Singlesig instead of multisig # Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
@ -155,22 +159,60 @@ class TestChangeDetector(unittest.TestCase):
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False
# Different order of xpubs # Different order of xpubs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], PRESERVED)) == False
# Sorted instead of unsorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == False
# Different xpubs # Different xpubs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3])) == False assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], PRESERVED)) == False
# External output # External output
assert self.d.output_is_change(get_external_singlesig_output()) == False assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_multisig_different_xpubs_order(self): def test_sorted_multisig(self):
# Different change and account index
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 1], [xpub1, xpub2], LEXICOGRAPHIC))
self.d.add_input(get_multisig_input([H_(45), 0, 1, 0], [xpub1, xpub2], LEXICOGRAPHIC))
self.d.add_input(get_multisig_input([H_(45), 0, 1, 1], [xpub1, xpub2], LEXICOGRAPHIC))
# Same outputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 1], [xpub1, xpub2], LEXICOGRAPHIC)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 1, 1], [xpub1, xpub2], LEXICOGRAPHIC)) == True
# Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
# Different account index
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 1, 0, 0])) == False
# Different order of xpubs # Different order of xpubs
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2])) assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], LEXICOGRAPHIC)) == True
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub2, xpub1]))
# Unsorted instead of sorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == False
# Different xpubs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], LEXICOGRAPHIC)) == False
# External output
assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_unsorted_multisig_different_xpubs_order(self):
# Different order of xpubs
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub2, xpub1], PRESERVED))
# Same ouputs as inputs # Same ouputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1])) == True assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], PRESERVED)) == False
# Sorted instead of unsorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], LEXICOGRAPHIC)) == False
# Singlesig instead of multisig # Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
@ -178,14 +220,70 @@ class TestChangeDetector(unittest.TestCase):
# External output # External output
assert self.d.output_is_change(get_external_singlesig_output()) == False assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_multisig_different_xpubs(self): def test_sorted_multisig_different_xpubs_order(self):
# Different xpubs # Different order of xpubs
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2])) self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub3])) self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub2, xpub1], LEXICOGRAPHIC))
# Same ouputs as inputs # Same ouputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2])) == False assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == True
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3])) == False assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], LEXICOGRAPHIC)) == True
# Sorted instead of unsorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub2, xpub1], PRESERVED)) == False
# Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
# External output
assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_unsorted_multisig_different_xpubs(self):
# Different xpubs
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub3], PRESERVED))
# Same ouputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], PRESERVED)) == False
# Sorted instead of unsorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], LEXICOGRAPHIC)) == False
# Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
# External output
assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_sorted_multisig_different_xpubs(self):
# Different xpubs
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1, xpub3], LEXICOGRAPHIC))
# Same ouputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], LEXICOGRAPHIC)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], LEXICOGRAPHIC)) == False
# Sorted instead of unsorted
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub2], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1, xpub3], PRESERVED)) == False
# Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False
# External output
assert self.d.output_is_change(get_external_singlesig_output()) == False
def test_mixed_sorted_and_unsorted_multisig_1(self):
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1], PRESERVED))
self.d.add_input(get_multisig_input([H_(45), 0, 0, 0], [xpub1], LEXICOGRAPHIC))
# Same ouputs as inputs
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1], PRESERVED)) == False
assert self.d.output_is_change(get_internal_multisig_output([H_(45), 0, 0, 0], [xpub1], LEXICOGRAPHIC)) == False
# Singlesig instead of multisig # Singlesig instead of multisig
assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False assert self.d.output_is_change(get_internal_singlesig_output([H_(45), 0, 0, 0])) == False

View File

@ -63,7 +63,7 @@ NODE_INT = bip32.deserialize(
# m/2 => 038caebd6f753bbbd2bb1f3346a43cd32140648583673a31d62f2dfb56ad0ab9e3 # m/2 => 038caebd6f753bbbd2bb1f3346a43cd32140648583673a31d62f2dfb56ad0ab9e3
multisig_in1 = messages.MultisigRedeemScriptType( multisig_in1 = messages.MultisigRedeemScriptType(
nodes=[NODE_EXT2, NODE_EXT1, NODE_INT], nodes=[NODE_EXT1, NODE_EXT2, NODE_INT],
address_n=[0, 0], address_n=[0, 0],
signatures=[b"", b"", b""], signatures=[b"", b"", b""],
m=2, m=2,
@ -84,7 +84,7 @@ multisig_in3 = messages.MultisigRedeemScriptType(
) )
prev_hash_1, prev_tx_1 = forge_prevtx( prev_hash_1, prev_tx_1 = forge_prevtx(
[("3HwrvQEfYw4wUvGHpGmixWB15HPgqrvTh1", 50_000_000)] [("3Ltgk5WPUMLcT2QvwRXKj9CWsYuAKqeHJ8", 50_000_000)]
) )
INP1 = messages.TxInputType( INP1 = messages.TxInputType(
address_n=[H_(45), 0, 0, 0], address_n=[H_(45), 0, 0, 0],