diff --git a/tests/common.py b/tests/common.py index c40838268..466ac651c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,8 +6,8 @@ from trezorlib.debuglink import DebugLink class TrezorTest(unittest.TestCase): def setUp(self): - self.debug_transport = config.DEBUG_TRANSPORT(*config.DEBUG_TRANSPORT_ARGS) - self.transport = config.TRANSPORT(*config.TRANSPORT_ARGS) + self.debug_transport = config.DEBUG_TRANSPORT(*config.DEBUG_TRANSPORT_ARGS, **config.DEBUG_TRANSPORT_KWARGS) + self.transport = config.TRANSPORT(*config.TRANSPORT_ARGS, **config.TRANSPORT_KWARGS) self.client = TrezorClient(self.transport, DebugLink(self.debug_transport), debug=True) # self.client = TrezorClient(self.transport, debug=False) diff --git a/tests/config.py b/tests/config.py index 31cf06ae5..db48f26e4 100644 --- a/tests/config.py +++ b/tests/config.py @@ -5,18 +5,18 @@ from trezorlib.transport_pipe import PipeTransport from trezorlib.transport_hid import HidTransport from trezorlib.transport_socket import SocketTransportClient -use_real = False -use_pipe = True +use_real = True +use_pipe = False if use_real: devices = HidTransport.enumerate() TRANSPORT = HidTransport - TRANSPORT_ARGS = (devices[0], ) + TRANSPORT_ARGS = (devices[0],) TRANSPORT_KWARGS = {'debug_link': False} DEBUG_TRANSPORT = HidTransport - DEBUG_TRANSPORT_ARGS = (devices[0], ) + DEBUG_TRANSPORT_ARGS = (devices[0],) DEBUG_TRANSPORT_KWARGS = {'debug_link': True} elif use_pipe: @@ -33,7 +33,7 @@ else: devices = HidTransport.enumerate() TRANSPORT = HidTransport - TRANSPORT_ARGS = (devices[0], ) + TRANSPORT_ARGS = (devices[0][0],) TRANSPORT_KWARGS = {'debug_link': False} DEBUG_TRANSPORT = SocketTransportClient diff --git a/tests/test_addresses.py b/tests/test_addresses.py index 6e0787c90..6f56131cf 100644 --- a/tests/test_addresses.py +++ b/tests/test_addresses.py @@ -5,6 +5,7 @@ from trezorlib import tools class TestAddresses(common.TrezorTest): def test_btc(self): + self.client.wipe_device() self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, pin='', passphrase_protection=False, @@ -18,6 +19,7 @@ class TestAddresses(common.TrezorTest): self.assertEqual(self.client.get_address('Bitcoin', [0, 9999999]), '1GS8X3yc7ntzwGw9vXwj9wqmBWZkTFewBV') def test_ltc(self): + self.client.wipe_device() self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, pin='', passphrase_protection=False, @@ -31,6 +33,7 @@ class TestAddresses(common.TrezorTest): self.assertEqual(self.client.get_address('Litecoin', [0, 9999999]), 'Laf5nGHSCT94C5dK6fw2RxuXPiw2ZuRR9S') def test_tbtc(self): + self.client.wipe_device() self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, pin='', passphrase_protection=False, @@ -40,6 +43,7 @@ class TestAddresses(common.TrezorTest): self.assertEqual(self.client.get_address('Testnet', [111, 42]), 'moN6aN6NP1KWgnPSqzrrRPvx2x1UtZJssa') def test_public_ckd(self): + self.client.wipe_device() self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, pin='', passphrase_protection=False, diff --git a/tests/test_bip32_speed.py b/tests/test_bip32_speed.py new file mode 100644 index 000000000..9080a4798 --- /dev/null +++ b/tests/test_bip32_speed.py @@ -0,0 +1,40 @@ +import unittest +import common +import time +from trezorlib import tools + +class TestAddresses(common.TrezorTest): + def test_public_ckd(self): + self.client.wipe_device() + self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, + pin='', + passphrase_protection=False, + label='test', + language='english') + + for depth in range(8): + start = time.time() + self.client.get_address('Bitcoin', range(depth)) + delay = time.time() - start + expected = (depth + 1) * 0.25 + print "DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay + self.assertLessEqual(delay, expected) + + def test_private_ckd(self): + self.client.wipe_device() + self.client.load_device_by_mnemonic(mnemonic=self.mnemonic1, + pin='', + passphrase_protection=False, + label='test', + language='english') + + for depth in range(8): + start = time.time() + self.client.get_address('Bitcoin', range(-depth, 0)) + delay = time.time() - start + expected = (depth + 1) * 0.25 + print "DEPTH", depth, "EXPECTED DELAY", expected, "REAL DELAY", delay + self.assertLessEqual(delay, expected) + +if __name__ == '__main__': + unittest.main()