1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-26 06:12:01 +00:00
This commit is contained in:
M1nd3r 2025-02-17 17:13:14 +01:00
parent 98e75f2e51
commit 9c4b4d2897
2 changed files with 20 additions and 6 deletions

View File

@ -33,6 +33,7 @@ from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
Capability,
TrezorConnection, TrezorConnection,
benchmark, benchmark,
binance, binance,
@ -45,6 +46,7 @@ from . import (
ethereum, ethereum,
fido, fido,
firmware, firmware,
get_passphrase,
monero, monero,
nem, nem,
ripple, ripple,
@ -225,7 +227,7 @@ def cli_main(
bytes_session_id = bytes.fromhex(session_id) bytes_session_id = bytes.fromhex(session_id)
except ValueError: except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}") raise click.ClickException(f"Not a valid session id: {session_id}")
# channel database = get_db(should_not_store=no_store)
ctx.obj = TrezorConnection( ctx.obj = TrezorConnection(
path, bytes_session_id, passphrase_on_host, script, get_channel_db() path, bytes_session_id, passphrase_on_host, script, get_channel_db()
) )
@ -342,10 +344,14 @@ def ping(session: "Session", message: str, button_protection: bool) -> str:
@cli.command() @cli.command()
@click.option(
"-c",
"--derive-cardano",
is_flag=True,
help="Should the session have cardano seed derived.",
)
@click.pass_obj @click.pass_obj
def get_session( def get_session(obj: TrezorConnection, derive_cardano: bool) -> str:
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands. """Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -363,11 +369,13 @@ def get_session(
raise click.ClickException( raise click.ClickException(
"Upgrade your firmware to enable session support." "Upgrade your firmware to enable session support."
) )
available_on_device = Capability.PassphraseEntry in client.features.capabilities
# client.ensure_unlocked() # client.ensure_unlocked()
passphrase = get_passphrase(available_on_device, obj.passphrase_on_host)
session = client.get_session( session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano passphrase=passphrase, derive_cardano=derive_cardano
) ) # TODO add arguments to get_passphrase, add session id=1 or from user input
if session.id is None: if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.") raise click.ClickException("Passphrase not enabled or firmware too old.")
else: else:

View File

@ -71,6 +71,7 @@ class TrezorClient:
transport: Transport, transport: Transport,
protobuf_mapping: ProtobufMapping | None = None, protobuf_mapping: ProtobufMapping | None = None,
protocol: Channel | None = None, protocol: Channel | None = None,
# channel_database
) -> None: ) -> None:
self._is_invalidated: bool = False self._is_invalidated: bool = False
self.transport = transport self.transport = transport
@ -126,6 +127,9 @@ class TrezorClient:
protocol = ProtocolV1Channel(transport, protobuf_mapping, channel_data) protocol = ProtocolV1Channel(transport, protobuf_mapping, channel_data)
return TrezorClient(transport, protobuf_mapping, protocol) return TrezorClient(transport, protobuf_mapping, protocol)
def get_channel_data(self) -> ChannelData:
return self.protocol.get_channel_data()
def get_session( def get_session(
self, self,
passphrase: str | object | None = None, passphrase: str | object | None = None,
@ -234,7 +238,9 @@ class TrezorClient:
if isinstance(response, messages.Failure): if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol: if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected") LOG.debug("Protocol V2 detected")
protocol = ProtocolV2Channel(self.transport, self.mapping) protocol = ProtocolV2Channel(
self.transport, self.mapping
) # self.database
return protocol return protocol