diff --git a/src/apps/wallet/sign_tx/__init__.py b/src/apps/wallet/sign_tx/__init__.py index 1bca695ad..de198a25c 100644 --- a/src/apps/wallet/sign_tx/__init__.py +++ b/src/apps/wallet/sign_tx/__init__.py @@ -3,7 +3,7 @@ from trezor.messages.MessageType import TxAck from trezor.messages.RequestType import TXFINISHED from trezor.messages.TxRequest import TxRequest -from apps.common import coins, paths, seed +from apps.common import paths, seed from apps.wallet.sign_tx import ( addresses, helpers, @@ -18,13 +18,9 @@ from apps.wallet.sign_tx import ( @ui.layout async def sign_tx(ctx, msg): - coin_name = msg.coin_name or "Bitcoin" - coin = coins.by_name(coin_name) - # TODO: rework this so we don't have to pass root to signing.sign_tx keychain = await seed.get_keychain(ctx) - root = keychain.derive([], coin.curve_name) + signer = signing.sign_tx(msg, keychain) - signer = signing.sign_tx(msg, root) res = None while True: try: diff --git a/src/apps/wallet/sign_tx/helpers.py b/src/apps/wallet/sign_tx/helpers.py index b9bf83bab..e8f6a42d9 100644 --- a/src/apps/wallet/sign_tx/helpers.py +++ b/src/apps/wallet/sign_tx/helpers.py @@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxRequest import TxRequest +from trezor.utils import obj_eq from apps.common.coininfo import CoinInfo @@ -24,6 +25,8 @@ class UiConfirmOutput: self.output = output self.coin = coin + __eq__ = obj_eq + class UiConfirmTotal: def __init__(self, spending: int, fee: int, coin: CoinInfo): @@ -31,17 +34,23 @@ class UiConfirmTotal: self.fee = fee self.coin = coin + __eq__ = obj_eq + class UiConfirmFeeOverThreshold: def __init__(self, fee: int, coin: CoinInfo): self.fee = fee self.coin = coin + __eq__ = obj_eq + class UiConfirmForeignAddress: def __init__(self, address_n: list): self.address_n = address_n + __eq__ = obj_eq + def confirm_output(output: TxOutputType, coin: CoinInfo): return (yield UiConfirmOutput(output, coin)) diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 32d829edb..ef37fbae3 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -13,7 +13,7 @@ from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType -from apps.common import address_type, coininfo, coins +from apps.common import address_type, coininfo, coins, seed from apps.wallet.sign_tx import ( addresses, decred, @@ -53,7 +53,7 @@ class SigningError(ValueError): # - check inputs, previous transactions, and outputs # - ask for confirmations # - check fee -async def check_tx_fee(tx: SignTx, root: bip32.HDNode): +async def check_tx_fee(tx: SignTx, keychain: seed.Keychain): coin = coins.by_name(tx.coin_name) # h_first is used to make sure the inputs and outputs streamed in Phase 1 @@ -103,7 +103,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): hash143.add_sequence(txi) if not addresses.validate_full_path(txi.address_n, coin, txi.script_type): - await helpers.confirm_foreign_address(txi.address_n, coin) + await helpers.confirm_foreign_address(txi.address_n) if txi.multisig: multifp.add(txi.multisig) @@ -158,7 +158,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): # STAGE_REQUEST_3_OUTPUT txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) weight.add_output(txo_bin.script_pubkey) if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp): @@ -207,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): return h_first, hash143, segwit, total_in, wallet_path -async def sign_tx(tx: SignTx, root: bip32.HDNode): +async def sign_tx(tx: SignTx, keychain: seed.Keychain): tx = helpers.sanitize_sign_tx(tx) progress.init(tx.inputs_count, tx.outputs_count) # Phase 1 - h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) + h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee( + tx, keychain + ) # Phase 2 # - sign inputs @@ -247,7 +249,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): ) input_check_wallet_path(txi_sign, wallet_path) - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub) @@ -275,7 +277,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): ) authorized_in -= txi_sign.amount - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() hash143_hash = hash143.preimage_hash( coin, @@ -312,7 +314,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): input_check_wallet_path(txi_sign, wallet_path) - key_sign = node_derive(root, txi_sign.address_n) + key_sign = keychain.derive(txi_sign.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() if txi_sign.script_type == InputScriptType.SPENDMULTISIG: @@ -399,7 +401,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): writers.write_tx_input_check(h_second, txi) if i == i_sign: txi_sign = txi - key_sign = node_derive(root, txi.address_n) + key_sign = keychain.derive(txi.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() # for the signing process the script_sig is equal # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) @@ -431,7 +433,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): # STAGE_REQUEST_4_OUTPUT txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) writers.write_tx_output(h_second, txo_bin) writers.write_tx_output(h_sign, txo_bin) @@ -481,7 +483,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): # STAGE_REQUEST_5_OUTPUT txo = await helpers.request_tx_output(tx_req, o) txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + txo_bin.script_pubkey = output_derive_script(txo, coin, keychain) # serialize output w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) @@ -510,7 +512,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): ) authorized_in -= txi.amount - key_sign = node_derive(root, txi.address_n) + key_sign = keychain.derive(txi.address_n, coin.curve_name) key_sign_pub = key_sign.public_key() hash143_hash = hash143.preimage_hash( coin, @@ -675,7 +677,7 @@ def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False): def output_derive_script( - o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode + o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain ) -> bytes: if o.script_type == OutputScriptType.PAYTOOPRETURN: @@ -690,7 +692,7 @@ def output_derive_script( # change output if o.address: raise SigningError(FailureType.DataError, "Address in change output") - o.address = get_address_for_change(o, coin, root) + o.address = get_address_for_change(o, coin, keychain) else: if not o.address: raise SigningError(FailureType.DataError, "Missing address") @@ -739,7 +741,7 @@ def output_derive_script( def get_address_for_change( - o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode + o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain ): if o.script_type == OutputScriptType.PAYTOADDRESS: input_script_type = InputScriptType.SPENDADDRESS @@ -751,9 +753,8 @@ def get_address_for_change( input_script_type = InputScriptType.SPENDP2SHWITNESS else: raise SigningError(FailureType.DataError, "Invalid script type") - return addresses.get_address( - input_script_type, coin, node_derive(root, o.address_n), o.multisig - ) + node = keychain.derive(o.address_n, coin.curve_name) + return addresses.get_address(input_script_type, coin, node, o.multisig) def output_is_change( @@ -857,12 +858,6 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list: ) -def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode: - node = root.clone() - node.derive_path(address_n) - return node - - def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: sig = secp256k1.sign(node.private_key(), digest) sigder = der.encode_seq((sig[1:33], sig[33:65])) diff --git a/tests/test_apps.wallet.address.py b/tests/test_apps.wallet.address.py index 3d59580de..2df72c70e 100644 --- a/tests/test_apps.wallet.address.py +++ b/tests/test_apps.wallet.address.py @@ -10,6 +10,12 @@ from apps.wallet.sign_tx.signing import * from apps.wallet.sign_tx.writers import * +def node_derive(root, path): + node = root.clone() + node.derive_path(path) + return node + + class TestAddress(unittest.TestCase): # pylint: disable=C0301 diff --git a/tests/test_apps.wallet.address_grs.py b/tests/test_apps.wallet.address_grs.py index 6c5cd0e4e..8aeb173e0 100644 --- a/tests/test_apps.wallet.address_grs.py +++ b/tests/test_apps.wallet.address_grs.py @@ -6,6 +6,12 @@ from apps.common import coins from trezor.crypto import bip32, bip39 +def node_derive(root, path): + node = root.clone() + node.derive_path(path) + return node + + class TestAddressGRS(unittest.TestCase): # pylint: disable=C0301 diff --git a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py index 1ba571804..c960b531d 100644 --- a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py +++ b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py @@ -15,6 +15,7 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing @@ -61,7 +62,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), @@ -113,9 +114,10 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -159,7 +161,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), @@ -209,21 +211,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or - (isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py index c773acf80..b862f7efa 100644 --- a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py +++ b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh_grs.py @@ -15,6 +15,7 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing # https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3 @@ -110,9 +111,10 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -203,20 +205,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py index d757d23e1..c2e042a42 100644 --- a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py +++ b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh.py @@ -15,6 +15,7 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing @@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -213,9 +215,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -322,26 +325,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): TxRequest(request_type=TXFINISHED, details=None) ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) i = 0 messages_count = int(len(messages) / 2) for request, response in chunks(messages, 2): if i == messages_count - 1: # last message should throw SigningError self.assertRaises(signing.SigningError, signer.send, request) else: - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) i += 1 with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py index 4d98eb633..69dab1bab 100644 --- a/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py +++ b/tests/test_apps.wallet.segwit.signtx.p2wpkh_in_p2sh_grs.py @@ -15,6 +15,7 @@ from trezor.messages import InputScriptType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing # https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72 @@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) @@ -212,20 +214,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase): )), ] - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.signtx.fee_threshold.py b/tests/test_apps.wallet.signtx.fee_threshold.py index b8154105b..4b4f1935f 100644 --- a/tests/test_apps.wallet.signtx.fee_threshold.py +++ b/tests/test_apps.wallet.signtx.fee_threshold.py @@ -14,6 +14,7 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing @@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase): seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') root = bip32.from_seed(seed, 'secp256k1') - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) def test_under_threshold(self): coin_bitcoin = coins.by_name('Bitcoin') @@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -149,18 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase): seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') root = bip32.from_seed(seed, 'secp256k1') - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) - - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or - (isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) + self.assertEqual(signer.send(request), response) if __name__ == '__main__': diff --git a/tests/test_apps.wallet.signtx.py b/tests/test_apps.wallet.signtx.py index fa6a0ea4a..f4b8c34d0 100644 --- a/tests/test_apps.wallet.signtx.py +++ b/tests/test_apps.wallet.signtx.py @@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing @@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase): TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), TxAck(tx=TransactionType(inputs=[inp1])), - signing.UiConfirmForeignAddress(address_n=inp1.address_n), + helpers.UiConfirmForeignAddress(address_n=inp1.address_n), True, TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), TxAck(tx=ptx1), @@ -98,24 +99,16 @@ class TestSignTx(unittest.TestCase): seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') root = bip32.from_seed(seed, 'secp256k1') - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): res = signer.send(request) - self.assertEqualEx(res, response) + self.assertEqual(res, response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or - (isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_apps.wallet.signtx_grs.py b/tests/test_apps.wallet.signtx_grs.py index 1dbe69e89..8547b3db0 100644 --- a/tests/test_apps.wallet.signtx_grs.py +++ b/tests/test_apps.wallet.signtx_grs.py @@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages import OutputScriptType from apps.common import coins +from apps.common.seed import Keychain from apps.wallet.sign_tx import helpers, signing @@ -87,20 +88,13 @@ class TestSignTx_GRS(unittest.TestCase): seed = bip39.seed(' '.join(['all'] * 12), '') root = bip32.from_seed(seed, coin.curve_name) - signer = signing.sign_tx(tx, root) + keychain = Keychain([[coin.curve_name]], [root]) + signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): - self.assertEqualEx(signer.send(request), response) + self.assertEqual(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) - def assertEqualEx(self, a, b): - # hack to avoid adding __eq__ to helpers.Ui* classes - if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or - (isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))): - return self.assertEqual(a.__dict__, b.__dict__) - else: - return self.assertEqual(a, b) - if __name__ == '__main__': unittest.main()