mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-03 13:22:33 +00:00
python: clarify session and feature management API
init_device() should be used to initialize a session. Reuses existing session if available. end_session() explicitly closes any existing session and requests a new one lock() enables soft-lock clear_session() is the equivalent of lock() + end_session() A new function ensure_unlocked() can be used to open a session and prompt for PIN and passphrase before further operations.
This commit is contained in:
parent
95f33a77c7
commit
e585d35f34
@ -16,8 +16,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
@ -25,9 +25,6 @@ from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
|
|||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import Capability
|
from .messages import Capability
|
||||||
|
|
||||||
if sys.version_info.major < 3:
|
|
||||||
raise Exception("Trezorlib does not support Python 2 anymore.")
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
VENDORS = ("bitcointrezor.com", "trezor.io")
|
VENDORS = ("bitcointrezor.com", "trezor.io")
|
||||||
@ -103,6 +100,7 @@ class TrezorClient:
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
self.session_counter = max(self.session_counter - 1, 0)
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
|
# TODO call EndSession here?
|
||||||
self.transport.end_session()
|
self.transport.end_session()
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
@ -226,18 +224,12 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@tools.session
|
def _refresh_features(self, features: messages.Features) -> None:
|
||||||
def init_device(self):
|
"""Update internal fields based on passed-in Features message."""
|
||||||
resp = self.call_raw(messages.Initialize(session_id=self.session_id))
|
if features.vendor not in VENDORS:
|
||||||
if not isinstance(resp, messages.Features):
|
|
||||||
raise exceptions.TrezorException("Unexpected initial response")
|
|
||||||
else:
|
|
||||||
self.features = resp
|
|
||||||
if self.features.vendor not in VENDORS:
|
|
||||||
raise RuntimeError("Unsupported device")
|
raise RuntimeError("Unsupported device")
|
||||||
# A side-effect of this is a sanity check for broken protobuf definitions.
|
|
||||||
# If the `vendor` field doesn't exist, you probably have a mismatched
|
self.features = features
|
||||||
# checkout of trezor-common.
|
|
||||||
self.version = (
|
self.version = (
|
||||||
self.features.major_version,
|
self.features.major_version,
|
||||||
self.features.minor_version,
|
self.features.minor_version,
|
||||||
@ -246,6 +238,72 @@ class TrezorClient:
|
|||||||
self.check_firmware_version(warn_only=True)
|
self.check_firmware_version(warn_only=True)
|
||||||
if self.features.session_id is not None:
|
if self.features.session_id is not None:
|
||||||
self.session_id = self.features.session_id
|
self.session_id = self.features.session_id
|
||||||
|
self.features.session_id = None
|
||||||
|
|
||||||
|
@tools.session
|
||||||
|
def refresh_features(self) -> messages.Features:
|
||||||
|
"""Reload features from the device.
|
||||||
|
|
||||||
|
Should be called after changing settings or performing operations that affect
|
||||||
|
device state.
|
||||||
|
"""
|
||||||
|
resp = self.call_raw(messages.GetFeatures())
|
||||||
|
if not isinstance(resp, messages.Features):
|
||||||
|
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||||
|
self._refresh_features(resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@tools.session
|
||||||
|
def init_device(
|
||||||
|
self, *, session_id: bytes = None, new_session: bool = False
|
||||||
|
) -> Optional[bytes]:
|
||||||
|
"""Initialize the device and return a session ID.
|
||||||
|
|
||||||
|
You can optionally specify a session ID. If the session still exists on the
|
||||||
|
device, the same session ID will be returned and the session is resumed.
|
||||||
|
Otherwise a different session ID is returned.
|
||||||
|
|
||||||
|
Specify `new_session=True` to open a fresh session. Since firmware version
|
||||||
|
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
|
||||||
|
resumed by calling `init_device` again with the appropriate session ID.
|
||||||
|
|
||||||
|
If neither `new_session` nor `session_id` is specified, the current session ID
|
||||||
|
will be reused. If no session ID was cached, a new session ID will be allocated
|
||||||
|
and returned.
|
||||||
|
|
||||||
|
# Version notes:
|
||||||
|
|
||||||
|
Trezor One older than 1.9.0 does not have session management. Optional arguments
|
||||||
|
have no effect and the function returns None
|
||||||
|
|
||||||
|
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
|
||||||
|
will overwrite the old one. In addition, this function will always return None.
|
||||||
|
A valid session_id can be obtained from the `session_id` attribute, but only
|
||||||
|
after a passphrase-protected call is performed. You can use the following code:
|
||||||
|
|
||||||
|
>>> client.init_device()
|
||||||
|
>>> client.ensure_unlocked()
|
||||||
|
>>> valid_session_id = client.session_id
|
||||||
|
"""
|
||||||
|
if new_session:
|
||||||
|
self.session_id = None
|
||||||
|
elif session_id is not None:
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
resp = self.call_raw(messages.Initialize(session_id=self.session_id))
|
||||||
|
if not isinstance(resp, messages.Features):
|
||||||
|
raise exceptions.TrezorException("Unexpected response to Initialize")
|
||||||
|
|
||||||
|
# TT < 2.3.0 compatibility:
|
||||||
|
# _refresh_features will clear out the session_id field. We want this function
|
||||||
|
# to return its value, so that callers can rely on it being either a valid
|
||||||
|
# session_id, or None if we can't do that.
|
||||||
|
# Older TT FW does not report session_id in Features and self.session_id might
|
||||||
|
# be invalid because TT will not allocate a session_id until a passphrase
|
||||||
|
# exchange happens.
|
||||||
|
reported_session_id = resp.session_id
|
||||||
|
self._refresh_features(resp)
|
||||||
|
return reported_session_id
|
||||||
|
|
||||||
def is_outdated(self):
|
def is_outdated(self):
|
||||||
if self.features.bootloader_mode:
|
if self.features.bootloader_mode:
|
||||||
@ -284,11 +342,57 @@ class TrezorClient:
|
|||||||
return self.features.device_id
|
return self.features.device_id
|
||||||
|
|
||||||
@tools.session
|
@tools.session
|
||||||
def clear_session(self):
|
def lock(self):
|
||||||
resp = self.call_raw(messages.LockDevice()) # TODO fix this
|
"""Lock the device.
|
||||||
if isinstance(resp, messages.Success):
|
|
||||||
|
If the device does not have a PIN configured, this will do nothing.
|
||||||
|
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||||
|
before further actions.
|
||||||
|
|
||||||
|
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||||
|
the device will not prompt for it after unlocking.
|
||||||
|
|
||||||
|
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||||
|
passphrase cache, use `clear_session()`.
|
||||||
|
"""
|
||||||
|
self.call(messages.LockDevice())
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
@tools.session
|
||||||
|
def ensure_unlocked(self):
|
||||||
|
"""Ensure the device is unlocked and a passphrase is cached.
|
||||||
|
|
||||||
|
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
||||||
|
and no passphrase is cached for the current session, the device will also
|
||||||
|
prompt for passphrase.
|
||||||
|
|
||||||
|
After calling this method, further actions on the device will not prompt for
|
||||||
|
PIN or passphrase until the device is locked or the session becomes invalid.
|
||||||
|
"""
|
||||||
|
from .btc import get_address
|
||||||
|
|
||||||
|
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
def end_session(self):
|
||||||
|
"""Close the current session and clear cached passphrase.
|
||||||
|
|
||||||
|
The session will become invalid until `init_device()` is called again.
|
||||||
|
If passphrase is enabled, further actions will prompt for it again.
|
||||||
|
"""
|
||||||
|
# XXX self.call(messages.EndSession())
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
|
@tools.session
|
||||||
|
def clear_session(self):
|
||||||
|
"""Lock the device and present a fresh session.
|
||||||
|
|
||||||
|
The current session will be invalidated and a new one will be started. If the
|
||||||
|
device has PIN enabled, it will become locked.
|
||||||
|
|
||||||
|
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
|
||||||
|
"""
|
||||||
|
# call LockDevice manually to save one refresh_features() call
|
||||||
|
self.call(messages.LockDevice())
|
||||||
|
self.end_session()
|
||||||
self.init_device()
|
self.init_device()
|
||||||
return resp.message
|
|
||||||
else:
|
|
||||||
return resp
|
|
||||||
|
@ -25,6 +25,7 @@ RECOVERY_BACK = "\x08" # backspace character, sent literally
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def apply_settings(
|
def apply_settings(
|
||||||
client,
|
client,
|
||||||
label=None,
|
label=None,
|
||||||
@ -48,45 +49,51 @@ def apply_settings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
out = client.call(settings)
|
out = client.call(settings)
|
||||||
client.init_device() # Reload Features
|
client.refresh_features()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def apply_flags(client, flags):
|
def apply_flags(client, flags):
|
||||||
out = client.call(messages.ApplyFlags(flags=flags))
|
out = client.call(messages.ApplyFlags(flags=flags))
|
||||||
client.init_device() # Reload Features
|
client.refresh_features()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def change_pin(client, remove=False):
|
def change_pin(client, remove=False):
|
||||||
ret = client.call(messages.ChangePin(remove=remove))
|
ret = client.call(messages.ChangePin(remove=remove))
|
||||||
client.init_device() # Re-read features
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def change_wipe_code(client, remove=False):
|
def change_wipe_code(client, remove=False):
|
||||||
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
||||||
client.init_device() # Re-read features
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def sd_protect(client, operation):
|
def sd_protect(client, operation):
|
||||||
ret = client.call(messages.SdProtect(operation=operation))
|
ret = client.call(messages.SdProtect(operation=operation))
|
||||||
client.init_device()
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def wipe(client):
|
def wipe(client):
|
||||||
ret = client.call(messages.WipeDevice())
|
ret = client.call(messages.WipeDevice())
|
||||||
client.init_device()
|
client.init_device()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@session
|
||||||
def recover(
|
def recover(
|
||||||
client,
|
client,
|
||||||
word_count=24,
|
word_count=24,
|
||||||
@ -190,8 +197,10 @@ def reset(
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message")
|
@expect(messages.Success, field="message")
|
||||||
|
@session
|
||||||
def backup(client):
|
def backup(client):
|
||||||
ret = client.call(messages.BackupDevice())
|
ret = client.call(messages.BackupDevice())
|
||||||
|
client.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,9 @@ from trezorlib import device, messages
|
|||||||
class TestBasic:
|
class TestBasic:
|
||||||
def test_features(self, client):
|
def test_features(self, client):
|
||||||
f0 = client.features
|
f0 = client.features
|
||||||
f1 = client.call(messages.Initialize(f0.session_id))
|
# client erases session_id from its features
|
||||||
|
f0.session_id = client.session_id
|
||||||
|
f1 = client.call(messages.Initialize(client.session_id))
|
||||||
assert f0 == f1
|
assert f0 == f1
|
||||||
|
|
||||||
def test_ping(self, client):
|
def test_ping(self, client):
|
||||||
|
@ -57,6 +57,7 @@ def test_backup_bip39(client):
|
|||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.Success(),
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
device.backup(client)
|
device.backup(client)
|
||||||
@ -119,6 +120,7 @@ def test_backup_slip39_basic(client):
|
|||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.Success(),
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
device.backup(client)
|
device.backup(client)
|
||||||
@ -238,6 +240,7 @@ def test_backup_slip39_advanced(client):
|
|||||||
messages.ButtonRequest(code=B.Success), # show seeds ends here
|
messages.ButtonRequest(code=B.Success), # show seeds ends here
|
||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.Success(),
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
device.backup(client)
|
device.backup(client)
|
||||||
|
@ -58,6 +58,7 @@ def backup_flow_bip39(client):
|
|||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
messages.Success(),
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
client.set_input_flow(input_flow)
|
client.set_input_flow(input_flow)
|
||||||
@ -99,7 +100,11 @@ def backup_flow_slip39_basic(client):
|
|||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
]
|
]
|
||||||
* 5 # individual shares
|
* 5 # individual shares
|
||||||
+ [messages.ButtonRequest(code=B.Success), messages.Success()]
|
+ [
|
||||||
|
messages.ButtonRequest(code=B.Success),
|
||||||
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
device.backup(client)
|
device.backup(client)
|
||||||
|
|
||||||
@ -158,7 +163,11 @@ def backup_flow_slip39_advanced(client):
|
|||||||
messages.ButtonRequest(code=B.Success),
|
messages.ButtonRequest(code=B.Success),
|
||||||
]
|
]
|
||||||
* 25 # individual shares
|
* 25 # individual shares
|
||||||
+ [messages.ButtonRequest(code=B.Success), messages.Success()]
|
+ [
|
||||||
|
messages.ButtonRequest(code=B.Success),
|
||||||
|
messages.Success(),
|
||||||
|
messages.Features(),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
device.backup(client)
|
device.backup(client)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user