1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

wip trezorlib TEMP

(note: has many warnings, not all things will work correctly)
This commit is contained in:
M1nd3r 2024-09-16 18:32:16 +02:00
parent 34d3dbdeb1
commit f9a0124e0f
40 changed files with 1135 additions and 1125 deletions

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata
from . import device
from .client import TrezorClient
from .transport.new.session import Session
try:
cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device(
client: TrezorClient,
session: Session,
challenge: bytes | None = None,
*,
whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None:
challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge)
resp = device.authenticate(session, challenge)
return verify_authentication_response(
challenge,

View File

@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import expect, session
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -42,16 +42,15 @@ def get_address(
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
session: "Session", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
)
@session
def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0]
tx_msg = tx_json.copy()
@ -60,7 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope)
response = session.call(envelope)
if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError(
@ -77,7 +76,7 @@ def sign_tx(
else:
raise ValueError("can not determine msg type")
response = client.call(msg)
response = session.call(msg)
if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError(

View File

@ -13,7 +13,6 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings
from copy import copy
from decimal import Decimal
@ -24,12 +23,11 @@ from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import expect, prepare_message_bytes
from .transport.new.session import Session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
class ScriptSig(TypedDict):
asm: str
@ -176,13 +174,13 @@ def get_authenticated_address(
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.GetOwnershipId(
address_n=n,
coin_name=coin_name,
@ -194,7 +192,7 @@ def get_ownership_id(
# TODO this is used by tests only
def get_ownership_proof(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -205,11 +203,11 @@ def get_ownership_proof(
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(
res = session.call(
messages.GetOwnershipProof(
address_n=n,
coin_name=coin_name,
@ -435,7 +433,7 @@ def sign_tx(
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin(
client: "TrezorClient",
session: "Session",
coordinator: str,
max_rounds: int,
max_coordinator_fee_rate: int,
@ -444,7 +442,7 @@ def authorize_coinjoin(
coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,
max_rounds=max_rounds,

View File

@ -35,8 +35,8 @@ from . import exceptions, messages, tools
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.new.session import Session
PROTOCOL_MAGICS = {
"mainnet": 764824073,
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
@ -833,7 +833,7 @@ def get_address(
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetAddress(
address_parameters=address_parameters,
protocol_magic=protocol_magic,
@ -847,12 +847,12 @@ def get_address(
@expect(messages.CardanoPublicKey)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
show_display: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetPublicKey(
address_n=address_n,
derivation_type=derivation_type,
@ -863,12 +863,12 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetNativeScriptHash(
script=native_script,
display_format=display_format,
@ -878,7 +878,7 @@ def get_native_script_hash(
def sign_tx(
client: "TrezorClient",
session: "Session",
signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithData],
@ -915,7 +915,7 @@ def sign_tx(
signing_mode,
)
response = client.call(
response = session.call(
messages.CardanoSignTxInit(
signing_mode=signing_mode,
inputs_count=len(inputs),
@ -951,14 +951,14 @@ def sign_tx(
_get_certificates_items(certificates),
withdrawals,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data)
auxiliary_data_supplement = session.call(auxiliary_data)
if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
):
@ -971,7 +971,7 @@ def sign_tx(
auxiliary_data_supplement.__dict__
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
@ -980,24 +980,24 @@ def sign_tx(
_get_collateral_inputs_items(collateral_inputs),
required_signers,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None:
for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs:
response = client.call(reference_input)
response = session.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = []
for witness_request in witness_requests:
response = client.call(witness_request)
response = session.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append(
@ -1009,12 +1009,12 @@ def sign_tx(
}
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR

View File

@ -28,9 +28,7 @@ from .. import exceptions, transport, ui
from ..client import TrezorClient
from ..messages import Capability
from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.transport import NewTransport
from ..ui import ClickUI, ScriptUI
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
@ -39,9 +37,6 @@ if t.TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P")
R = TypeVar("R")
@ -117,19 +112,17 @@ class NewTrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(
self,
):
def get_session(self, derive_cardano: bool = False):
client = self.get_client()
if self.session_id is not None:
pass # TODO Try resume
pass # TODO Try resume - be careful of cardano derivation settings!
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=True)
return client.get_session(derive_cardano=derive_cardano)
# TODO Passphrase empty by default - ???
available_on_device = Capability.PassphraseEntry in features.capabilities
@ -137,7 +130,9 @@ class NewTrezorConnection:
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(passphrase=passphrase, derive_cardano=True)
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "NewTransport":
@ -152,7 +147,7 @@ class NewTrezorConnection:
# if this fails, we want the exception to bubble up to the caller
return transport.new_get_transport(self.path, prefix_search=True)
def get_client(self) -> NewTrezorClient:
def get_client(self) -> TrezorClient:
transport = self.get_transport()
stored_channels = channel_database.load_stored_channels()
@ -162,11 +157,11 @@ class NewTrezorConnection:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
client = NewTrezorClient.resume(
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
else:
client = NewTrezorClient(transport)
client = TrezorClient(transport)
return client
@ -202,91 +197,96 @@ class NewTrezorConnection:
# other exceptions may cause a traceback
class TrezorConnection:
# class TrezorConnection:
def __init__(
self,
path: str,
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
self.script = script
# def __init__(
# self,
# path: str,
# session_id: bytes | None,
# passphrase_on_host: bool,
# script: bool,
# ) -> None:
# self.path = path
# self.session_id = session_id
# self.passphrase_on_host = passphrase_on_host
# self.script = script
def get_transport(self) -> "Transport":
try:
# look for transport without prefix search
return transport.get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass
# def get_transport(self) -> "Transport":
# try:
# # look for transport without prefix search
# return transport.get_transport(self.path, prefix_search=False)
# except Exception:
# # most likely not found. try again below.
# pass
# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True)
# # look for transport with prefix search
# # if this fails, we want the exception to bubble up to the caller
# return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
# def get_ui(self) -> "TrezorClientUI":
# if self.script:
# # It is alright to return just the class object instead of instance,
# # as the ScriptUI class object itself is the implementation of TrezorClientUI
# # (ScriptUI is just a set of staticmethods)
# return ScriptUI
# else:
# return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
# def get_client(self) -> TrezorClient:
# transport = self.get_transport()
# ui = self.get_ui()
# return TrezorClient(transport, ui=ui, session_id=self.session_id)
@contextmanager
def client_context(self):
"""Get a client instance as a context manager. Handle errors in a manner
appropriate for end-users.
# @contextmanager
# def client_context(self):
# """Get a client instance as a context manager. Handle errors in a manner
# appropriate for end-users.
Usage:
>>> with obj.client_context() as client:
>>> do_your_actions_here()
"""
try:
client = self.get_client()
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
click.echo(f"Using path: {self.path}")
sys.exit(1)
try:
yield client
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
# Usage:
# >>> with obj.client_context() as client:
# >>> do_your_actions_here()
# """
# try:
# client = self.get_client()
# except transport.DeviceIsBusy:
# click.echo("Device is in use by another process.")
# sys.exit(1)
# except Exception:
# click.echo("Failed to find a Trezor device.")
# if self.path is not None:
# click.echo(f"Using path: {self.path}")
# sys.exit(1)
# try:
# yield client
# except exceptions.Cancelled:
# # handle cancel action
# click.echo("Action was cancelled.")
# sys.exit(1)
# except exceptions.TrezorException as e:
# # handle any Trezor-sent exceptions as user-readable
# raise click.ClickException(str(e)) from e
# # other exceptions may cause a traceback
from ..transport.new.session import Session
def with_session(
def with_cardano_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
return with_session(func=func, derive_cardano=True)
def with_session(
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session()
session = obj.get_session(derive_cardano)
try:
return func(session, *args, **kwargs)
finally:
@ -298,8 +298,8 @@ def with_session(
return function_with_session # type: ignore [is incompatible with return type]
def new_with_client(
func: "t.Callable[Concatenate[NewTrezorClient, P], R]",
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
@ -336,39 +336,39 @@ def new_with_client(
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
"""
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
try:
return func(client, *args, **kwargs)
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group):

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
from ..transport.new.session import Session
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify)
return binance.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
return binance.get_public_key(session, address_n, show_display).hex()
@cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.
"""
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_cardano_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client
@with_cardano_session
def sign_tx(
client: "TrezorClient",
session: "Session",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
@ -123,9 +123,9 @@ def sign_tx(
for p in transaction["additional_witness_requests"]
]
client.init_device(derive_cardano=True)
session.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx(
client,
session,
signing_mode,
inputs,
outputs,
@ -209,9 +209,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_cardano_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
@ -262,9 +262,9 @@ def get_address(
script_staking_hash_bytes,
)
client.init_device(derive_cardano=True)
session.init_device(derive_cardano=True)
return cardano.get_address(
client,
session,
address_parameters,
protocol_magic,
network_id,
@ -283,18 +283,18 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_cardano_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
derivation_type: messages.CardanoDerivationType,
show_display: bool,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
session.init_device(derive_cardano=True)
return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display
session, address_n, derivation_type=derivation_type, show_display=show_display
)
@ -312,9 +312,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
@with_cardano_session
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
@ -323,7 +323,7 @@ def get_native_script_hash(
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
session.init_device(derive_cardano=True)
return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type
session, native_script, display_format, derivation_type=derivation_type
)

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client: "TrezorClient", size: int) -> str:
@with_session
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
return misc.get_entropy(session, size).hex()
@cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(
client,
session,
address_n,
key,
value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(
client,
session,
address_n,
key,
bytes.fromhex(value),

View File

@ -27,7 +27,7 @@ from ..debuglink import record_screen
from . import with_client
if TYPE_CHECKING:
from . import TrezorConnection
from . import NewTrezorConnection
@click.group(name="debug")
@ -40,7 +40,7 @@ def cli() -> None:
@click.argument("hex_data")
@click.pass_obj
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
@ -86,7 +86,7 @@ def send_bytes(
@click.argument("directory", required=False)
@click.option("-s", "--stop", is_flag=True, help="Stop the recording")
@click.pass_obj
def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> None:
def record(obj: "NewTrezorConnection", directory: Union[str, None], stop: bool) -> None:
"""Record screen changes into a specified directory.
Recording can be stopped with `-s / --stop` option.
@ -95,7 +95,7 @@ def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) ->
def record_screen_from_connection(
obj: "TrezorConnection", directory: Union[str, None]
obj: "NewTrezorConnection", directory: Union[str, None]
) -> None:
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
transport = obj.get_transport()

View File

@ -29,7 +29,7 @@ from . import ChoiceType, with_client
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..protobuf import MessageType
from . import TrezorConnection
from . import NewTrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
"scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords,
@ -67,14 +67,15 @@ def cli() -> None:
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data."""
features = client.get_management_session().get_features()
if bootloader:
if not client.features.bootloader_mode:
if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
else:
click.echo("Wiping user data and firmware!")
else:
if client.features.bootloader_mode:
if features.bootloader_mode:
click.echo(
"Your device is in bootloader mode. This operation would also erase firmware."
)
@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
click.echo("Wiping user data!")
try:
return device.wipe(client)
return device.wipe(
client
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@ -233,9 +236,11 @@ def setup(
strength = int(strength)
BT = messages.BackupType
management_session = client.get_management_session()
version = management_session.get_version()
if backup_type is None:
if client.version >= (2, 7, 1):
if version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable
else:
@ -309,7 +314,7 @@ def sd_protect(
@cli.command()
@click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
def reboot_to_bootloader(obj: "NewTrezorConnection") -> str:
"""Reboot device into bootloader mode.
Currently only supported on Trezor Model One.

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)
address_n = tools.parse_path(address)
return eos.sign_tx(
client,
session,
address_n,
tx_json["transaction"],
tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions
from . import with_client
from . import with_session
if TYPE_CHECKING:
import web3
from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify)
return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
@with_session
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
result = ethereum.get_public_node(session, address_n, show_display=show_display)
return {
"node": {
"depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
chain_id: int,
address: str,
amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address)
from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network
session, address_n, encoded_network=encoded_network
)
if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
session,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price
assert gas_price is not None
sig = ethereum.sign_tx(
client,
session,
n=address_n,
tx_type=tx_type,
nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool
session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = {
"message": message,
"address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation",
)
@click.argument("file", type=click.File("r"))
@with_client
@with_session
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read())
ret = ethereum.sign_typed_data(
client,
session,
address_n,
data,
metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: str,
message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify
session, address, signature_bytes, message, chunkify=chunkify
)
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex")
@click.argument("message_hash_hex")
@with_client
@with_session
def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]:
"""
Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network
session, address_n, domain_hash, message_hash, network
)
output = {
"domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list")
@with_client
def credentials_list(client: "TrezorClient") -> None:
@with_session
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
creds = fido.list_credentials(session)
for cred in creds:
click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
@with_session
def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> str:
@with_session
def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index."""
return fido.remove_credential(client, index)
return fido.remove_credential(session, index)
#
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> str:
@with_session
def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value."""
return fido.set_counter(client, counter)
return fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client: "TrezorClient") -> int:
@with_session
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(client)
return fido.get_next_counter(session)

View File

@ -41,7 +41,7 @@ from . import ChoiceType, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
from . import TrezorConnection
from . import NewTrezorConnection
MODEL_CHOICE = ChoiceType(
{
@ -74,9 +74,10 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader.
"""
f = client.features
version = (f.major_version, f.minor_version, f.patch_version)
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
management_session = client.get_management_session()
features = management_session.get_features()
version = management_session.get_version()
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2
@ -306,25 +307,27 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version
(higher than the specified one, if existing).
"""
management_session = client.get_management_session()
features = management_session.get_features()
model = management_session.get_model()
if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features)
bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
releases = get_all_firmware_releases(model, bitcoin_only, beta)
highest_version = releases[0]["version"]
if version:
want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version:
if want_version[0] != features.major_version:
click.echo(
f"Warning: Trezor {client.model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})"
f"Warning: Trezor {model.name} firmware version should be "
f"{features.major_version}.X.Y (requested: {version})"
)
else:
want_version = highest_version
@ -359,8 +362,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal
# compatible version first
# Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version]
if f.bootloader_mode:
client_version = management_session.get_version()
if features.bootloader_mode:
key_to_compare = "min_bootloader_version"
else:
key_to_compare = "min_firmware_version"
@ -451,7 +454,7 @@ def upload_firmware_into_device(
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
f = client.features
f = client.get_management_session().get_features()
try:
if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest
@ -482,7 +485,7 @@ def _is_strict_update(client: "TrezorClient", firmware_data: bytes) -> bool:
if not isinstance(fw, firmware.VendorFirmware):
return False
f = client.features
f = client.get_management_session().get_features()
cur_version = (f.major_version, f.minor_version, f.patch_version, 0)
return (
@ -519,7 +522,7 @@ def cli() -> None:
@click.pass_obj
# fmt: on
def verify(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
filename: BinaryIO,
check_device: bool,
fingerprint: Optional[str],
@ -564,7 +567,7 @@ def verify(
@click.pass_obj
# fmt: on
def download(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
output: Optional[BinaryIO],
model: Optional[TrezorModel],
version: Optional[str],
@ -630,7 +633,7 @@ def download(
# fmt: on
@click.pass_obj
def update(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
filename: Optional[BinaryIO],
url: Optional[str],
version: Optional[str],

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click
from .. import messages, monero, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify)
return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET,
)
@with_client
@with_session
def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type)
res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests
from .. import nem, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
network: int,
show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify)
return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
file: TextIO,
broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
"""
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify)
return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:")
click.echo(result.signature.hex())
click.echo()

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import messages, solana, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
) -> messages.SolanaPublicKey:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
return solana.get_public_key(session, address_n, show_display)
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
chunkify: bool,
) -> messages.SolanaAddress:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
return solana.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
)
return solana.sign_tx(
client,
session,
address_n,
bytes.fromhex(serialized_tx),
additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
try:
from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify)
return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Tezos address for specified path."""
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify)
return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key."""
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
return tezos.get_public_key(session, address_n, show_display)
@cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx:
"""Sign Tezos transaction."""
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -28,13 +28,11 @@ from .. import __version__, log, messages, protobuf
from ..client import TrezorClient
from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.session import Session
from ..transport.new.udp import UdpTransport
from . import (
AliasedGroup,
NewTrezorConnection,
TrezorConnection,
binance,
btc,
cardano,
@ -47,7 +45,6 @@ from . import (
firmware,
monero,
nem,
new_with_client,
ripple,
settings,
solana,
@ -261,7 +258,7 @@ def print_result(res: Any, is_json: bool, script: bool, **kwargs: Any) -> None:
@cli.set_result_callback()
@click.pass_obj
def stop_recording_action(obj: TrezorConnection, *args: Any, **kwargs: Any) -> None:
def stop_recording_action(obj: NewTrezorConnection, *args: Any, **kwargs: Any) -> None:
"""Stop recording screen changes when the recording was started by `cli_main`.
(When user used the `-r / --record` option of `trezorctl` command.)
@ -302,14 +299,14 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
client = NewTrezorClient.resume(
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
else:
client = NewTrezorClient(transport)
client = TrezorClient(transport)
session = client.get_management_session()
description = format_device_name(session.features)
description = format_device_name(session.get_features())
# json_string = channel_database.channel_to_str(client.protocol)
# print(json_string)
channel_database.save_channel(client.protocol)
@ -348,7 +345,9 @@ def ping(session: "Session", message: str, button_protection: bool) -> str:
@cli.command()
@click.pass_obj
def get_session(obj: TrezorConnection) -> str:
def get_session(
obj: NewTrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -362,7 +361,10 @@ def get_session(obj: TrezorConnection) -> str:
obj.session_id = None
with obj.client_context() as client:
if client.features.model == "1" and client.version < (1, 9, 0):
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.get_features().model == "1" and session.get_version() < (1, 9, 0):
raise click.ClickException(
"Upgrade your firmware to enable session support."
)
@ -388,11 +390,11 @@ def new_clear_session() -> None:
@cli.command()
@new_with_client
def get_features(client: "NewTrezorClient") -> messages.Features:
@with_client
def get_features(client: "TrezorClient") -> messages.Features:
"""Retrieve device features and settings."""
session = client.get_management_session()
return session.features
return session.get_features()
@cli.command()

File diff suppressed because it is too large Load Diff

View File

@ -57,6 +57,7 @@ if TYPE_CHECKING:
from .messages import PinMatrixRequestType
from .transport import Transport
from .transport.new.session import Session
ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
@ -1107,7 +1108,8 @@ class TrezorClientDebugLink(TrezorClient):
Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before
using `debug.wait_layout()`, otherwise layout changes are not reported.
"""
if self.version >= (2, 3, 2):
version = self.get_management_session().get_version()
if version >= (2, 3, 2):
# version check is necessary because otherwise we cannot reliably detect
# whether and where to wait for reply:
# - T1 reports unknown debuglink messages on the wirelink
@ -1319,7 +1321,7 @@ class TrezorClientDebugLink(TrezorClient):
@expect(messages.Success, field="message", ret_type=str)
def load_device(
client: "TrezorClient",
session: "Session",
mnemonic: Union[str, Iterable[str]],
pin: Optional[str],
passphrase_protection: bool,
@ -1333,12 +1335,12 @@ def load_device(
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
if client.features.initialized:
if session.features.initialized:
raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again."
)
resp = client.call(
resp = session.call(
messages.LoadDevice(
mnemonics=mnemonics,
pin=pin,
@ -1349,7 +1351,7 @@ def load_device(
no_backup=no_backup,
)
)
client.init_device()
session.init_device()
return resp
@ -1358,11 +1360,11 @@ load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str)
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
if client.features.bootloader_mode is not True:
def prodtest_t1(session: "Session") -> protobuf.MessageType:
if session.get_features().bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode")
return client.call(
return session.call(
messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
)
@ -1418,5 +1420,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
@expect(messages.Success, field="message", ret_type=str)
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType:
return client.call(messages.DebugLinkOptigaSetSecMax())
def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
return session.call(messages.DebugLinkOptigaSetSecMax())

View File

@ -23,20 +23,19 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional
from . import messages
from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, session
from .tools import Address, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.new.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings(
client: "TrezorClient",
session: "Session",
label: Optional[str] = None,
language: Optional[str] = None,
use_passphrase: Optional[bool] = None,
@ -67,13 +66,13 @@ def apply_settings(
haptic_feedback=haptic_feedback,
)
out = client.call(settings)
client.refresh_features()
out = session.call(settings)
session.refresh_features()
return out
def _send_language_data(
client: "TrezorClient",
session: "Session",
request: "messages.TranslationDataRequest",
language_data: bytes,
) -> "MessageType":
@ -83,76 +82,69 @@ def _send_language_data(
data_length = response.data_length
data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
return response
@expect(messages.Success, field="message", ret_type=str)
@session
def change_language(
client: "TrezorClient",
session: "Session",
language_data: bytes,
show_display: bool | None = None,
) -> "MessageType":
data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg)
response = session.call(msg)
if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest)
response = _send_language_data(client, response, language_data)
response = _send_language_data(session, response, language_data)
assert isinstance(response, messages.Success)
client.refresh_features() # changing the language in features
session.refresh_features() # changing the language in features
return response
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
out = client.call(messages.ApplyFlags(flags=flags))
client.refresh_features()
def apply_flags(session: "Session", flags: int) -> "MessageType":
out = session.call(messages.ApplyFlags(flags=flags))
session.refresh_features()
return out
@expect(messages.Success, field="message", ret_type=str)
@session
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangePin(remove=remove))
client.refresh_features()
def change_pin(session: "Session", remove: bool = False) -> "MessageType":
ret = session.call(messages.ChangePin(remove=remove))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangeWipeCode(remove=remove))
client.refresh_features()
def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
ret = session.call(messages.ChangeWipeCode(remove=remove))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
session: "Session", operation: messages.SdProtectOperationType
) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation))
client.refresh_features()
ret = session.call(messages.SdProtect(operation=operation))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice())
if not client.features.bootloader_mode:
client.init_device()
def wipe(session: "Session") -> "MessageType":
ret = session.call(messages.WipeDevice())
if not session.get_features().bootloader_mode:
session.init_device()
return ret
@session
def recover(
client: "TrezorClient",
session: "Session",
word_count: int = 24,
passphrase_protection: bool = False,
pin_protection: bool = True,
@ -188,13 +180,16 @@ def recover(
if type is None:
type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None:
if session.get_features().model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
if (
session.get_features().initialized
and type == messages.RecoveryType.NormalRecovery
):
raise RuntimeError(
"Device already initialized. Call device.wipe() and try again."
)
@ -216,24 +211,23 @@ def recover(
msg.label = label
msg.u2f_counter = u2f_counter
res = client.call(msg)
res = session.call(msg)
while isinstance(res, messages.WordRequest):
try:
assert input_callback is not None
inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp))
res = session.call(messages.WordAck(word=inp))
except Cancelled:
res = client.call(messages.Cancel())
res = session.call(messages.Cancel())
client.init_device()
session.init_device()
return res
@expect(messages.Success, field="message", ret_type=str)
@session
def reset(
client: "TrezorClient",
session: "Session",
display_random: bool = False,
strength: Optional[int] = None,
passphrase_protection: bool = False,
@ -251,13 +245,13 @@ def reset(
DeprecationWarning,
)
if client.features.initialized:
if session.get_features().initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if strength is None:
if client.features.model == "1":
if session.get_features().model == "1":
strength = 256
else:
strength = 128
@ -275,25 +269,24 @@ def reset(
backup_type=backup_type,
)
resp = client.call(msg)
resp = session.call(msg)
if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(messages.EntropyAck(entropy=external_entropy))
client.init_device()
ret = session.call(messages.EntropyAck(entropy=external_entropy))
session.init_device()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (),
) -> "MessageType":
ret = client.call(
ret = session.call(
messages.BackupDevice(
group_threshold=group_threshold,
groups=[
@ -302,37 +295,36 @@ def backup(
],
)
)
client.refresh_features()
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType":
return client.call(messages.CancelAuthorization())
def cancel_authorization(session: "Session") -> "MessageType":
return session.call(messages.CancelAuthorization())
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType":
resp = client.call(messages.UnlockPath(address_n=n))
def unlock_path(session: "Session", n: "Address") -> "MessageType":
resp = session.call(messages.UnlockPath(address_n=n))
# Cancel the UnlockPath workflow now that we have the authentication code.
try:
client.call(messages.Cancel())
session.call(messages.Cancel())
except Cancelled:
return resp
else:
raise TrezorException("Unexpected response in UnlockPath flow")
@session
@expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader(
client: "TrezorClient",
session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None,
language_data: bytes = b"",
) -> "MessageType":
response = client.call(
response = session.call(
messages.RebootToBootloader(
boot_command=boot_command,
firmware_header=firmware_header,
@ -340,42 +332,37 @@ def reboot_to_bootloader(
)
)
if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data)
response = _send_language_data(session, response, language_data)
return response
@session
@expect(messages.Success, field="message", ret_type=str)
def show_device_tutorial(client: "TrezorClient") -> "MessageType":
return client.call(messages.ShowDeviceTutorial())
@session
@expect(messages.Success, field="message", ret_type=str)
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.UnlockBootloader())
def show_device_tutorial(session: "Session") -> "MessageType":
return session.call(messages.ShowDeviceTutorial())
@expect(messages.Success, field="message", ret_type=str)
@session
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType":
def unlock_bootloader(session: "Session") -> "MessageType":
return session.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str)
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
"""Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state.
"""
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms))
client.refresh_features()
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
session.refresh_features()
return ret
@expect(messages.AuthenticityProof)
def authenticate(client: "TrezorClient", challenge: bytes):
return client.call(messages.AuthenticateDevice(challenge=challenge))
def authenticate(session: "Session", challenge: bytes):
return session.call(messages.AuthenticateDevice(challenge=challenge))
@expect(messages.Success, field="message", ret_type=str)
def set_brightness(
client: "TrezorClient", value: Optional[int] = None
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))
def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
return session.call(messages.SetBrightness(value=value))

View File

@ -18,12 +18,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages
from .tools import b58decode, expect, session
from .tools import b58decode, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
def name_to_number(name: str) -> int:
@ -321,17 +321,16 @@ def parse_transaction_json(
@expect(messages.EosPublicKey)
def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> "MessageType":
response = client.call(
response = session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display)
)
return response
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: "Address",
transaction: dict,
chain_id: str,
@ -347,11 +346,11 @@ def sign_tx(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
try:
while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0))
response = session.call(actions.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -18,12 +18,12 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden
from .tools import expect, prepare_message_bytes, unharden
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
def int_to_big_endian(value: int) -> bytes:
@ -163,13 +163,13 @@ def network_from_address_n(
@expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumGetAddress(
address_n=n,
show_display=show_display,
@ -181,16 +181,15 @@ def get_address(
@expect(messages.EthereumPublicKey)
def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
)
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
n: "Address",
nonce: int,
gas_price: int,
@ -226,13 +225,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -247,9 +246,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559(
client: "TrezorClient",
session: "Session",
n: "Address",
*,
nonce: int,
@ -282,13 +280,13 @@ def sign_tx_eip1559(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
@expect(messages.EthereumMessageSignature)
def sign_message(
client: "TrezorClient",
session: "Session",
n: "Address",
message: AnyStr,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumSignMessage(
address_n=n,
message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data(
client: "TrezorClient",
session: "Session",
n: "Address",
data: Dict[str, Any],
*,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat,
definitions=definitions,
)
response = client.call(request)
response = session.call(request)
# Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request)
response = session.call(request)
# Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"]
member_data = data["message"]
else:
client.cancel()
session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request)
response = session.call(request)
return response
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: bytes,
message: AnyStr,
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
resp = session.call(
messages.EthereumVerifyMessage(
address=address,
signature=signature,
@ -413,13 +411,13 @@ def verify_message(
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash(
client: "TrezorClient",
session: "Session",
n: "Address",
domain_hash: bytes,
message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumSignTypedHash(
address_n=n,
domain_separator_hash=domain_hash,

View File

@ -20,8 +20,8 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.new.session import Session
@expect(
@ -29,27 +29,27 @@ if TYPE_CHECKING:
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials())
def list_credentials(session: "Session") -> "MessageType":
return session.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
return client.call(
def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
return session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
)
@expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
def remove_credential(session: "Session", index: int) -> "MessageType":
return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType":
return client.call(messages.GetNextU2FCounter())
def get_next_counter(session: "Session") -> "MessageType":
return session.call(messages.GetNextU2FCounter())

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard
from .. import messages
from ..tools import expect, session
from ..tools import expect
from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware
@ -38,7 +38,7 @@ if True:
from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
T = t.TypeVar("T", bound="FirmwareType")
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== #
@session
def update(
client: "TrezorClient",
session: "Session",
data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None,
):
if client.features.bootloader_mode is False:
if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data)))
resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method
if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data))
resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data))
if isinstance(resp, messages.Success):
return
@ -97,7 +96,7 @@ def update(
length = resp.length
payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length)
if isinstance(resp, messages.Success):
@ -107,5 +106,5 @@ def update(
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]):
return client.call(messages.GetFirmwareHash(challenge=challenge))
def get_hash(session: "Session", challenge: t.Optional[bytes]):
return session.call(messages.GetFirmwareHash(challenge=challenge))

View File

@ -20,25 +20,25 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
@expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size))
def get_entropy(session: "Session", size: int) -> "MessageType":
return session.call(messages.GetEntropy(size=size))
@expect(messages.SignedIdentity)
def sign_identity(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.SignIdentity(
identity=identity,
challenge_hidden=challenge_hidden,
@ -50,12 +50,12 @@ def sign_identity(
@expect(messages.ECDHSessionKey)
def get_ecdh_session_key(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.GetECDHSessionKey(
identity=identity,
peer_public_key=peer_public_key,
@ -66,7 +66,7 @@ def get_ecdh_session_key(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -74,7 +74,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -89,7 +89,7 @@ def encrypt_keyvalue(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -97,7 +97,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -111,5 +111,5 @@ def decrypt_keyvalue(
@expect(messages.Nonce, field="nonce", ret_type=bytes)
def get_nonce(client: "TrezorClient"):
return client.call(messages.GetNonce())
def get_nonce(session: "Session"):
return session.call(messages.GetNonce())

View File

@ -20,9 +20,9 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
# MAINNET = 0
@ -33,13 +33,13 @@ if TYPE_CHECKING:
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.MoneroGetAddress(
address_n=n,
show_display=show_display,
@ -51,10 +51,10 @@ def get_address(
@expect(messages.MoneroWatchKey)
def get_watch_key(
client: "TrezorClient",
session: "Session",
n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType":
return client.call(
return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
)

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
@expect(messages.NEMAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
network: int,
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify
)
@ -213,7 +213,7 @@ def get_address(
@expect(messages.NEMSignedTx)
def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType":
try:
msg = create_sign_tx(transaction, chunkify=chunkify)
@ -222,4 +222,4 @@ def sign_tx(
assert msg.transaction is not None
msg.transaction.address_n = n
return client.call(msg)
return session.call(msg)

View File

@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -45,14 +45,14 @@ def get_address(
@expect(messages.RippleSignedTx)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
msg: messages.RippleSignTx,
chunkify: bool = False,
) -> "MessageType":
msg.address_n = address_n
msg.chunkify = chunkify
return client.call(msg)
return session.call(msg)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -4,29 +4,29 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.new.session import Session
@expect(messages.SolanaPublicKey)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
)
@expect(messages.SolanaAddress)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaGetAddress(
address_n=address_n,
show_display=show_display,
@ -37,12 +37,12 @@ def get_address(
@expect(messages.SolanaTxSignature)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: List[int],
serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaSignTx(
address_n=address_n,
serialized_tx=serialized_tx,

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
StellarMessageType = Union[
messages.StellarAccountMergeOp,
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
@expect(messages.StellarAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -338,7 +338,7 @@ def get_address(
def sign_tx(
client: "TrezorClient",
session: "Session",
tx: messages.StellarSignTx,
operations: List["StellarMessageType"],
address_n: "Address",
@ -354,10 +354,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx)
resp = session.call(tx)
try:
while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0))
resp = session.call(operations.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -20,19 +20,19 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.new.session import Session
@expect(messages.TezosAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -41,12 +41,12 @@ def get_address(
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -55,11 +55,11 @@ def get_public_key(
@expect(messages.TezosSignedTx)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False,
) -> "MessageType":
sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg)
return session.call(sign_tx_msg)

View File

@ -40,7 +40,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
from . import client
from .protobuf import MessageType
@ -284,24 +284,6 @@ def expect(
return decorator
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
print("wrap end")
return wrapped_f
# de-camelcasifier
# https://stackoverflow.com/a/1176023/222189

View File

@ -1,101 +0,0 @@
from __future__ import annotations
import logging
from ... import mapping
from ...mapping import ProtobufMapping
from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel, ProtocolV1
from .protocol_v2 import ProtocolV2
from .session import Session, SessionV1, SessionV2
from .transport import NewTransport
LOG = logging.getLogger(__name__)
class NewTrezorClient:
management_session: Session | None = None
def __init__(
self,
transport: NewTransport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None,
) -> None:
self.transport = transport
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
if protocol is None:
try:
self.protocol = self._get_protocol()
except Exception as e:
print(e)
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
@classmethod
def resume(
cls,
transport: NewTransport,
channel_data: ChannelData,
protobuf_mapping: ProtobufMapping | None = None,
) -> NewTrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
if channel_data.protocol_version == 2:
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
else:
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
return NewTrezorClient(transport, protobuf_mapping, protocol)
def get_session(
self,
passphrase: str | None = None,
derive_cardano: bool = False,
) -> Session:
if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2):
return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
def get_management_session(self):
if self.management_session is not None:
return self.management_session
if isinstance(self.protocol, ProtocolV1):
self.management_session = SessionV1.new(self, "", False)
elif isinstance(self.protocol, ProtocolV2):
self.management_session = SessionV2(self, b"\x00")
assert self.management_session is not None
return self.management_session
def resume_session(self, session_id: bytes) -> Session:
raise NotImplementedError # TODO
def _get_protocol(self) -> ProtocolAndChannel:
from ... import mapping, messages
from ...messages import FailureType
from .protocol_and_channel import ProtocolV1
self.transport.open()
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"
):
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.transport, self.mapping)
return protocol

View File

@ -2,35 +2,69 @@ from __future__ import annotations
import typing as t
from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession
from ... import models
from ...messages import (
Features,
GetFeatures,
Initialize,
ThpCreateNewSession,
ThpNewSession,
)
from .protocol_and_channel import ProtocolV1
from .protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from .client import NewTrezorClient
from ...client import TrezorClient
class Session:
features: Features
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
def __init__(self, client: TrezorClient, id: bytes) -> None:
self.client = client
self.id = id
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> Session:
raise NotImplementedError
def call(self, msg: t.Any) -> t.Any:
raise NotImplementedError
def refresh_features(self) -> None:
raise NotImplementedError
def get_features(self) -> Features:
raise NotImplementedError
def get_model(self) -> models.TrezorModel:
features = self.get_features()
model = models.by_name(features.model or "1")
if model is None:
raise RuntimeError(
"Unsupported Trezor model"
f" (internal_model: {features.internal_model}, model: {features.model})"
)
return model
def get_version(self) -> t.Tuple[int, int, int]:
features = self.get_features()
version = (
features.major_version,
features.minor_version,
features.patch_version,
)
return version
class SessionV1(Session):
features: Features
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"")
@ -38,7 +72,7 @@ class SessionV1(Session):
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
Initialize()
)
session.id = session.features.session_id
session.id = session.get_features().session_id
return session
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
@ -49,12 +83,18 @@ class SessionV1(Session):
self.client.protocol.write(msg)
return self.client.protocol.read()
def refresh_features(self) -> None:
self.features = self.call(GetFeatures())
def get_features(self) -> Features:
return self.features
class SessionV2(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2)
session = SessionV2(client, b"\x00")
@ -66,7 +106,7 @@ class SessionV2(Session):
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
def __init__(self, client: TrezorClient, id: bytes) -> None:
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
@ -78,6 +118,12 @@ class SessionV2(Session):
self.channel.write(self.sid, msg)
return self.channel.read(self.sid)
def get_features(self) -> Features:
return self.channel.get_features()
def refresh_features(self) -> None:
self.channel.update_features()
def update_id_and_sid(self, id: bytes) -> None:
self.id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -24,13 +24,14 @@ from trezorlib.tools import parse_path
def main() -> None:
# Use first connected device
client = get_default_client()
session = client.get_session(derive_cardano=True)
# Print out Trezor's features and settings
print(client.features)
print(session.get_features())
# Get the first address of first BIP44 account
bip32_path = parse_path("44h/0h/0h/0/0")
address = btc.get_address(client, "Bitcoin", bip32_path, True)
address = btc.get_address(session, "Bitcoin", bip32_path, False)
print("Bitcoin address:", address)

View File

@ -63,7 +63,7 @@ MODULES = (
CALLS_DONE = []
DEBUGLINK = None
get_client_orig = cli.TrezorConnection.get_client
get_client_orig = cli.NewTrezorConnection.get_client
def get_client(conn):
@ -75,7 +75,7 @@ def get_client(conn):
return client
cli.TrezorConnection.get_client = get_client
cli.NewTrezorConnection.get_client = get_client
def scan_layouts(dest):