1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-16 11:38:12 +00:00

wallet: use keychain API in signing, fix tests

This commit is contained in:
Jan Pochyla 2018-12-04 14:27:06 +01:00
parent 172f3cb22f
commit 852bf8f4ef
12 changed files with 94 additions and 121 deletions

View File

@ -3,7 +3,7 @@ from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths, seed from apps.common import paths, seed
from apps.wallet.sign_tx import ( from apps.wallet.sign_tx import (
addresses, addresses,
helpers, helpers,
@ -18,13 +18,9 @@ from apps.wallet.sign_tx import (
@ui.layout @ui.layout
async def sign_tx(ctx, msg): async def sign_tx(ctx, msg):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
# TODO: rework this so we don't have to pass root to signing.sign_tx
keychain = await seed.get_keychain(ctx) keychain = await seed.get_keychain(ctx)
root = keychain.derive([], coin.curve_name) signer = signing.sign_tx(msg, keychain)
signer = signing.sign_tx(msg, root)
res = None res = None
while True: while True:
try: try:

View File

@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from trezor.utils import obj_eq
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -24,6 +25,8 @@ class UiConfirmOutput:
self.output = output self.output = output
self.coin = coin self.coin = coin
__eq__ = obj_eq
class UiConfirmTotal: class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinInfo): def __init__(self, spending: int, fee: int, coin: CoinInfo):
@ -31,17 +34,23 @@ class UiConfirmTotal:
self.fee = fee self.fee = fee
self.coin = coin self.coin = coin
__eq__ = obj_eq
class UiConfirmFeeOverThreshold: class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinInfo): def __init__(self, fee: int, coin: CoinInfo):
self.fee = fee self.fee = fee
self.coin = coin self.coin = coin
__eq__ = obj_eq
class UiConfirmForeignAddress: class UiConfirmForeignAddress:
def __init__(self, address_n: list): def __init__(self, address_n: list):
self.address_n = address_n self.address_n = address_n
__eq__ = obj_eq
def confirm_output(output: TxOutputType, coin: CoinInfo): def confirm_output(output: TxOutputType, coin: CoinInfo):
return (yield UiConfirmOutput(output, coin)) return (yield UiConfirmOutput(output, coin))

View File

@ -13,7 +13,7 @@ from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from apps.common import address_type, coininfo, coins from apps.common import address_type, coininfo, coins, seed
from apps.wallet.sign_tx import ( from apps.wallet.sign_tx import (
addresses, addresses,
decred, decred,
@ -53,7 +53,7 @@ class SigningError(ValueError):
# - check inputs, previous transactions, and outputs # - check inputs, previous transactions, and outputs
# - ask for confirmations # - ask for confirmations
# - check fee # - check fee
async def check_tx_fee(tx: SignTx, root: bip32.HDNode): async def check_tx_fee(tx: SignTx, keychain: seed.Keychain):
coin = coins.by_name(tx.coin_name) coin = coins.by_name(tx.coin_name)
# h_first is used to make sure the inputs and outputs streamed in Phase 1 # h_first is used to make sure the inputs and outputs streamed in Phase 1
@ -103,7 +103,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
hash143.add_sequence(txi) hash143.add_sequence(txi)
if not addresses.validate_full_path(txi.address_n, coin, txi.script_type): if not addresses.validate_full_path(txi.address_n, coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n, coin) await helpers.confirm_foreign_address(txi.address_n)
if txi.multisig: if txi.multisig:
multifp.add(txi.multisig) multifp.add(txi.multisig)
@ -158,7 +158,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_3_OUTPUT # STAGE_REQUEST_3_OUTPUT
txo = await helpers.request_tx_output(tx_req, o) txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
weight.add_output(txo_bin.script_pubkey) weight.add_output(txo_bin.script_pubkey)
if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp): if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp):
@ -207,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
return h_first, hash143, segwit, total_in, wallet_path return h_first, hash143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root: bip32.HDNode): async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx = helpers.sanitize_sign_tx(tx) tx = helpers.sanitize_sign_tx(tx)
progress.init(tx.inputs_count, tx.outputs_count) progress.init(tx.inputs_count, tx.outputs_count)
# Phase 1 # Phase 1
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(
tx, keychain
)
# Phase 2 # Phase 2
# - sign inputs # - sign inputs
@ -247,7 +249,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
) )
input_check_wallet_path(txi_sign, wallet_path) input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n) key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub) txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
@ -275,7 +277,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
) )
authorized_in -= txi_sign.amount authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n) key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, coin,
@ -312,7 +314,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
input_check_wallet_path(txi_sign, wallet_path) input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n) key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
@ -399,7 +401,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
writers.write_tx_input_check(h_second, txi) writers.write_tx_input_check(h_second, txi)
if i == i_sign: if i == i_sign:
txi_sign = txi txi_sign = txi
key_sign = node_derive(root, txi.address_n) key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
# for the signing process the script_sig is equal # for the signing process the script_sig is equal
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
@ -431,7 +433,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_4_OUTPUT # STAGE_REQUEST_4_OUTPUT
txo = await helpers.request_tx_output(tx_req, o) txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
writers.write_tx_output(h_second, txo_bin) writers.write_tx_output(h_second, txo_bin)
writers.write_tx_output(h_sign, txo_bin) writers.write_tx_output(h_sign, txo_bin)
@ -481,7 +483,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_5_OUTPUT # STAGE_REQUEST_5_OUTPUT
txo = await helpers.request_tx_output(tx_req, o) txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
# serialize output # serialize output
w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
@ -510,7 +512,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
) )
authorized_in -= txi.amount authorized_in -= txi.amount
key_sign = node_derive(root, txi.address_n) key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, coin,
@ -675,7 +677,7 @@ def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False):
def output_derive_script( def output_derive_script(
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
) -> bytes: ) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN: if o.script_type == OutputScriptType.PAYTOOPRETURN:
@ -690,7 +692,7 @@ def output_derive_script(
# change output # change output
if o.address: if o.address:
raise SigningError(FailureType.DataError, "Address in change output") raise SigningError(FailureType.DataError, "Address in change output")
o.address = get_address_for_change(o, coin, root) o.address = get_address_for_change(o, coin, keychain)
else: else:
if not o.address: if not o.address:
raise SigningError(FailureType.DataError, "Missing address") raise SigningError(FailureType.DataError, "Missing address")
@ -739,7 +741,7 @@ def output_derive_script(
def get_address_for_change( def get_address_for_change(
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
): ):
if o.script_type == OutputScriptType.PAYTOADDRESS: if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS input_script_type = InputScriptType.SPENDADDRESS
@ -751,9 +753,8 @@ def get_address_for_change(
input_script_type = InputScriptType.SPENDP2SHWITNESS input_script_type = InputScriptType.SPENDP2SHWITNESS
else: else:
raise SigningError(FailureType.DataError, "Invalid script type") raise SigningError(FailureType.DataError, "Invalid script type")
return addresses.get_address( node = keychain.derive(o.address_n, coin.curve_name)
input_script_type, coin, node_derive(root, o.address_n), o.multisig return addresses.get_address(input_script_type, coin, node, o.multisig)
)
def output_is_change( def output_is_change(
@ -857,12 +858,6 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
) )
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest) sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65])) sigder = der.encode_seq((sig[1:33], sig[33:65]))

View File

@ -10,6 +10,12 @@ from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.writers import *
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddress(unittest.TestCase): class TestAddress(unittest.TestCase):
# pylint: disable=C0301 # pylint: disable=C0301

View File

@ -6,6 +6,12 @@ from apps.common import coins
from trezor.crypto import bip32, bip39 from trezor.crypto import bip32, bip39
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddressGRS(unittest.TestCase): class TestAddressGRS(unittest.TestCase):
# pylint: disable=C0301 # pylint: disable=C0301

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
@ -61,7 +62,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])), TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n), helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True, True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -113,9 +114,10 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
@ -159,7 +161,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])), TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n), helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True, True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -209,21 +211,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3 # https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3
@ -110,9 +111,10 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
@ -203,20 +205,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
@ -213,9 +215,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
@ -322,26 +325,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
TxRequest(request_type=TXFINISHED, details=None) TxRequest(request_type=TXFINISHED, details=None)
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
i = 0 i = 0
messages_count = int(len(messages) / 2) messages_count = int(len(messages) / 2)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
if i == messages_count - 1: # last message should throw SigningError if i == messages_count - 1: # last message should throw SigningError
self.assertRaises(signing.SigningError, signer.send, request) self.assertRaises(signing.SigningError, signer.send, request)
else: else:
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
i += 1 i += 1
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72 # https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
@ -212,20 +214,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -14,6 +14,7 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])), TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n), helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True, True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1), TxAck(tx=ptx1),
@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1') root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
def test_under_threshold(self): def test_under_threshold(self):
coin_bitcoin = coins.by_name('Bitcoin') coin_bitcoin = coins.by_name('Bitcoin')
@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])), TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n), helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True, True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1), TxAck(tx=ptx1),
@ -149,18 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1') root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])), TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n), helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True, True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1), TxAck(tx=ptx1),
@ -98,24 +99,16 @@ class TestSignTx(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1') root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
res = signer.send(request) res = signer.send(request)
self.assertEqualEx(res, response) self.assertEqual(res, response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing from apps.wallet.sign_tx import helpers, signing
@ -87,20 +88,13 @@ class TestSignTx_GRS(unittest.TestCase):
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name) root = bip32.from_seed(seed, coin.curve_name)
signer = signing.sign_tx(tx, root) keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response) self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()