1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00
This commit is contained in:
M1nd3r 2025-02-17 17:13:14 +01:00
parent f51e517969
commit ac12643121
2 changed files with 20 additions and 6 deletions

View File

@ -33,6 +33,7 @@ from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
Capability,
TrezorConnection,
benchmark,
binance,
@ -45,6 +46,7 @@ from . import (
ethereum,
fido,
firmware,
get_passphrase,
monero,
nem,
ripple,
@ -225,7 +227,7 @@ def cli_main(
bytes_session_id = bytes.fromhex(session_id)
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
# channel database = get_db(should_not_store=no_store)
ctx.obj = TrezorConnection(
path, bytes_session_id, passphrase_on_host, script, get_channel_db()
)
@ -342,10 +344,14 @@ def ping(session: "Session", message: str, button_protection: bool) -> str:
@cli.command()
@click.option(
"-c",
"--derive-cardano",
is_flag=True,
help="Should the session have cardano seed derived.",
)
@click.pass_obj
def get_session(
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
def get_session(obj: TrezorConnection, derive_cardano: bool) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -363,11 +369,13 @@ def get_session(
raise click.ClickException(
"Upgrade your firmware to enable session support."
)
available_on_device = Capability.PassphraseEntry in client.features.capabilities
# client.ensure_unlocked()
passphrase = get_passphrase(available_on_device, obj.passphrase_on_host)
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
) # TODO add arguments to get_passphrase, add session id=1 or from user input
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.")
else:

View File

@ -71,6 +71,7 @@ class TrezorClient:
transport: Transport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: Channel | None = None,
# channel_database
) -> None:
self._is_invalidated: bool = False
self.transport = transport
@ -126,6 +127,9 @@ class TrezorClient:
protocol = ProtocolV1Channel(transport, protobuf_mapping, channel_data)
return TrezorClient(transport, protobuf_mapping, protocol)
def get_channel_data(self) -> ChannelData:
return self.protocol.get_channel_data()
def get_session(
self,
passphrase: str | object | None = None,
@ -234,7 +238,9 @@ class TrezorClient:
if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2Channel(self.transport, self.mapping)
protocol = ProtocolV2Channel(
self.transport, self.mapping
) # self.database
return protocol