diff --git a/cmd.py b/cmd.py index b845795f3e..d8857cde39 100755 --- a/cmd.py +++ b/cmd.py @@ -110,6 +110,9 @@ class Commands(object): def change_pin(self, args): return self.client.change_pin(args.remove) + def wipe_device(self, args): + return self.client.wipe_device() + def load_device(self, args): if not args.mnemonic and not args.xprv: raise Exception("Please provide mnemonic or xprv") @@ -151,8 +154,9 @@ class Commands(object): set_label.help = 'Set new wallet label' change_pin.help = 'Change new PIN or remove existing' list_coins.help = 'List all supported coin types by the device' + wipe_device.help = 'Reset device to factory defaults and remove all private data.' load_device.help = 'Load custom configuration to the device' - reset_device.help = 'Perform factory reset of the device and generate new seed' + reset_device.help = 'Perform device setup and generate new seed' sign_message.help = 'Sign message using address of given path' verify_message.help = 'Verify message' firmware_update.help = 'Upload new firmware to device (must be in bootloader mode)' @@ -184,6 +188,8 @@ class Commands(object): (('-r', '--remove'), {'action': 'store_true', 'default': False}), ) + wipe_device.arguments = () + load_device.arguments = ( (('-m', '--mnemonic'), {'type': str, 'nargs': '+'}), (('-x', '--xprv'), {'type': str}), diff --git a/tests/common.py b/tests/common.py index c1454a44ef..d2c7ddc79b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -17,6 +17,7 @@ class TrezorTest(unittest.TestCase): self.client.setup_debuglink(button=True, pin_correct=True) + self.client.wipe_device() self.client.load_device_by_mnemonic( mnemonic=self.mnemonic1, pin=self.pin1, diff --git a/tests/test_basic.py b/tests/test_basic.py index e8e71e8a9a..16817df01c 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -25,16 +25,24 @@ class TestBasic(common.TrezorTest): # Ping results in Success(message='Ahoj!') self.assertEqual(ping, messages.Success(message='ahoj!')) - def test_uuid(self): - uuid1 = self.client.get_device_id() + def test_device_id_same(self): + id1 = self.client.get_device_id() self.client.init_device() - uuid2 = self.client.get_device_id() + id2 = self.client.get_device_id() - # UUID must be at least 12 characters - self.assertTrue(len(uuid1) >= 12) + # ID must be at least 12 characters + self.assertTrue(len(id1) >= 12) # Every resulf of UUID must be the same - self.assertEqual(uuid1, uuid2) + self.assertEqual(id1, id2) + + def test_device_id_different(self): + id1 = self.client.get_device_id() + self.client.wipe_device() + id2 = self.client.get_device_id() + + # Device ID must be fresh after every reset + self.assertNotEqual(id1, id2) if __name__ == '__main__': unittest.main() diff --git a/trezorlib/client.py b/trezorlib/client.py index 2943e48df2..9e8d90215c 100755 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -320,7 +320,15 @@ class TrezorClient(object): return (signatures, serialized_tx) + def wipe_device(self): + ret = self.call(proto.WipeDevice()) + self.init_device() + return ret + def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + # Begin with device reset workflow msg = proto.ResetDevice(display_random=display_random, strength=strength, @@ -341,6 +349,9 @@ class TrezorClient(object): return isinstance(resp, proto.Success) def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin, passphrase_protection=passphrase_protection, language=language, @@ -349,6 +360,9 @@ class TrezorClient(object): return isinstance(resp, proto.Success) def load_device_by_xprv(self, xprv, pin, passphrase_protection, label): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + if xprv[0:4] not in ('xprv', 'tprv'): raise Exception("Unknown type of xprv")