mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
wallet: use keychain API in signing, fix tests
This commit is contained in:
parent
172f3cb22f
commit
852bf8f4ef
@ -3,7 +3,7 @@ from trezor.messages.MessageType import TxAck
|
||||
from trezor.messages.RequestType import TXFINISHED
|
||||
from trezor.messages.TxRequest import TxRequest
|
||||
|
||||
from apps.common import coins, paths, seed
|
||||
from apps.common import paths, seed
|
||||
from apps.wallet.sign_tx import (
|
||||
addresses,
|
||||
helpers,
|
||||
@ -18,13 +18,9 @@ from apps.wallet.sign_tx import (
|
||||
|
||||
@ui.layout
|
||||
async def sign_tx(ctx, msg):
|
||||
coin_name = msg.coin_name or "Bitcoin"
|
||||
coin = coins.by_name(coin_name)
|
||||
# TODO: rework this so we don't have to pass root to signing.sign_tx
|
||||
keychain = await seed.get_keychain(ctx)
|
||||
root = keychain.derive([], coin.curve_name)
|
||||
signer = signing.sign_tx(msg, keychain)
|
||||
|
||||
signer = signing.sign_tx(msg, root)
|
||||
res = None
|
||||
while True:
|
||||
try:
|
||||
|
@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType
|
||||
from trezor.messages.TxOutputBinType import TxOutputBinType
|
||||
from trezor.messages.TxOutputType import TxOutputType
|
||||
from trezor.messages.TxRequest import TxRequest
|
||||
from trezor.utils import obj_eq
|
||||
|
||||
from apps.common.coininfo import CoinInfo
|
||||
|
||||
@ -24,6 +25,8 @@ class UiConfirmOutput:
|
||||
self.output = output
|
||||
self.coin = coin
|
||||
|
||||
__eq__ = obj_eq
|
||||
|
||||
|
||||
class UiConfirmTotal:
|
||||
def __init__(self, spending: int, fee: int, coin: CoinInfo):
|
||||
@ -31,17 +34,23 @@ class UiConfirmTotal:
|
||||
self.fee = fee
|
||||
self.coin = coin
|
||||
|
||||
__eq__ = obj_eq
|
||||
|
||||
|
||||
class UiConfirmFeeOverThreshold:
|
||||
def __init__(self, fee: int, coin: CoinInfo):
|
||||
self.fee = fee
|
||||
self.coin = coin
|
||||
|
||||
__eq__ = obj_eq
|
||||
|
||||
|
||||
class UiConfirmForeignAddress:
|
||||
def __init__(self, address_n: list):
|
||||
self.address_n = address_n
|
||||
|
||||
__eq__ = obj_eq
|
||||
|
||||
|
||||
def confirm_output(output: TxOutputType, coin: CoinInfo):
|
||||
return (yield UiConfirmOutput(output, coin))
|
||||
|
@ -13,7 +13,7 @@ from trezor.messages.TxRequest import TxRequest
|
||||
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||
|
||||
from apps.common import address_type, coininfo, coins
|
||||
from apps.common import address_type, coininfo, coins, seed
|
||||
from apps.wallet.sign_tx import (
|
||||
addresses,
|
||||
decred,
|
||||
@ -53,7 +53,7 @@ class SigningError(ValueError):
|
||||
# - check inputs, previous transactions, and outputs
|
||||
# - ask for confirmations
|
||||
# - check fee
|
||||
async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
|
||||
async def check_tx_fee(tx: SignTx, keychain: seed.Keychain):
|
||||
coin = coins.by_name(tx.coin_name)
|
||||
|
||||
# h_first is used to make sure the inputs and outputs streamed in Phase 1
|
||||
@ -103,7 +103,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
|
||||
hash143.add_sequence(txi)
|
||||
|
||||
if not addresses.validate_full_path(txi.address_n, coin, txi.script_type):
|
||||
await helpers.confirm_foreign_address(txi.address_n, coin)
|
||||
await helpers.confirm_foreign_address(txi.address_n)
|
||||
|
||||
if txi.multisig:
|
||||
multifp.add(txi.multisig)
|
||||
@ -158,7 +158,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
|
||||
# STAGE_REQUEST_3_OUTPUT
|
||||
txo = await helpers.request_tx_output(tx_req, o)
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
|
||||
weight.add_output(txo_bin.script_pubkey)
|
||||
|
||||
if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp):
|
||||
@ -207,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
|
||||
return h_first, hash143, segwit, total_in, wallet_path
|
||||
|
||||
|
||||
async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
async def sign_tx(tx: SignTx, keychain: seed.Keychain):
|
||||
tx = helpers.sanitize_sign_tx(tx)
|
||||
|
||||
progress.init(tx.inputs_count, tx.outputs_count)
|
||||
|
||||
# Phase 1
|
||||
|
||||
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root)
|
||||
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(
|
||||
tx, keychain
|
||||
)
|
||||
|
||||
# Phase 2
|
||||
# - sign inputs
|
||||
@ -247,7 +249,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
)
|
||||
input_check_wallet_path(txi_sign, wallet_path)
|
||||
|
||||
key_sign = node_derive(root, txi_sign.address_n)
|
||||
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
|
||||
|
||||
@ -275,7 +277,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
)
|
||||
authorized_in -= txi_sign.amount
|
||||
|
||||
key_sign = node_derive(root, txi_sign.address_n)
|
||||
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
hash143_hash = hash143.preimage_hash(
|
||||
coin,
|
||||
@ -312,7 +314,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
|
||||
input_check_wallet_path(txi_sign, wallet_path)
|
||||
|
||||
key_sign = node_derive(root, txi_sign.address_n)
|
||||
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
|
||||
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
|
||||
@ -399,7 +401,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
writers.write_tx_input_check(h_second, txi)
|
||||
if i == i_sign:
|
||||
txi_sign = txi
|
||||
key_sign = node_derive(root, txi.address_n)
|
||||
key_sign = keychain.derive(txi.address_n, coin.curve_name)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
# for the signing process the script_sig is equal
|
||||
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
|
||||
@ -431,7 +433,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
# STAGE_REQUEST_4_OUTPUT
|
||||
txo = await helpers.request_tx_output(tx_req, o)
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
|
||||
writers.write_tx_output(h_second, txo_bin)
|
||||
writers.write_tx_output(h_sign, txo_bin)
|
||||
|
||||
@ -481,7 +483,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
# STAGE_REQUEST_5_OUTPUT
|
||||
txo = await helpers.request_tx_output(tx_req, o)
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
|
||||
|
||||
# serialize output
|
||||
w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
|
||||
@ -510,7 +512,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
|
||||
)
|
||||
authorized_in -= txi.amount
|
||||
|
||||
key_sign = node_derive(root, txi.address_n)
|
||||
key_sign = keychain.derive(txi.address_n, coin.curve_name)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
hash143_hash = hash143.preimage_hash(
|
||||
coin,
|
||||
@ -675,7 +677,7 @@ def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False):
|
||||
|
||||
|
||||
def output_derive_script(
|
||||
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode
|
||||
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
|
||||
) -> bytes:
|
||||
|
||||
if o.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||
@ -690,7 +692,7 @@ def output_derive_script(
|
||||
# change output
|
||||
if o.address:
|
||||
raise SigningError(FailureType.DataError, "Address in change output")
|
||||
o.address = get_address_for_change(o, coin, root)
|
||||
o.address = get_address_for_change(o, coin, keychain)
|
||||
else:
|
||||
if not o.address:
|
||||
raise SigningError(FailureType.DataError, "Missing address")
|
||||
@ -739,7 +741,7 @@ def output_derive_script(
|
||||
|
||||
|
||||
def get_address_for_change(
|
||||
o: TxOutputType, coin: coininfo.CoinInfo, root: bip32.HDNode
|
||||
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
|
||||
):
|
||||
if o.script_type == OutputScriptType.PAYTOADDRESS:
|
||||
input_script_type = InputScriptType.SPENDADDRESS
|
||||
@ -751,9 +753,8 @@ def get_address_for_change(
|
||||
input_script_type = InputScriptType.SPENDP2SHWITNESS
|
||||
else:
|
||||
raise SigningError(FailureType.DataError, "Invalid script type")
|
||||
return addresses.get_address(
|
||||
input_script_type, coin, node_derive(root, o.address_n), o.multisig
|
||||
)
|
||||
node = keychain.derive(o.address_n, coin.curve_name)
|
||||
return addresses.get_address(input_script_type, coin, node, o.multisig)
|
||||
|
||||
|
||||
def output_is_change(
|
||||
@ -857,12 +858,6 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
|
||||
)
|
||||
|
||||
|
||||
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
|
||||
node = root.clone()
|
||||
node.derive_path(address_n)
|
||||
return node
|
||||
|
||||
|
||||
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
|
||||
sig = secp256k1.sign(node.private_key(), digest)
|
||||
sigder = der.encode_seq((sig[1:33], sig[33:65]))
|
||||
|
@ -10,6 +10,12 @@ from apps.wallet.sign_tx.signing import *
|
||||
from apps.wallet.sign_tx.writers import *
|
||||
|
||||
|
||||
def node_derive(root, path):
|
||||
node = root.clone()
|
||||
node.derive_path(path)
|
||||
return node
|
||||
|
||||
|
||||
class TestAddress(unittest.TestCase):
|
||||
# pylint: disable=C0301
|
||||
|
||||
|
@ -6,6 +6,12 @@ from apps.common import coins
|
||||
from trezor.crypto import bip32, bip39
|
||||
|
||||
|
||||
def node_derive(root, path):
|
||||
node = root.clone()
|
||||
node.derive_path(path)
|
||||
return node
|
||||
|
||||
|
||||
class TestAddressGRS(unittest.TestCase):
|
||||
# pylint: disable=C0301
|
||||
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
|
||||
@ -61,7 +62,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||
|
||||
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
True,
|
||||
|
||||
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||
@ -113,9 +114,10 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
@ -159,7 +161,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||
|
||||
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
True,
|
||||
|
||||
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||
@ -209,21 +211,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
|
||||
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
# https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3
|
||||
@ -110,9 +111,10 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
@ -203,20 +205,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
|
||||
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
@ -213,9 +215,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
@ -322,26 +325,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
|
||||
TxRequest(request_type=TXFINISHED, details=None)
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
i = 0
|
||||
messages_count = int(len(messages) / 2)
|
||||
for request, response in chunks(messages, 2):
|
||||
if i == messages_count - 1: # last message should throw SigningError
|
||||
self.assertRaises(signing.SigningError, signer.send, request)
|
||||
else:
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
i += 1
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages import InputScriptType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
# https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72
|
||||
@ -110,9 +111,10 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
@ -212,20 +214,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
|
||||
)),
|
||||
]
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -14,6 +14,7 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
|
||||
@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
|
||||
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
True,
|
||||
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||
TxAck(tx=ptx1),
|
||||
@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
|
||||
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
|
||||
root = bip32.from_seed(seed, 'secp256k1')
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
|
||||
def test_under_threshold(self):
|
||||
coin_bitcoin = coins.by_name('Bitcoin')
|
||||
@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
|
||||
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
True,
|
||||
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||
TxAck(tx=ptx1),
|
||||
@ -149,18 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
|
||||
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
|
||||
root = bip32.from_seed(seed, 'secp256k1')
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
|
||||
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
|
||||
@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase):
|
||||
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
|
||||
True,
|
||||
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||
TxAck(tx=ptx1),
|
||||
@ -98,24 +99,16 @@ class TestSignTx(unittest.TestCase):
|
||||
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
|
||||
root = bip32.from_seed(seed, 'secp256k1')
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
|
||||
for request, response in chunks(messages, 2):
|
||||
res = signer.send(request)
|
||||
self.assertEqualEx(res, response)
|
||||
self.assertEqual(res, response)
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal)) or
|
||||
(isinstance(a, helpers.UiConfirmForeignAddress) and isinstance(b, helpers.UiConfirmForeignAddress))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,6 +15,7 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import coins
|
||||
from apps.common.seed import Keychain
|
||||
from apps.wallet.sign_tx import helpers, signing
|
||||
|
||||
|
||||
@ -87,20 +88,13 @@ class TestSignTx_GRS(unittest.TestCase):
|
||||
seed = bip39.seed(' '.join(['all'] * 12), '')
|
||||
root = bip32.from_seed(seed, coin.curve_name)
|
||||
|
||||
signer = signing.sign_tx(tx, root)
|
||||
keychain = Keychain([[coin.curve_name]], [root])
|
||||
signer = signing.sign_tx(tx, keychain)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
self.assertEqual(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to helpers.Ui* classes
|
||||
if ((isinstance(a, helpers.UiConfirmOutput) and isinstance(b, helpers.UiConfirmOutput)) or
|
||||
(isinstance(a, helpers.UiConfirmTotal) and isinstance(b, helpers.UiConfirmTotal))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user