diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index b74535bbce..0d6f68ecf0 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -16,8 +16,8 @@ import logging import os -import sys import warnings +from typing import Optional from mnemonic import Mnemonic @@ -25,9 +25,6 @@ from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools from .log import DUMP_BYTES from .messages import Capability -if sys.version_info.major < 3: - raise Exception("Trezorlib does not support Python 2 anymore.") - LOG = logging.getLogger(__name__) VENDORS = ("bitcointrezor.com", "trezor.io") @@ -103,6 +100,7 @@ class TrezorClient: def close(self): self.session_counter = max(self.session_counter - 1, 0) if self.session_counter == 0: + # TODO call EndSession here? self.transport.end_session() def cancel(self): @@ -226,18 +224,12 @@ class TrezorClient: else: return resp - @tools.session - def init_device(self): - resp = self.call_raw(messages.Initialize(session_id=self.session_id)) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected initial response") - else: - self.features = resp - if self.features.vendor not in VENDORS: + def _refresh_features(self, features: messages.Features) -> None: + """Update internal fields based on passed-in Features message.""" + if features.vendor not in VENDORS: raise RuntimeError("Unsupported device") - # A side-effect of this is a sanity check for broken protobuf definitions. - # If the `vendor` field doesn't exist, you probably have a mismatched - # checkout of trezor-common. + + self.features = features self.version = ( self.features.major_version, self.features.minor_version, @@ -246,6 +238,72 @@ class TrezorClient: self.check_firmware_version(warn_only=True) if self.features.session_id is not None: self.session_id = self.features.session_id + self.features.session_id = None + + @tools.session + def refresh_features(self) -> messages.Features: + """Reload features from the device. + + Should be called after changing settings or performing operations that affect + device state. + """ + resp = self.call_raw(messages.GetFeatures()) + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._refresh_features(resp) + return resp + + @tools.session + def init_device( + self, *, session_id: bytes = None, new_session: bool = False + ) -> Optional[bytes]: + """Initialize the device and return a session ID. + + You can optionally specify a session ID. If the session still exists on the + device, the same session ID will be returned and the session is resumed. + Otherwise a different session ID is returned. + + Specify `new_session=True` to open a fresh session. Since firmware version + 1.9.0/2.3.0, the previous session will remain cached on the device, and can be + resumed by calling `init_device` again with the appropriate session ID. + + If neither `new_session` nor `session_id` is specified, the current session ID + will be reused. If no session ID was cached, a new session ID will be allocated + and returned. + + # Version notes: + + Trezor One older than 1.9.0 does not have session management. Optional arguments + have no effect and the function returns None + + Trezor T older than 2.3.0 does not have session cache. Requesting a new session + will overwrite the old one. In addition, this function will always return None. + A valid session_id can be obtained from the `session_id` attribute, but only + after a passphrase-protected call is performed. You can use the following code: + + >>> client.init_device() + >>> client.ensure_unlocked() + >>> valid_session_id = client.session_id + """ + if new_session: + self.session_id = None + elif session_id is not None: + self.session_id = session_id + + resp = self.call_raw(messages.Initialize(session_id=self.session_id)) + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to Initialize") + + # TT < 2.3.0 compatibility: + # _refresh_features will clear out the session_id field. We want this function + # to return its value, so that callers can rely on it being either a valid + # session_id, or None if we can't do that. + # Older TT FW does not report session_id in Features and self.session_id might + # be invalid because TT will not allocate a session_id until a passphrase + # exchange happens. + reported_session_id = resp.session_id + self._refresh_features(resp) + return reported_session_id def is_outdated(self): if self.features.bootloader_mode: @@ -283,12 +341,58 @@ class TrezorClient: def get_device_id(self): return self.features.device_id + @tools.session + def lock(self): + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + self.call(messages.LockDevice()) + self.refresh_features() + + @tools.session + def ensure_unlocked(self): + """Ensure the device is unlocked and a passphrase is cached. + + If the device is locked, this will prompt for PIN. If passphrase is enabled + and no passphrase is cached for the current session, the device will also + prompt for passphrase. + + After calling this method, further actions on the device will not prompt for + PIN or passphrase until the device is locked or the session becomes invalid. + """ + from .btc import get_address + + get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def end_session(self): + """Close the current session and clear cached passphrase. + + The session will become invalid until `init_device()` is called again. + If passphrase is enabled, further actions will prompt for it again. + """ + # XXX self.call(messages.EndSession()) + self.session_id = None + @tools.session def clear_session(self): - resp = self.call_raw(messages.LockDevice()) # TODO fix this - if isinstance(resp, messages.Success): - self.session_id = None - self.init_device() - return resp.message - else: - return resp + """Lock the device and present a fresh session. + + The current session will be invalidated and a new one will be started. If the + device has PIN enabled, it will become locked. + + Equivalent to calling `lock()`, `end_session()` and `init_device()`. + """ + # call LockDevice manually to save one refresh_features() call + self.call(messages.LockDevice()) + self.end_session() + self.init_device() diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 23d1ea1cba..60a1335b31 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -25,6 +25,7 @@ RECOVERY_BACK = "\x08" # backspace character, sent literally @expect(messages.Success, field="message") +@session def apply_settings( client, label=None, @@ -48,45 +49,51 @@ def apply_settings( ) out = client.call(settings) - client.init_device() # Reload Features + client.refresh_features() return out @expect(messages.Success, field="message") +@session def apply_flags(client, flags): out = client.call(messages.ApplyFlags(flags=flags)) - client.init_device() # Reload Features + client.refresh_features() return out @expect(messages.Success, field="message") +@session def change_pin(client, remove=False): ret = client.call(messages.ChangePin(remove=remove)) - client.init_device() # Re-read features + client.refresh_features() return ret @expect(messages.Success, field="message") +@session def change_wipe_code(client, remove=False): ret = client.call(messages.ChangeWipeCode(remove=remove)) - client.init_device() # Re-read features + client.refresh_features() return ret @expect(messages.Success, field="message") +@session def sd_protect(client, operation): ret = client.call(messages.SdProtect(operation=operation)) - client.init_device() + client.refresh_features() return ret @expect(messages.Success, field="message") +@session def wipe(client): ret = client.call(messages.WipeDevice()) client.init_device() return ret +@session def recover( client, word_count=24, @@ -190,8 +197,10 @@ def reset( @expect(messages.Success, field="message") +@session def backup(client): ret = client.call(messages.BackupDevice()) + client.refresh_features() return ret diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index ff59c181f3..282ac4fa2f 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -20,7 +20,9 @@ from trezorlib import device, messages class TestBasic: def test_features(self, client): f0 = client.features - f1 = client.call(messages.Initialize(f0.session_id)) + # client erases session_id from its features + f0.session_id = client.session_id + f1 = client.call(messages.Initialize(client.session_id)) assert f0 == f1 def test_ping(self, client): diff --git a/tests/device_tests/test_msg_backup_device.py b/tests/device_tests/test_msg_backup_device.py index f605b9ff49..bdaaf61453 100644 --- a/tests/device_tests/test_msg_backup_device.py +++ b/tests/device_tests/test_msg_backup_device.py @@ -57,6 +57,7 @@ def test_backup_bip39(client): messages.ButtonRequest(code=B.Success), messages.ButtonRequest(code=B.Success), messages.Success(), + messages.Features(), ] ) device.backup(client) @@ -119,6 +120,7 @@ def test_backup_slip39_basic(client): messages.ButtonRequest(code=B.Success), messages.ButtonRequest(code=B.Success), messages.Success(), + messages.Features(), ] ) device.backup(client) @@ -238,6 +240,7 @@ def test_backup_slip39_advanced(client): messages.ButtonRequest(code=B.Success), # show seeds ends here messages.ButtonRequest(code=B.Success), messages.Success(), + messages.Features(), ] ) device.backup(client) diff --git a/tests/device_tests/test_reset_backup.py b/tests/device_tests/test_reset_backup.py index 1e752a96da..75b087c205 100644 --- a/tests/device_tests/test_reset_backup.py +++ b/tests/device_tests/test_reset_backup.py @@ -58,6 +58,7 @@ def backup_flow_bip39(client): messages.ButtonRequest(code=B.Success), messages.ButtonRequest(code=B.Success), messages.Success(), + messages.Features(), ] ) client.set_input_flow(input_flow) @@ -99,7 +100,11 @@ def backup_flow_slip39_basic(client): messages.ButtonRequest(code=B.Success), ] * 5 # individual shares - + [messages.ButtonRequest(code=B.Success), messages.Success()] + + [ + messages.ButtonRequest(code=B.Success), + messages.Success(), + messages.Features(), + ] ) device.backup(client) @@ -158,7 +163,11 @@ def backup_flow_slip39_advanced(client): messages.ButtonRequest(code=B.Success), ] * 25 # individual shares - + [messages.ButtonRequest(code=B.Success), messages.Success()] + + [ + messages.ButtonRequest(code=B.Success), + messages.Success(), + messages.Features(), + ] ) device.backup(client)