1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-08-02 03:48:58 +00:00

wip trezorlib TEMP

(note: has many warnings, not all things will work correctly)
This commit is contained in:
M1nd3r 2024-09-16 18:32:16 +02:00
parent 34d3dbdeb1
commit f9a0124e0f
40 changed files with 1135 additions and 1125 deletions

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata from importlib import metadata
from . import device from . import device
from .client import TrezorClient from .transport.new.session import Session
try: try:
cryptography_version = metadata.version("cryptography") cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device( def authenticate_device(
client: TrezorClient, session: Session,
challenge: bytes | None = None, challenge: bytes | None = None,
*, *,
whitelist: t.Collection[bytes] | None = None, whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None: if challenge is None:
challenge = secrets.token_bytes(16) challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge) resp = device.authenticate(session, challenge)
return verify_authentication_response( return verify_authentication_response(
challenge, challenge,

View File

@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import expect, session from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
@expect(messages.BinanceAddress, field="address", ret_type=str) @expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.BinanceGetAddress( messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -42,16 +42,15 @@ def get_address(
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key( def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False session: "Session", address_n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx: ) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0] msg = tx_json["msgs"][0]
tx_msg = tx_json.copy() tx_msg = tx_json.copy()
@ -60,7 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope) response = session.call(envelope)
if not isinstance(response, messages.BinanceTxRequest): if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError( raise RuntimeError(
@ -77,7 +76,7 @@ def sign_tx(
else: else:
raise ValueError("can not determine msg type") raise ValueError("can not determine msg type")
response = client.call(msg) response = session.call(msg)
if not isinstance(response, messages.BinanceSignedTx): if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError( raise RuntimeError(

View File

@ -13,7 +13,6 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings import warnings
from copy import copy from copy import copy
from decimal import Decimal from decimal import Decimal
@ -24,12 +23,11 @@ from typing_extensions import Protocol, TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, prepare_message_bytes from .tools import expect, prepare_message_bytes
from .transport.new.session import Session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
class ScriptSig(TypedDict): class ScriptSig(TypedDict):
asm: str asm: str
@ -176,13 +174,13 @@ def get_authenticated_address(
# TODO this is used by tests only # TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id( def get_ownership_id(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.GetOwnershipId( messages.GetOwnershipId(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -194,7 +192,7 @@ def get_ownership_id(
# TODO this is used by tests only # TODO this is used by tests only
def get_ownership_proof( def get_ownership_proof(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -205,11 +203,11 @@ def get_ownership_proof(
preauthorized: bool = False, preauthorized: bool = False,
) -> Tuple[bytes, bytes]: ) -> Tuple[bytes, bytes]:
if preauthorized: if preauthorized:
res = client.call(messages.DoPreauthorized()) res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest): if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
res = client.call( res = session.call(
messages.GetOwnershipProof( messages.GetOwnershipProof(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -435,7 +433,7 @@ def sign_tx(
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin( def authorize_coinjoin(
client: "TrezorClient", session: "Session",
coordinator: str, coordinator: str,
max_rounds: int, max_rounds: int,
max_coordinator_fee_rate: int, max_coordinator_fee_rate: int,
@ -444,7 +442,7 @@ def authorize_coinjoin(
coin_name: str, coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.AuthorizeCoinJoin( messages.AuthorizeCoinJoin(
coordinator=coordinator, coordinator=coordinator,
max_rounds=max_rounds, max_rounds=max_rounds,

View File

@ -35,8 +35,8 @@ from . import exceptions, messages, tools
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.new.session import Session
PROTOCOL_MAGICS = { PROTOCOL_MAGICS = {
"mainnet": 764824073, "mainnet": 764824073,
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
@expect(messages.CardanoAddress, field="address", ret_type=str) @expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_parameters: messages.CardanoAddressParametersType, address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
@ -833,7 +833,7 @@ def get_address(
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetAddress( messages.CardanoGetAddress(
address_parameters=address_parameters, address_parameters=address_parameters,
protocol_magic=protocol_magic, protocol_magic=protocol_magic,
@ -847,12 +847,12 @@ def get_address(
@expect(messages.CardanoPublicKey) @expect(messages.CardanoPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
show_display: bool = False, show_display: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetPublicKey( messages.CardanoGetPublicKey(
address_n=address_n, address_n=address_n,
derivation_type=derivation_type, derivation_type=derivation_type,
@ -863,12 +863,12 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash) @expect(messages.CardanoNativeScriptHash)
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
native_script: messages.CardanoNativeScript, native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetNativeScriptHash( messages.CardanoGetNativeScriptHash(
script=native_script, script=native_script,
display_format=display_format, display_format=display_format,
@ -878,7 +878,7 @@ def get_native_script_hash(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath], inputs: List[InputWithPath],
outputs: List[OutputWithData], outputs: List[OutputWithData],
@ -915,7 +915,7 @@ def sign_tx(
signing_mode, signing_mode,
) )
response = client.call( response = session.call(
messages.CardanoSignTxInit( messages.CardanoSignTxInit(
signing_mode=signing_mode, signing_mode=signing_mode,
inputs_count=len(inputs), inputs_count=len(inputs),
@ -951,14 +951,14 @@ def sign_tx(
_get_certificates_items(certificates), _get_certificates_items(certificates),
withdrawals, withdrawals,
): ):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {} sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None: if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data) auxiliary_data_supplement = session.call(auxiliary_data)
if not isinstance( if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
): ):
@ -971,7 +971,7 @@ def sign_tx(
auxiliary_data_supplement.__dict__ auxiliary_data_supplement.__dict__
) )
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
@ -980,24 +980,24 @@ def sign_tx(
_get_collateral_inputs_items(collateral_inputs), _get_collateral_inputs_items(collateral_inputs),
required_signers, required_signers,
): ):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None: if collateral_return is not None:
for tx_item in _get_output_items(collateral_return): for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs: for reference_input in reference_inputs:
response = client.call(reference_input) response = session.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = [] sign_tx_response["witnesses"] = []
for witness_request in witness_requests: for witness_request in witness_requests:
response = client.call(witness_request) response = session.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse): if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append( sign_tx_response["witnesses"].append(
@ -1009,12 +1009,12 @@ def sign_tx(
} }
) )
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash): if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished): if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR

View File

@ -28,9 +28,7 @@ from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport.new import channel_database from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.transport import NewTransport from ..transport.new.transport import NewTransport
from ..ui import ClickUI, ScriptUI
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators # Needed to enforce a return value from decorators
@ -39,9 +37,6 @@ if t.TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@ -117,19 +112,17 @@ class NewTrezorConnection:
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
self.script = script self.script = script
def get_session( def get_session(self, derive_cardano: bool = False):
self,
):
client = self.get_client() client = self.get_client()
if self.session_id is not None: if self.session_id is not None:
pass # TODO Try resume pass # TODO Try resume - be careful of cardano derivation settings!
features = client.protocol.get_features() features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here? passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled: if not passphrase_enabled:
return client.get_session(derive_cardano=True) return client.get_session(derive_cardano=derive_cardano)
# TODO Passphrase empty by default - ??? # TODO Passphrase empty by default - ???
available_on_device = Capability.PassphraseEntry in features.capabilities available_on_device = Capability.PassphraseEntry in features.capabilities
@ -137,7 +130,9 @@ class NewTrezorConnection:
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str): if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str") raise RuntimeError("Passphrase must be a str")
session = client.get_session(passphrase=passphrase, derive_cardano=True) session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session return session
def get_transport(self) -> "NewTransport": def get_transport(self) -> "NewTransport":
@ -152,7 +147,7 @@ class NewTrezorConnection:
# if this fails, we want the exception to bubble up to the caller # if this fails, we want the exception to bubble up to the caller
return transport.new_get_transport(self.path, prefix_search=True) return transport.new_get_transport(self.path, prefix_search=True)
def get_client(self) -> NewTrezorClient: def get_client(self) -> TrezorClient:
transport = self.get_transport() transport = self.get_transport()
stored_channels = channel_database.load_stored_channels() stored_channels = channel_database.load_stored_channels()
@ -162,11 +157,11 @@ class NewTrezorConnection:
stored_channel_with_correct_transport_path = next( stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path ch for ch in stored_channels if ch.transport_path == path
) )
client = NewTrezorClient.resume( client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path transport, stored_channel_with_correct_transport_path
) )
else: else:
client = NewTrezorClient(transport) client = TrezorClient(transport)
return client return client
@ -202,91 +197,96 @@ class NewTrezorConnection:
# other exceptions may cause a traceback # other exceptions may cause a traceback
class TrezorConnection: # class TrezorConnection:
def __init__( # def __init__(
self, # self,
path: str, # path: str,
session_id: bytes | None, # session_id: bytes | None,
passphrase_on_host: bool, # passphrase_on_host: bool,
script: bool, # script: bool,
) -> None: # ) -> None:
self.path = path # self.path = path
self.session_id = session_id # self.session_id = session_id
self.passphrase_on_host = passphrase_on_host # self.passphrase_on_host = passphrase_on_host
self.script = script # self.script = script
def get_transport(self) -> "Transport": # def get_transport(self) -> "Transport":
try: # try:
# look for transport without prefix search # # look for transport without prefix search
return transport.get_transport(self.path, prefix_search=False) # return transport.get_transport(self.path, prefix_search=False)
except Exception: # except Exception:
# most likely not found. try again below. # # most likely not found. try again below.
pass # pass
# look for transport with prefix search # # look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller # # if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True) # return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI": # def get_ui(self) -> "TrezorClientUI":
if self.script: # if self.script:
# It is alright to return just the class object instead of instance, # # It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI # # as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods) # # (ScriptUI is just a set of staticmethods)
return ScriptUI # return ScriptUI
else: # else:
return ClickUI(passphrase_on_host=self.passphrase_on_host) # return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient: # def get_client(self) -> TrezorClient:
transport = self.get_transport() # transport = self.get_transport()
ui = self.get_ui() # ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id) # return TrezorClient(transport, ui=ui, session_id=self.session_id)
@contextmanager # @contextmanager
def client_context(self): # def client_context(self):
"""Get a client instance as a context manager. Handle errors in a manner # """Get a client instance as a context manager. Handle errors in a manner
appropriate for end-users. # appropriate for end-users.
Usage: # Usage:
>>> with obj.client_context() as client: # >>> with obj.client_context() as client:
>>> do_your_actions_here() # >>> do_your_actions_here()
""" # """
try: # try:
client = self.get_client() # client = self.get_client()
except transport.DeviceIsBusy: # except transport.DeviceIsBusy:
click.echo("Device is in use by another process.") # click.echo("Device is in use by another process.")
sys.exit(1) # sys.exit(1)
except Exception: # except Exception:
click.echo("Failed to find a Trezor device.") # click.echo("Failed to find a Trezor device.")
if self.path is not None: # if self.path is not None:
click.echo(f"Using path: {self.path}") # click.echo(f"Using path: {self.path}")
sys.exit(1) # sys.exit(1)
try:
yield client
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
# try:
# yield client
# except exceptions.Cancelled:
# # handle cancel action
# click.echo("Action was cancelled.")
# sys.exit(1)
# except exceptions.TrezorException as e:
# # handle any Trezor-sent exceptions as user-readable
# raise click.ClickException(str(e)) from e
# # other exceptions may cause a traceback
from ..transport.new.session import Session from ..transport.new.session import Session
def with_session( def with_cardano_session(
func: "t.Callable[Concatenate[Session, P], R]", func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]": ) -> "t.Callable[P, R]":
return with_session(func=func, derive_cardano=True)
def with_session(
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
) -> "t.Callable[P, R]":
@click.pass_obj @click.pass_obj
@functools.wraps(func) @functools.wraps(func)
def function_with_session( def function_with_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R": ) -> "R":
session = obj.get_session() session = obj.get_session(derive_cardano)
try: try:
return func(session, *args, **kwargs) return func(session, *args, **kwargs)
finally: finally:
@ -298,8 +298,8 @@ def with_session(
return function_with_session # type: ignore [is incompatible with return type] return function_with_session # type: ignore [is incompatible with return type]
def new_with_client( def with_client(
func: "t.Callable[Concatenate[NewTrezorClient, P], R]", func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]": ) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
@ -336,39 +336,39 @@ def new_with_client(
return trezorctl_command_with_client # type: ignore [is incompatible with return type] return trezorctl_command_with_client # type: ignore [is incompatible with return type]
def with_client( # def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]", # func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]": # ) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. # """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume # Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session # cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open. # was resumed, in which case it should remain open.
""" # """
@click.pass_obj # @click.pass_obj
@functools.wraps(func) # @functools.wraps(func)
def trezorctl_command_with_client( # def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" # obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R": # ) -> "R":
with obj.client_context() as client: # with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id # session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None: # if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed # # tried to resume but failed
click.echo("Warning: failed to resume session.", err=True) # click.echo("Warning: failed to resume session.", err=True)
try: # try:
return func(client, *args, **kwargs) # return func(client, *args, **kwargs)
finally: # finally:
if not session_was_resumed: # if not session_was_resumed:
try: # try:
client.end_session() # client.end_session()
except Exception: # except Exception:
pass # pass
# the return type of @click.pass_obj is improperly specified and pyright doesn't # # the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) # # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [is incompatible with return type] # return trezorctl_command_with_client
class AliasedGroup(click.Group): class AliasedGroup(click.Group):

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import binance, tools from .. import binance, tools
from . import with_client from ..transport.new.session import Session
from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Binance address for specified path.""" """Get Binance address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify) return binance.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key.""" """Get Binance public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex() return binance.get_public_key(session, address_n, show_display).hex()
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx": ) -> "messages.BinanceSignedTx":
"""Sign Binance transaction. """Sign Binance transaction.
Transaction must be provided as a JSON file. Transaction must be provided as a JSON file.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import cardano, messages, tools from .. import cardano, messages, tools
from . import ChoiceType, with_client from . import ChoiceType, with_cardano_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True) @click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "chunkify", is_flag=True) @click.option("-C", "chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client @with_cardano_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int, protocol_magic: int,
@ -123,9 +123,9 @@ def sign_tx(
for p in transaction["additional_witness_requests"] for p in transaction["additional_witness_requests"]
] ]
client.init_device(derive_cardano=True) session.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx( sign_tx_response = cardano.sign_tx(
client, session,
signing_mode, signing_mode,
inputs, inputs,
outputs, outputs,
@ -209,9 +209,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_cardano_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
address_type: messages.CardanoAddressType, address_type: messages.CardanoAddressType,
staking_address: str, staking_address: str,
@ -262,9 +262,9 @@ def get_address(
script_staking_hash_bytes, script_staking_hash_bytes,
) )
client.init_device(derive_cardano=True) session.init_device(derive_cardano=True)
return cardano.get_address( return cardano.get_address(
client, session,
address_parameters, address_parameters,
protocol_magic, protocol_magic,
network_id, network_id,
@ -283,18 +283,18 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_cardano_session
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
show_display: bool, show_display: bool,
) -> messages.CardanoPublicKey: ) -> messages.CardanoPublicKey:
"""Get Cardano public key.""" """Get Cardano public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
client.init_device(derive_cardano=True) session.init_device(derive_cardano=True)
return cardano.get_public_key( return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display session, address_n, derivation_type=derivation_type, show_display=show_display
) )
@ -312,9 +312,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@with_client @with_cardano_session
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat, display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
@ -323,7 +323,7 @@ def get_native_script_hash(
native_script_json = json.load(file) native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json) native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True) session.init_device(derive_cardano=True)
return cardano.get_native_script_hash( return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type session, native_script, display_format, derivation_type=derivation_type
) )

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click import click
from .. import misc, tools from .. import misc, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PROMPT_TYPE = ChoiceType( PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("size", type=int) @click.argument("size", type=int)
@with_client @with_session
def get_entropy(client: "TrezorClient", size: int) -> str: def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device.""" """Get random bytes from device."""
return misc.get_entropy(client, size).hex() return misc.get_entropy(session, size).hex()
@cli.command() @cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.encrypt_keyvalue( return misc.encrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
value.encode(), value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.decrypt_keyvalue( return misc.decrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
bytes.fromhex(value), bytes.fromhex(value),

View File

@ -27,7 +27,7 @@ from ..debuglink import record_screen
from . import with_client from . import with_client
if TYPE_CHECKING: if TYPE_CHECKING:
from . import TrezorConnection from . import NewTrezorConnection
@click.group(name="debug") @click.group(name="debug")
@ -40,7 +40,7 @@ def cli() -> None:
@click.argument("hex_data") @click.argument("hex_data")
@click.pass_obj @click.pass_obj
def send_bytes( def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
) -> None: ) -> None:
"""Send raw bytes to Trezor. """Send raw bytes to Trezor.
@ -86,7 +86,7 @@ def send_bytes(
@click.argument("directory", required=False) @click.argument("directory", required=False)
@click.option("-s", "--stop", is_flag=True, help="Stop the recording") @click.option("-s", "--stop", is_flag=True, help="Stop the recording")
@click.pass_obj @click.pass_obj
def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> None: def record(obj: "NewTrezorConnection", directory: Union[str, None], stop: bool) -> None:
"""Record screen changes into a specified directory. """Record screen changes into a specified directory.
Recording can be stopped with `-s / --stop` option. Recording can be stopped with `-s / --stop` option.
@ -95,7 +95,7 @@ def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) ->
def record_screen_from_connection( def record_screen_from_connection(
obj: "TrezorConnection", directory: Union[str, None] obj: "NewTrezorConnection", directory: Union[str, None]
) -> None: ) -> None:
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink.""" """Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
transport = obj.get_transport() transport = obj.get_transport()

View File

@ -29,7 +29,7 @@ from . import ChoiceType, with_client
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
from ..protobuf import MessageType from ..protobuf import MessageType
from . import TrezorConnection from . import NewTrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = { RECOVERY_DEVICE_INPUT_METHOD = {
"scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords, "scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords,
@ -67,14 +67,15 @@ def cli() -> None:
@with_client @with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str: def wipe(client: "TrezorClient", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data.""" """Reset device to factory defaults and remove all private data."""
features = client.get_management_session().get_features()
if bootloader: if bootloader:
if not client.features.bootloader_mode: if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.") click.echo("Please switch your device to bootloader mode.")
sys.exit(1) sys.exit(1)
else: else:
click.echo("Wiping user data and firmware!") click.echo("Wiping user data and firmware!")
else: else:
if client.features.bootloader_mode: if features.bootloader_mode:
click.echo( click.echo(
"Your device is in bootloader mode. This operation would also erase firmware." "Your device is in bootloader mode. This operation would also erase firmware."
) )
@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
click.echo("Wiping user data!") click.echo("Wiping user data!")
try: try:
return device.wipe(client) return device.wipe(
client
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e: except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args)) click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3) sys.exit(3)
@ -233,9 +236,11 @@ def setup(
strength = int(strength) strength = int(strength)
BT = messages.BackupType BT = messages.BackupType
management_session = client.get_management_session()
version = management_session.get_version()
if backup_type is None: if backup_type is None:
if client.version >= (2, 7, 1): if version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1 # SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable backup_type = BT.Slip39_Single_Extendable
else: else:
@ -309,7 +314,7 @@ def sd_protect(
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> str: def reboot_to_bootloader(obj: "NewTrezorConnection") -> str:
"""Reboot device into bootloader mode. """Reboot device into bootloader mode.
Currently only supported on Trezor Model One. Currently only supported on Trezor Model One.

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import eos, tools from .. import eos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding.""" """Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display) res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx": ) -> "messages.EosSignedTx":
"""Sign EOS transaction.""" """Sign EOS transaction."""
tx_json = json.load(file) tx_json = json.load(file)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return eos.sign_tx( return eos.sign_tx(
client, session,
address_n, address_n,
tx_json["transaction"], tx_json["transaction"],
tx_json["chain_id"], tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions from ..messages import EthereumDefinitions
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
import web3 import web3
from eth_typing import ChecksumAddress # noqa: I900 from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei from web3.types import Wei
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ethereum address in hex encoding.""" """Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify) return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path.""" """Get Ethereum public node of given path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display) result = ethereum.get_public_node(session, address_n, show_display=show_display)
return { return {
"node": { "node": {
"depth": result.node.depth, "depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address") @click.argument("to_address")
@click.argument("amount", callback=_amount_to_int) @click.argument("amount", callback=_amount_to_int)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
chain_id: int, chain_id: int,
address: str, address: str,
amount: int, amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
from_address = ethereum.get_address( from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network session, address_n, encoded_network=encoded_network
) )
if token: if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None assert max_gas_fee is not None
assert max_priority_fee is not None assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559( sig = ethereum.sign_tx_eip1559(
client, session,
n=address_n, n=address_n,
nonce=nonce, nonce=nonce,
gas_limit=gas_limit, gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price gas_price = _get_web3().eth.gas_price
assert gas_price is not None assert gas_price is not None
sig = ethereum.sign_tx( sig = ethereum.sign_tx(
client, session,
n=address_n, n=address_n,
tx_type=tx_type, tx_type=tx_type,
nonce=nonce, nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign message with Ethereum address.""" """Sign message with Ethereum address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = { output = {
"message": message, "message": message,
"address": ret.address, "address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation", help="Be compatible with Metamask's signTypedData_v4 implementation",
) )
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@with_client @with_session
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address. """Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network) defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read()) data = json.loads(file.read())
ret = ethereum.sign_typed_data( ret = ethereum.sign_typed_data(
client, session,
address_n, address_n,
data, data,
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: str, signature: str,
message: str, message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address.""" """Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature) signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message( return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify session, address, signature_bytes, message, chunkify=chunkify
) )
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex") @click.argument("domain_hash_hex")
@click.argument("message_hash_hex") @click.argument("message_hash_hex")
@with_client @with_session
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
Sign hash of typed data (EIP-712) with Ethereum address. Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash( ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network session, address_n, domain_hash, message_hash, network
) )
output = { output = {
"domain_hash": domain_hash_hex, "domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import fido from .. import fido
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list") @credentials.command(name="list")
@with_client @with_session
def credentials_list(client: "TrezorClient") -> None: def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device.""" """List all resident credentials on the device."""
creds = fido.list_credentials(client) creds = fido.list_credentials(session)
for cred in creds: for cred in creds:
click.echo("") click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:") click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_client @with_session
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
""" """
return fido.add_credential(client, bytes.fromhex(hex_credential_id)) return fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove") @credentials.command(name="remove")
@click.option( @click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_client @with_session
def credentials_remove(client: "TrezorClient", index: int) -> str: def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
return fido.remove_credential(client, index) return fido.remove_credential(session, index)
# #
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_client @with_session
def counter_set(client: "TrezorClient", counter: int) -> str: def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
return fido.set_counter(client, counter) return fido.set_counter(session, counter)
@counter.command(name="get-next") @counter.command(name="get-next")
@with_client @with_session
def counter_get_next(client: "TrezorClient") -> int: def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter. """Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation is returned and atomically increased. This command performs the same operation
and returns the counter value. and returns the counter value.
""" """
return fido.get_next_counter(client) return fido.get_next_counter(session)

View File

@ -41,7 +41,7 @@ from . import ChoiceType, with_client
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
from . import TrezorConnection from . import NewTrezorConnection
MODEL_CHOICE = ChoiceType( MODEL_CHOICE = ChoiceType(
{ {
@ -74,9 +74,10 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader. 1.8.0 because that installs the appropriate bootloader.
""" """
f = client.features management_session = client.get_management_session()
version = (f.major_version, f.minor_version, f.patch_version) features = management_session.get_features()
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) version = management_session.get_version()
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2 return bootloader_onev2
@ -306,25 +307,27 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version If the specified version is not found, prints the closest available version
(higher than the specified one, if existing). (higher than the specified one, if existing).
""" """
management_session = client.get_management_session()
features = management_session.get_features()
model = management_session.get_model()
if bitcoin_only is None: if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features) bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str: def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version)) return ".".join(map(str, version))
f = client.features releases = get_all_firmware_releases(model, bitcoin_only, beta)
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
highest_version = releases[0]["version"] highest_version = releases[0]["version"]
if version: if version:
want_version = [int(x) for x in version.split(".")] want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3: if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.") click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version: if want_version[0] != features.major_version:
click.echo( click.echo(
f"Warning: Trezor {client.model.name} firmware version should be " f"Warning: Trezor {model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})" f"{features.major_version}.X.Y (requested: {version})"
) )
else: else:
want_version = highest_version want_version = highest_version
@ -359,8 +362,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal # to the newer one, in that case update to the minimal
# compatible version first # compatible version first
# Choosing the version key to compare based on (not) being in BL mode # Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version] client_version = management_session.get_version()
if f.bootloader_mode: if features.bootloader_mode:
key_to_compare = "min_bootloader_version" key_to_compare = "min_bootloader_version"
else: else:
key_to_compare = "min_firmware_version" key_to_compare = "min_firmware_version"
@ -451,7 +454,7 @@ def upload_firmware_into_device(
firmware_data: bytes, firmware_data: bytes,
) -> None: ) -> None:
"""Perform the final act of loading the firmware into Trezor.""" """Perform the final act of loading the firmware into Trezor."""
f = client.features f = client.get_management_session().get_features()
try: try:
if f.major_version == 1 and f.firmware_present is not False: if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest # Trezor One does not send ButtonRequest
@ -482,7 +485,7 @@ def _is_strict_update(client: "TrezorClient", firmware_data: bytes) -> bool:
if not isinstance(fw, firmware.VendorFirmware): if not isinstance(fw, firmware.VendorFirmware):
return False return False
f = client.features f = client.get_management_session().get_features()
cur_version = (f.major_version, f.minor_version, f.patch_version, 0) cur_version = (f.major_version, f.minor_version, f.patch_version, 0)
return ( return (
@ -519,7 +522,7 @@ def cli() -> None:
@click.pass_obj @click.pass_obj
# fmt: on # fmt: on
def verify( def verify(
obj: "TrezorConnection", obj: "NewTrezorConnection",
filename: BinaryIO, filename: BinaryIO,
check_device: bool, check_device: bool,
fingerprint: Optional[str], fingerprint: Optional[str],
@ -564,7 +567,7 @@ def verify(
@click.pass_obj @click.pass_obj
# fmt: on # fmt: on
def download( def download(
obj: "TrezorConnection", obj: "NewTrezorConnection",
output: Optional[BinaryIO], output: Optional[BinaryIO],
model: Optional[TrezorModel], model: Optional[TrezorModel],
version: Optional[str], version: Optional[str],
@ -630,7 +633,7 @@ def download(
# fmt: on # fmt: on
@click.pass_obj @click.pass_obj
def update( def update(
obj: "TrezorConnection", obj: "NewTrezorConnection",
filename: Optional[BinaryIO], filename: Optional[BinaryIO],
url: Optional[str], url: Optional[str],
version: Optional[str], version: Optional[str],

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click import click
from .. import messages, monero, tools from .. import messages, monero, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
network_type: messages.MoneroNetworkType, network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes: ) -> bytes:
"""Get Monero address for specified path.""" """Get Monero address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify) return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@with_client @with_session
def get_watch_key( def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get Monero watch key for specified path.""" """Get Monero watch key for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type) res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey # TODO: could be made required in MoneroWatchKey
assert res.address is not None assert res.address is not None
assert res.watch_key is not None assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests import requests
from .. import nem, tools from .. import nem, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68) @click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
network: int, network: int,
show_display: bool, show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str: ) -> str:
"""Get NEM address for specified path.""" """Get NEM address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify) return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command() @cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
file: TextIO, file: TextIO,
broadcast: Optional[str], broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format. Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import ripple, tools from .. import ripple, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ripple address""" """Get Ripple address"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify) return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction""" """Sign Ripple transaction"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file)) msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:") click.echo("Signature:")
click.echo(result.signature.hex()) click.echo(result.signature.hex())
click.echo() click.echo()

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import messages, solana, tools from .. import messages, solana, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
) -> messages.SolanaPublicKey: ) -> messages.SolanaPublicKey:
"""Get Solana public key.""" """Get Solana public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display) return solana.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
chunkify: bool, chunkify: bool,
) -> messages.SolanaAddress: ) -> messages.SolanaAddress:
"""Get Solana address.""" """Get Solana address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify) return solana.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.argument("serialized_tx", type=str) @click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r")) @click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
serialized_tx: str, serialized_tx: str,
additional_information_file: Optional[TextIO], additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
) )
return solana.sign_tx( return solana.sign_tx(
client, session,
address_n, address_n,
bytes.fromhex(serialized_tx), bytes.fromhex(serialized_tx),
additional_information, additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import stellar, tools from .. import stellar, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
try: try:
from stellar_sdk import ( from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Stellar public address.""" """Get Stellar public address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify) return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).", help="Network passphrase (blank for public network).",
) )
@click.argument("b64envelope") @click.argument("b64envelope")
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes: ) -> bytes:
"""Sign a base64-encoded transaction envelope. """Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope) tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature) return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import messages, protobuf, tezos, tools from .. import messages, protobuf, tezos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Tezos address for specified path.""" """Get Tezos address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify) return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key.""" """Get Tezos public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display) return tezos.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx: ) -> messages.TezosSignedTx:
"""Sign Tezos transaction.""" """Sign Tezos transaction."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -28,13 +28,11 @@ from .. import __version__, log, messages, protobuf
from ..client import TrezorClient from ..client import TrezorClient
from ..transport import DeviceIsBusy, new_enumerate_devices from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new import channel_database from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.session import Session from ..transport.new.session import Session
from ..transport.new.udp import UdpTransport from ..transport.new.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
NewTrezorConnection, NewTrezorConnection,
TrezorConnection,
binance, binance,
btc, btc,
cardano, cardano,
@ -47,7 +45,6 @@ from . import (
firmware, firmware,
monero, monero,
nem, nem,
new_with_client,
ripple, ripple,
settings, settings,
solana, solana,
@ -261,7 +258,7 @@ def print_result(res: Any, is_json: bool, script: bool, **kwargs: Any) -> None:
@cli.set_result_callback() @cli.set_result_callback()
@click.pass_obj @click.pass_obj
def stop_recording_action(obj: TrezorConnection, *args: Any, **kwargs: Any) -> None: def stop_recording_action(obj: NewTrezorConnection, *args: Any, **kwargs: Any) -> None:
"""Stop recording screen changes when the recording was started by `cli_main`. """Stop recording screen changes when the recording was started by `cli_main`.
(When user used the `-r / --record` option of `trezorctl` command.) (When user used the `-r / --record` option of `trezorctl` command.)
@ -302,14 +299,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
stored_channel_with_correct_transport_path = next( stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path ch for ch in stored_channels if ch.transport_path == path
) )
client = NewTrezorClient.resume( client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path transport, stored_channel_with_correct_transport_path
) )
else: else:
client = NewTrezorClient(transport) client = TrezorClient(transport)
session = client.get_management_session() session = client.get_management_session()
description = format_device_name(session.features) description = format_device_name(session.get_features())
# json_string = channel_database.channel_to_str(client.protocol) # json_string = channel_database.channel_to_str(client.protocol)
# print(json_string) # print(json_string)
channel_database.save_channel(client.protocol) channel_database.save_channel(client.protocol)
@ -348,7 +345,9 @@ def ping(session: "Session", message: str, button_protection: bool) -> str:
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def get_session(obj: TrezorConnection) -> str: def get_session(
obj: NewTrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands. """Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -362,7 +361,10 @@ def get_session(obj: TrezorConnection) -> str:
obj.session_id = None obj.session_id = None
with obj.client_context() as client: with obj.client_context() as client:
if client.features.model == "1" and client.version < (1, 9, 0): session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.get_features().model == "1" and session.get_version() < (1, 9, 0):
raise click.ClickException( raise click.ClickException(
"Upgrade your firmware to enable session support." "Upgrade your firmware to enable session support."
) )
@ -388,11 +390,11 @@ def new_clear_session() -> None:
@cli.command() @cli.command()
@new_with_client @with_client
def get_features(client: "NewTrezorClient") -> messages.Features: def get_features(client: "TrezorClient") -> messages.Features:
"""Retrieve device features and settings.""" """Retrieve device features and settings."""
session = client.get_management_session() session = client.get_management_session()
return session.features return session.get_features()
@cli.command() @cli.command()

File diff suppressed because it is too large Load Diff

View File

@ -57,6 +57,7 @@ if TYPE_CHECKING:
from .messages import PinMatrixRequestType from .messages import PinMatrixRequestType
from .transport import Transport from .transport import Transport
from .transport.new.session import Session
ExpectedMessage = Union[ ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
@ -1107,7 +1108,8 @@ class TrezorClientDebugLink(TrezorClient):
Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before
using `debug.wait_layout()`, otherwise layout changes are not reported. using `debug.wait_layout()`, otherwise layout changes are not reported.
""" """
if self.version >= (2, 3, 2): version = self.get_management_session().get_version()
if version >= (2, 3, 2):
# version check is necessary because otherwise we cannot reliably detect # version check is necessary because otherwise we cannot reliably detect
# whether and where to wait for reply: # whether and where to wait for reply:
# - T1 reports unknown debuglink messages on the wirelink # - T1 reports unknown debuglink messages on the wirelink
@ -1319,7 +1321,7 @@ class TrezorClientDebugLink(TrezorClient):
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def load_device( def load_device(
client: "TrezorClient", session: "Session",
mnemonic: Union[str, Iterable[str]], mnemonic: Union[str, Iterable[str]],
pin: Optional[str], pin: Optional[str],
passphrase_protection: bool, passphrase_protection: bool,
@ -1333,12 +1335,12 @@ def load_device(
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
if client.features.initialized: if session.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again." "Device is initialized already. Call device.wipe() and try again."
) )
resp = client.call( resp = session.call(
messages.LoadDevice( messages.LoadDevice(
mnemonics=mnemonics, mnemonics=mnemonics,
pin=pin, pin=pin,
@ -1349,7 +1351,7 @@ def load_device(
no_backup=no_backup, no_backup=no_backup,
) )
) )
client.init_device() session.init_device()
return resp return resp
@ -1358,11 +1360,11 @@ load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: def prodtest_t1(session: "Session") -> protobuf.MessageType:
if client.features.bootloader_mode is not True: if session.get_features().bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
return client.call( return session.call(
messages.ProdTestT1( messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
) )
@ -1418,5 +1420,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
return client.call(messages.DebugLinkOptigaSetSecMax()) return session.call(messages.DebugLinkOptigaSetSecMax())

View File

@ -23,20 +23,19 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional
from . import messages from . import messages
from .exceptions import Cancelled, TrezorException from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, session from .tools import Address, expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.new.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings( def apply_settings(
client: "TrezorClient", session: "Session",
label: Optional[str] = None, label: Optional[str] = None,
language: Optional[str] = None, language: Optional[str] = None,
use_passphrase: Optional[bool] = None, use_passphrase: Optional[bool] = None,
@ -67,13 +66,13 @@ def apply_settings(
haptic_feedback=haptic_feedback, haptic_feedback=haptic_feedback,
) )
out = client.call(settings) out = session.call(settings)
client.refresh_features() session.refresh_features()
return out return out
def _send_language_data( def _send_language_data(
client: "TrezorClient", session: "Session",
request: "messages.TranslationDataRequest", request: "messages.TranslationDataRequest",
language_data: bytes, language_data: bytes,
) -> "MessageType": ) -> "MessageType":
@ -83,76 +82,69 @@ def _send_language_data(
data_length = response.data_length data_length = response.data_length
data_offset = response.data_offset data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length] chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk)) response = session.call(messages.TranslationDataAck(data_chunk=chunk))
return response return response
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def change_language( def change_language(
client: "TrezorClient", session: "Session",
language_data: bytes, language_data: bytes,
show_display: bool | None = None, show_display: bool | None = None,
) -> "MessageType": ) -> "MessageType":
data_length = len(language_data) data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg) response = session.call(msg)
if data_length > 0: if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest) assert isinstance(response, messages.TranslationDataRequest)
response = _send_language_data(client, response, language_data) response = _send_language_data(session, response, language_data)
assert isinstance(response, messages.Success) assert isinstance(response, messages.Success)
client.refresh_features() # changing the language in features session.refresh_features() # changing the language in features
return response return response
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def apply_flags(session: "Session", flags: int) -> "MessageType":
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": out = session.call(messages.ApplyFlags(flags=flags))
out = client.call(messages.ApplyFlags(flags=flags)) session.refresh_features()
client.refresh_features()
return out return out
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def change_pin(session: "Session", remove: bool = False) -> "MessageType":
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = session.call(messages.ChangePin(remove=remove))
ret = client.call(messages.ChangePin(remove=remove)) session.refresh_features()
client.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = session.call(messages.ChangeWipeCode(remove=remove))
ret = client.call(messages.ChangeWipeCode(remove=remove)) session.refresh_features()
client.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect( def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType session: "Session", operation: messages.SdProtectOperationType
) -> "MessageType": ) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation)) ret = session.call(messages.SdProtect(operation=operation))
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def wipe(session: "Session") -> "MessageType":
def wipe(client: "TrezorClient") -> "MessageType": ret = session.call(messages.WipeDevice())
ret = client.call(messages.WipeDevice()) if not session.get_features().bootloader_mode:
if not client.features.bootloader_mode: session.init_device()
client.init_device()
return ret return ret
@session
def recover( def recover(
client: "TrezorClient", session: "Session",
word_count: int = 24, word_count: int = 24,
passphrase_protection: bool = False, passphrase_protection: bool = False,
pin_protection: bool = True, pin_protection: bool = True,
@ -188,13 +180,16 @@ def recover(
if type is None: if type is None:
type = messages.RecoveryType.NormalRecovery type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None: if session.get_features().model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One") raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24): if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24") raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery: if (
session.get_features().initialized
and type == messages.RecoveryType.NormalRecovery
):
raise RuntimeError( raise RuntimeError(
"Device already initialized. Call device.wipe() and try again." "Device already initialized. Call device.wipe() and try again."
) )
@ -216,24 +211,23 @@ def recover(
msg.label = label msg.label = label
msg.u2f_counter = u2f_counter msg.u2f_counter = u2f_counter
res = client.call(msg) res = session.call(msg)
while isinstance(res, messages.WordRequest): while isinstance(res, messages.WordRequest):
try: try:
assert input_callback is not None assert input_callback is not None
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp)) res = session.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
res = client.call(messages.Cancel()) res = session.call(messages.Cancel())
client.init_device() session.init_device()
return res return res
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def reset( def reset(
client: "TrezorClient", session: "Session",
display_random: bool = False, display_random: bool = False,
strength: Optional[int] = None, strength: Optional[int] = None,
passphrase_protection: bool = False, passphrase_protection: bool = False,
@ -251,13 +245,13 @@ def reset(
DeprecationWarning, DeprecationWarning,
) )
if client.features.initialized: if session.get_features().initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again." "Device is initialized already. Call wipe_device() and try again."
) )
if strength is None: if strength is None:
if client.features.model == "1": if session.get_features().model == "1":
strength = 256 strength = 256
else: else:
strength = 128 strength = 128
@ -275,25 +269,24 @@ def reset(
backup_type=backup_type, backup_type=backup_type,
) )
resp = client.call(msg) resp = session.call(msg)
if not isinstance(resp, messages.EntropyRequest): if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest") raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32) external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex()) # LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(messages.EntropyAck(entropy=external_entropy)) ret = session.call(messages.EntropyAck(entropy=external_entropy))
client.init_device() session.init_device()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def backup( def backup(
client: "TrezorClient", session: "Session",
group_threshold: Optional[int] = None, group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (), groups: Iterable[tuple[int, int]] = (),
) -> "MessageType": ) -> "MessageType":
ret = client.call( ret = session.call(
messages.BackupDevice( messages.BackupDevice(
group_threshold=group_threshold, group_threshold=group_threshold,
groups=[ groups=[
@ -302,37 +295,36 @@ def backup(
], ],
) )
) )
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType": def cancel_authorization(session: "Session") -> "MessageType":
return client.call(messages.CancelAuthorization()) return session.call(messages.CancelAuthorization())
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) @expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": def unlock_path(session: "Session", n: "Address") -> "MessageType":
resp = client.call(messages.UnlockPath(address_n=n)) resp = session.call(messages.UnlockPath(address_n=n))
# Cancel the UnlockPath workflow now that we have the authentication code. # Cancel the UnlockPath workflow now that we have the authentication code.
try: try:
client.call(messages.Cancel()) session.call(messages.Cancel())
except Cancelled: except Cancelled:
return resp return resp
else: else:
raise TrezorException("Unexpected response in UnlockPath flow") raise TrezorException("Unexpected response in UnlockPath flow")
@session
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader( def reboot_to_bootloader(
client: "TrezorClient", session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None, firmware_header: Optional[bytes] = None,
language_data: bytes = b"", language_data: bytes = b"",
) -> "MessageType": ) -> "MessageType":
response = client.call( response = session.call(
messages.RebootToBootloader( messages.RebootToBootloader(
boot_command=boot_command, boot_command=boot_command,
firmware_header=firmware_header, firmware_header=firmware_header,
@ -340,42 +332,37 @@ def reboot_to_bootloader(
) )
) )
if isinstance(response, messages.TranslationDataRequest): if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data) response = _send_language_data(session, response, language_data)
return response return response
@session
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def show_device_tutorial(client: "TrezorClient") -> "MessageType": def show_device_tutorial(session: "Session") -> "MessageType":
return client.call(messages.ShowDeviceTutorial()) return session.call(messages.ShowDeviceTutorial())
@session
@expect(messages.Success, field="message", ret_type=str)
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def unlock_bootloader(session: "Session") -> "MessageType":
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": return session.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str)
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
"""Sets or clears the busy state of the device. """Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen. In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state. Setting `expiry_ms=None` clears the busy state.
""" """
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.AuthenticityProof) @expect(messages.AuthenticityProof)
def authenticate(client: "TrezorClient", challenge: bytes): def authenticate(session: "Session", challenge: bytes):
return client.call(messages.AuthenticateDevice(challenge=challenge)) return session.call(messages.AuthenticateDevice(challenge=challenge))
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def set_brightness( def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
client: "TrezorClient", value: Optional[int] = None return session.call(messages.SetBrightness(value=value))
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))

View File

@ -18,12 +18,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import b58decode, expect, session from .tools import b58decode, expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
def name_to_number(name: str) -> int: def name_to_number(name: str) -> int:
@ -321,17 +321,16 @@ def parse_transaction_json(
@expect(messages.EosPublicKey) @expect(messages.EosPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
response = client.call( response = session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display) messages.EosGetPublicKey(address_n=n, show_display=show_display)
) )
return response return response
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: "Address", address: "Address",
transaction: dict, transaction: dict,
chain_id: str, chain_id: str,
@ -347,11 +346,11 @@ def sign_tx(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
try: try:
while isinstance(response, messages.EosTxActionRequest): while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0)) response = session.call(actions.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -18,12 +18,12 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden from .tools import expect, prepare_message_bytes, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
def int_to_big_endian(value: int) -> bytes: def int_to_big_endian(value: int) -> bytes:
@ -163,13 +163,13 @@ def network_from_address_n(
@expect(messages.EthereumAddress, field="address", ret_type=str) @expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumGetAddress( messages.EthereumGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -181,16 +181,15 @@ def get_address(
@expect(messages.EthereumPublicKey) @expect(messages.EthereumPublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display) messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
nonce: int, nonce: int,
gas_price: int, gas_price: int,
@ -226,13 +225,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024] data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk msg.data_initial_chunk = chunk
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -247,9 +246,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559( def sign_tx_eip1559(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
*, *,
nonce: int, nonce: int,
@ -282,13 +280,13 @@ def sign_tx_eip1559(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
@expect(messages.EthereumMessageSignature) @expect(messages.EthereumMessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumSignMessage( messages.EthereumSignMessage(
address_n=n, address_n=n,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
@expect(messages.EthereumTypedDataSignature) @expect(messages.EthereumTypedDataSignature)
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
data: Dict[str, Any], data: Dict[str, Any],
*, *,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
definitions=definitions, definitions=definitions,
) )
response = client.call(request) response = session.call(request)
# Sending all the types # Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest): while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member) members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members) request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request) response = session.call(request)
# Sending the whole message that should be signed # Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest): while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"] member_typename = data["primaryType"]
member_data = data["message"] member_data = data["message"]
else: else:
client.cancel() session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1") raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...]) # It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename) encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data) request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request) response = session.call(request)
return response return response
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: bytes, signature: bytes,
message: AnyStr, message: AnyStr,
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( resp = session.call(
messages.EthereumVerifyMessage( messages.EthereumVerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -413,13 +411,13 @@ def verify_message(
@expect(messages.EthereumTypedDataSignature) @expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
domain_hash: bytes, domain_hash: bytes,
message_hash: Optional[bytes], message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumSignTypedHash( messages.EthereumSignTypedHash(
address_n=n, address_n=n,
domain_separator_hash=domain_hash, domain_separator_hash=domain_hash,

View File

@ -20,8 +20,8 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.new.session import Session
@expect( @expect(
@ -29,27 +29,27 @@ if TYPE_CHECKING:
field="credentials", field="credentials",
ret_type=List[messages.WebAuthnCredential], ret_type=List[messages.WebAuthnCredential],
) )
def list_credentials(client: "TrezorClient") -> "MessageType": def list_credentials(session: "Session") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials()) return session.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
return client.call( return session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id) messages.WebAuthnAddResidentCredential(credential_id=credential_id)
) )
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType": def remove_credential(session: "Session", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType": def get_next_counter(session: "Session") -> "MessageType":
return client.call(messages.GetNextU2FCounter()) return session.call(messages.GetNextU2FCounter())

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard from typing_extensions import Protocol, TypeGuard
from .. import messages from .. import messages
from ..tools import expect, session from ..tools import expect
from .core import VendorFirmware from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware from .legacy import LegacyFirmware, LegacyV2Firmware
@ -38,7 +38,7 @@ if True:
from .vendor import * # noqa: F401, F403 from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
T = t.TypeVar("T", bound="FirmwareType") T = t.TypeVar("T", bound="FirmwareType")
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== # # ====== Client functions ====== #
@session
def update( def update(
client: "TrezorClient", session: "Session",
data: bytes, data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None, progress_update: t.Callable[[int], t.Any] = lambda _: None,
): ):
if client.features.bootloader_mode is False: if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data))) resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method # TREZORv1 method
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data)) resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data)) progress_update(len(data))
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
return return
@ -97,7 +96,7 @@ def update(
length = resp.length length = resp.length
payload = data[resp.offset : resp.offset + length] payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest() digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length) progress_update(length)
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
@ -107,5 +106,5 @@ def update(
@expect(messages.FirmwareHash, field="hash", ret_type=bytes) @expect(messages.FirmwareHash, field="hash", ret_type=bytes)
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): def get_hash(session: "Session", challenge: t.Optional[bytes]):
return client.call(messages.GetFirmwareHash(challenge=challenge)) return session.call(messages.GetFirmwareHash(challenge=challenge))

View File

@ -20,25 +20,25 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
@expect(messages.Entropy, field="entropy", ret_type=bytes) @expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType": def get_entropy(session: "Session", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size)) return session.call(messages.GetEntropy(size=size))
@expect(messages.SignedIdentity) @expect(messages.SignedIdentity)
def sign_identity( def sign_identity(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
challenge_hidden: bytes, challenge_hidden: bytes,
challenge_visual: str, challenge_visual: str,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
challenge_hidden=challenge_hidden, challenge_hidden=challenge_hidden,
@ -50,12 +50,12 @@ def sign_identity(
@expect(messages.ECDHSessionKey) @expect(messages.ECDHSessionKey)
def get_ecdh_session_key( def get_ecdh_session_key(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
peer_public_key: bytes, peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
peer_public_key=peer_public_key, peer_public_key=peer_public_key,
@ -66,7 +66,7 @@ def get_ecdh_session_key(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -74,7 +74,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -89,7 +89,7 @@ def encrypt_keyvalue(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -97,7 +97,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -111,5 +111,5 @@ def decrypt_keyvalue(
@expect(messages.Nonce, field="nonce", ret_type=bytes) @expect(messages.Nonce, field="nonce", ret_type=bytes)
def get_nonce(client: "TrezorClient"): def get_nonce(session: "Session"):
return client.call(messages.GetNonce()) return session.call(messages.GetNonce())

View File

@ -20,9 +20,9 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
# MAINNET = 0 # MAINNET = 0
@ -33,13 +33,13 @@ if TYPE_CHECKING:
@expect(messages.MoneroAddress, field="address", ret_type=bytes) @expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.MoneroGetAddress( messages.MoneroGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -51,10 +51,10 @@ def get_address(
@expect(messages.MoneroWatchKey) @expect(messages.MoneroWatchKey)
def get_watch_key( def get_watch_key(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type) messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
) )

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
@expect(messages.NEMAddress, field="address", ret_type=str) @expect(messages.NEMAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network: int, network: int,
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.NEMGetAddress( messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify address_n=n, network=network, show_display=show_display, chunkify=chunkify
) )
@ -213,7 +213,7 @@ def get_address(
@expect(messages.NEMSignedTx) @expect(messages.NEMSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType": ) -> "MessageType":
try: try:
msg = create_sign_tx(transaction, chunkify=chunkify) msg = create_sign_tx(transaction, chunkify=chunkify)
@ -222,4 +222,4 @@ def sign_tx(
assert msg.transaction is not None assert msg.transaction is not None
msg.transaction.address_n = n msg.transaction.address_n = n
return client.call(msg) return session.call(msg)

View File

@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect from .tools import dict_from_camelcase, expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str) @expect(messages.RippleAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.RippleGetAddress( messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -45,14 +45,14 @@ def get_address(
@expect(messages.RippleSignedTx) @expect(messages.RippleSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
msg: messages.RippleSignTx, msg: messages.RippleSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
msg.address_n = address_n msg.address_n = address_n
msg.chunkify = chunkify msg.chunkify = chunkify
return client.call(msg) return session.call(msg)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -4,29 +4,29 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.new.session import Session
@expect(messages.SolanaPublicKey) @expect(messages.SolanaPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
) )
@expect(messages.SolanaAddress) @expect(messages.SolanaAddress)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaGetAddress( messages.SolanaGetAddress(
address_n=address_n, address_n=address_n,
show_display=show_display, show_display=show_display,
@ -37,12 +37,12 @@ def get_address(
@expect(messages.SolanaTxSignature) @expect(messages.SolanaTxSignature)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
serialized_tx: bytes, serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo], additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaSignTx( messages.SolanaSignTx(
address_n=address_n, address_n=address_n,
serialized_tx=serialized_tx, serialized_tx=serialized_tx,

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
StellarMessageType = Union[ StellarMessageType = Union[
messages.StellarAccountMergeOp, messages.StellarAccountMergeOp,
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
@expect(messages.StellarAddress, field="address", ret_type=str) @expect(messages.StellarAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.StellarGetAddress( messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -338,7 +338,7 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
tx: messages.StellarSignTx, tx: messages.StellarSignTx,
operations: List["StellarMessageType"], operations: List["StellarMessageType"],
address_n: "Address", address_n: "Address",
@ -354,10 +354,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message # 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method # 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx) resp = session.call(tx)
try: try:
while isinstance(resp, messages.StellarTxOpRequest): while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0)) resp = session.call(operations.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -20,19 +20,19 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.new.session import Session
@expect(messages.TezosAddress, field="address", ret_type=str) @expect(messages.TezosAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.TezosGetAddress( messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -41,12 +41,12 @@ def get_address(
@expect(messages.TezosPublicKey, field="public_key", ret_type=str) @expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.TezosGetPublicKey( messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -55,11 +55,11 @@ def get_public_key(
@expect(messages.TezosSignedTx) @expect(messages.TezosSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
sign_tx_msg: messages.TezosSignTx, sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
sign_tx_msg.address_n = address_n sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg) return session.call(sign_tx_msg)

View File

@ -40,7 +40,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec from typing_extensions import ParamSpec
from . import client from . import client
from .protobuf import MessageType from .protobuf import MessageType
@ -284,24 +284,6 @@ def expect(
return decorator return decorator
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
print("wrap end")
return wrapped_f
# de-camelcasifier # de-camelcasifier
# https://stackoverflow.com/a/1176023/222189 # https://stackoverflow.com/a/1176023/222189

View File

@ -1,101 +0,0 @@
from __future__ import annotations
import logging
from ... import mapping
from ...mapping import ProtobufMapping
from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel, ProtocolV1
from .protocol_v2 import ProtocolV2
from .session import Session, SessionV1, SessionV2
from .transport import NewTransport
LOG = logging.getLogger(__name__)
class NewTrezorClient:
management_session: Session | None = None
def __init__(
self,
transport: NewTransport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None,
) -> None:
self.transport = transport
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
if protocol is None:
try:
self.protocol = self._get_protocol()
except Exception as e:
print(e)
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
@classmethod
def resume(
cls,
transport: NewTransport,
channel_data: ChannelData,
protobuf_mapping: ProtobufMapping | None = None,
) -> NewTrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
if channel_data.protocol_version == 2:
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
else:
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
return NewTrezorClient(transport, protobuf_mapping, protocol)
def get_session(
self,
passphrase: str | None = None,
derive_cardano: bool = False,
) -> Session:
if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2):
return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
def get_management_session(self):
if self.management_session is not None:
return self.management_session
if isinstance(self.protocol, ProtocolV1):
self.management_session = SessionV1.new(self, "", False)
elif isinstance(self.protocol, ProtocolV2):
self.management_session = SessionV2(self, b"\x00")
assert self.management_session is not None
return self.management_session
def resume_session(self, session_id: bytes) -> Session:
raise NotImplementedError # TODO
def _get_protocol(self) -> ProtocolAndChannel:
from ... import mapping, messages
from ...messages import FailureType
from .protocol_and_channel import ProtocolV1
self.transport.open()
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"
):
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.transport, self.mapping)
return protocol

View File

@ -2,35 +2,69 @@ from __future__ import annotations
import typing as t import typing as t
from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession from ... import models
from ...messages import (
Features,
GetFeatures,
Initialize,
ThpCreateNewSession,
ThpNewSession,
)
from .protocol_and_channel import ProtocolV1 from .protocol_and_channel import ProtocolV1
from .protocol_v2 import ProtocolV2 from .protocol_v2 import ProtocolV2
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from .client import NewTrezorClient from ...client import TrezorClient
class Session: class Session:
features: Features
def __init__(self, client: NewTrezorClient, id: bytes) -> None: def __init__(self, client: TrezorClient, id: bytes) -> None:
self.client = client self.client = client
self.id = id self.id = id
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> Session: ) -> Session:
raise NotImplementedError raise NotImplementedError
def call(self, msg: t.Any) -> t.Any: def call(self, msg: t.Any) -> t.Any:
raise NotImplementedError raise NotImplementedError
def refresh_features(self) -> None:
raise NotImplementedError
def get_features(self) -> Features:
raise NotImplementedError
def get_model(self) -> models.TrezorModel:
features = self.get_features()
model = models.by_name(features.model or "1")
if model is None:
raise RuntimeError(
"Unsupported Trezor model"
f" (internal_model: {features.internal_model}, model: {features.model})"
)
return model
def get_version(self) -> t.Tuple[int, int, int]:
features = self.get_features()
version = (
features.major_version,
features.minor_version,
features.patch_version,
)
return version
class SessionV1(Session): class SessionV1(Session):
features: Features
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV1: ) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1) assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"") session = SessionV1(client, b"")
@ -38,7 +72,7 @@ class SessionV1(Session):
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO # Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
Initialize() Initialize()
) )
session.id = session.features.session_id session.id = session.get_features().session_id
return session return session
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any: def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
@ -49,12 +83,18 @@ class SessionV1(Session):
self.client.protocol.write(msg) self.client.protocol.write(msg)
return self.client.protocol.read() return self.client.protocol.read()
def refresh_features(self) -> None:
self.features = self.call(GetFeatures())
def get_features(self) -> Features:
return self.features
class SessionV2(Session): class SessionV2(Session):
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2: ) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2) assert isinstance(client.protocol, ProtocolV2)
session = SessionV2(client, b"\x00") session = SessionV2(client, b"\x00")
@ -66,7 +106,7 @@ class SessionV2(Session):
session.update_id_and_sid(session_id.to_bytes(1, "big")) session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session return session
def __init__(self, client: NewTrezorClient, id: bytes) -> None: def __init__(self, client: TrezorClient, id: bytes) -> None:
super().__init__(client, id) super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2) assert isinstance(client.protocol, ProtocolV2)
@ -78,6 +118,12 @@ class SessionV2(Session):
self.channel.write(self.sid, msg) self.channel.write(self.sid, msg)
return self.channel.read(self.sid) return self.channel.read(self.sid)
def get_features(self) -> Features:
return self.channel.get_features()
def refresh_features(self) -> None:
self.channel.update_features()
def update_id_and_sid(self, id: bytes) -> None: def update_id_and_sid(self, id: bytes) -> None:
self.id = id self.id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -24,13 +24,14 @@ from trezorlib.tools import parse_path
def main() -> None: def main() -> None:
# Use first connected device # Use first connected device
client = get_default_client() client = get_default_client()
session = client.get_session(derive_cardano=True)
# Print out Trezor's features and settings # Print out Trezor's features and settings
print(client.features) print(session.get_features())
# Get the first address of first BIP44 account # Get the first address of first BIP44 account
bip32_path = parse_path("44h/0h/0h/0/0") bip32_path = parse_path("44h/0h/0h/0/0")
address = btc.get_address(client, "Bitcoin", bip32_path, True) address = btc.get_address(session, "Bitcoin", bip32_path, False)
print("Bitcoin address:", address) print("Bitcoin address:", address)

View File

@ -63,7 +63,7 @@ MODULES = (
CALLS_DONE = [] CALLS_DONE = []
DEBUGLINK = None DEBUGLINK = None
get_client_orig = cli.TrezorConnection.get_client get_client_orig = cli.NewTrezorConnection.get_client
def get_client(conn): def get_client(conn):
@ -75,7 +75,7 @@ def get_client(conn):
return client return client
cli.TrezorConnection.get_client = get_client cli.NewTrezorConnection.get_client = get_client
def scan_layouts(dest): def scan_layouts(dest):