diff --git a/tests/test_protect_call.py b/tests/test_protect_call.py index db32a2f42..0ad07fb75 100644 --- a/tests/test_protect_call.py +++ b/tests/test_protect_call.py @@ -1,3 +1,4 @@ +import time import unittest import common @@ -10,7 +11,7 @@ class TestProtectCall(common.TrezorTest): entropy_len = 10 entropy = self.client.get_entropy(entropy_len) self.assertEqual(len(entropy), entropy_len) - + def test_no_protection(self): self.client.load_device(seed=self.mnemonic1, pin='') @@ -30,6 +31,37 @@ class TestProtectCall(common.TrezorTest): def test_cancelled_pin(self): self.client.setup_debuglink(button=True, pin_correct=-1) # PIN cancel self.assertRaises(PinException, self._some_protected_call) + + def test_exponential_backoff_with_reboot(self): + self.client.setup_debuglink(button=True, pin_correct=False) + def test_backoff(attempts, start): + expected = 1.8 ** attempts + got = time.time() - start + + msg = "Pin delay expected to be at least %s seconds, got %s" % (expected, got) + print msg + self.assertLessEqual(expected, got, msg) + + for attempt in range(1, 6): + start = time.time() + self.assertRaises(PinException, self._some_protected_call) + test_backoff(attempt, start) + + # Unplug Trezor now + self.client.debuglink.stop() + self.client.close() + + # Give it some time to reboot (it may take some time on RPi) + boot_delay = 5 + start = time.time() + time.sleep(boot_delay) + + # Connect to Trezor again + self.setUp() + print "Expected reboot time %s seconds" % (1.8 ** attempt) + print "Rebooted in %s seconds" % (time.time() - start) + self.assertLessEqual(1.8 ** attempt, time.time() - start, "Bootup took less than expected!") + if __name__ == '__main__': unittest.main() diff --git a/trezorlib/client.py b/trezorlib/client.py index ccb058c53..0d28cf1a6 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -41,7 +41,12 @@ class TrezorClient(object): def init_device(self): self.master_public_key = None self.features = self.call(proto.Initialize()) - + + def close(self): + self.transport.close() + if self.debuglink: + self.debuglink.transport.close() + def get_master_public_key(self): if self.master_public_key: return self.master_public_key diff --git a/trezorlib/debuglink.py b/trezorlib/debuglink.py index d32f7cbcf..0c7dc550d 100644 --- a/trezorlib/debuglink.py +++ b/trezorlib/debuglink.py @@ -45,4 +45,7 @@ class DebugLink(object): self.press_button(True) def press_no(self): - self.press_button(False) \ No newline at end of file + self.press_button(False) + + def stop(self): + self.transport.write(proto.DebugLinkStop())