1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 06:48:16 +00:00

wallet: use keychain API in signing, fix tests

This commit is contained in:
Jan Pochyla 2018-12-04 14:27:06 +01:00
parent 172f3cb22f
commit 852bf8f4ef
12 changed files with 94 additions and 121 deletions

View File

@ -3,7 +3,7 @@ from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths, seed
from apps.common import paths, seed
from apps.wallet.sign_tx import (
addresses,
helpers,
@ -18,13 +18,9 @@ from apps.wallet.sign_tx import (
@ui.layout
async def sign_tx(ctx, msg):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
# TODO: rework this so we don't have to pass root to signing.sign_tx
keychain = await seed.get_keychain(ctx)
root = keychain.derive([], coin.curve_name)
signer = signing.sign_tx(msg, keychain)
signer = signing.sign_tx(msg, root)
res = None
while True:
try:

View File

@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.utils import obj_eq
from apps.common.coininfo import CoinInfo
@ -24,6 +25,8 @@ class UiConfirmOutput:
self.output = output
self.coin = coin
__eq__ = obj_eq
class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinInfo):
@ -31,17 +34,23 @@ class UiConfirmTotal:
self.fee = fee
self.coin = coin
__eq__ = obj_eq
class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinInfo):
self.fee = fee
self.coin = coin
__eq__ = obj_eq
class UiConfirmForeignAddress:
def __init__(self, address_n: list):
self.address_n = address_n
__eq__ = obj_eq
def confirm_output(output: TxOutputType, coin: CoinInfo):
return (yield UiConfirmOutput(output, coin))

View File

@ -13,7 +13,7 @@ from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from apps.common import address_type, coininfo, coins
from apps.common import address_type, coininfo, coins, seed
from apps.wallet.sign_tx import (
addresses,
decred,
@ -53,7 +53,7 @@ class SigningError(ValueError):
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
async def check_tx_fee(tx: SignTx, keychain: seed.Keychain):
coin = coins.by_name(tx.coin_name)
# h_first is used to make sure the inputs and outputs streamed in Phase 1
@ -103,7 +103,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
hash143.add_sequence(txi)
if not addresses.validate_full_path(txi.address_n, coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n, coin)
await helpers.confirm_foreign_address(txi.address_n)
if txi.multisig:
multifp.add(txi.multisig)
@ -158,7 +158,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_3_OUTPUT
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
weight.add_output(txo_bin.script_pubkey)
if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp):
@ -207,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
return h_first, hash143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root: bip32.HDNode):
async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx = helpers.sanitize_sign_tx(tx)
progress.init(tx.inputs_count, tx.outputs_count)
# Phase 1
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root)
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(
tx, keychain
)
# Phase 2
# - sign inputs
@ -247,7 +249,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
)
input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
@ -275,7 +277,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
)
authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin,
@ -312,7 +314,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
@ -399,7 +401,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
writers.write_tx_input_check(h_second, txi)
if i == i_sign:
txi_sign = txi
key_sign = node_derive(root, txi.address_n)
key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
# for the signing process the script_sig is equal
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
@ -431,7 +433,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_4_OUTPUT
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
writers.write_tx_output(h_second, txo_bin)
writers.write_tx_output(h_sign, txo_bin)
@ -481,7 +483,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# STAGE_REQUEST_5_OUTPUT
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
# serialize output
w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
@ -510,7 +512,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
)
authorized_in -= txi.amount
key_sign = node_derive(root, txi.address_n)
key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin,
@ -675,7 +677,7 @@ def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False):
def output_derive_script(
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN:
@ -690,7 +692,7 @@ def output_derive_script(
# change output
if o.address:
raise SigningError(FailureType.DataError, "Address in change output")
o.address = get_address_for_change(o, coin, root)
o.address = get_address_for_change(o, coin, keychain)
else:
if not o.address:
raise SigningError(FailureType.DataError, "Missing address")
@ -739,7 +741,7 @@ def output_derive_script(
def get_address_for_change(
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
):
if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS
@ -751,9 +753,8 @@ def get_address_for_change(
input_script_type = InputScriptType.SPENDP2SHWITNESS
else:
raise SigningError(FailureType.DataError, "Invalid script type")
return addresses.get_address(
input_script_type, coin, node_derive(root, o.address_n), o.multisig
)
node = keychain.derive(o.address_n, coin.curve_name)
return addresses.get_address(input_script_type, coin, node, o.multisig)
def output_is_change(
@ -857,12 +858,6 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
)
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65]))

View File

@ -10,6 +10,12 @@ from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.writers import *
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddress(unittest.TestCase):
# pylint: disable=C0301

View File

@ -6,6 +6,12 @@ from apps.common import coins
from trezor.crypto import bip32, bip39
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddressGRS(unittest.TestCase):
# pylint: disable=C0301

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
@ -61,7 +62,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -113,9 +114,10 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -159,7 +161,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -209,21 +211,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3
@ -110,9 +111,10 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -203,20 +205,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -213,9 +215,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -322,26 +325,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
TxRequest(request_type=TXFINISHED, details=None)
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
i = 0
messages_count = int(len(messages) / 2)
for request, response in chunks(messages, 2):
if i == messages_count - 1: # last message should throw SigningError
self.assertRaises(signing.SigningError, signer.send, request)
else:
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
i += 1
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -212,20 +214,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -14,6 +14,7 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
def test_under_threshold(self):
coin_bitcoin = coins.by_name('Bitcoin')
@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -149,18 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
self.assertEqual(signer.send(request), response)
if __name__ == '__main__':

View File

@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -98,24 +99,16 @@ class TestSignTx(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
res = signer.send(request)
self.assertEqualEx(res, response)
self.assertEqual(res, response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
@ -87,20 +88,13 @@ class TestSignTx_GRS(unittest.TestCase):
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to helpers.Ui* classes
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()