1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-30 18:38:27 +00:00

wip trezorlib add passphrase sessions

This commit is contained in:
M1nd3r 2024-09-13 17:25:21 +02:00
parent a7f386f3a9
commit 6a65d62353
6 changed files with 178 additions and 44 deletions

View File

@ -14,21 +14,25 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools
import os
import sys
import typing as t
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions, transport
from .. import exceptions, transport, ui
from ..client import TrezorClient
from ..messages import Capability
from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.transport import NewTransport
from ..ui import ClickUI, ScriptUI
if TYPE_CHECKING:
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
@ -43,7 +47,10 @@ if TYPE_CHECKING:
class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive
if case_sensitive:
@ -51,7 +58,7 @@ class ChoiceType(click.Choice):
else:
self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -60,11 +67,48 @@ class ChoiceType(click.Choice):
return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
class NewTrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -73,6 +117,29 @@ class NewTrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(
self,
):
client = self.get_client()
if self.session_id is not None:
pass # TODO Try resume
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=True)
# TODO Passphrase empty by default - ???
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(passphrase=passphrase, derive_cardano=True)
return session
def get_transport(self) -> "NewTransport":
try:
# look for transport without prefix search
@ -100,6 +167,7 @@ class NewTrezorConnection:
)
else:
client = NewTrezorClient(transport)
return client
@contextmanager
@ -135,10 +203,11 @@ class NewTrezorConnection:
class TrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -205,9 +274,33 @@ class TrezorConnection:
# other exceptions may cause a traceback
from ..transport.new.session import Session
def with_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session()
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return function_with_session # type: ignore [is incompatible with return type]
def new_with_client(
func: "Callable[Concatenate[NewTrezorClient, P], R]",
) -> "Callable[P, R]":
func: "t.Callable[Concatenate[NewTrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -243,7 +336,9 @@ def new_with_client(
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -305,14 +400,14 @@ class AliasedGroup(click.Group):
def __init__(
self,
aliases: Optional[Dict[str, click.Command]] = None,
*args: Any,
**kwargs: Any,
aliases: t.Dict[str, click.Command] | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super().__init__(*args, **kwargs)
self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)

View File

@ -29,6 +29,7 @@ from ..client import TrezorClient
from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient
from ..transport.new.session import Session
from ..transport.new.udp import UdpTransport
from . import (
AliasedGroup,
@ -53,6 +54,7 @@ from . import (
stellar,
tezos,
with_client,
with_session,
)
F = TypeVar("F", bound=Callable)
@ -334,10 +336,14 @@ def version() -> str:
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_client
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
@with_session
def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.call(
messages.Ping(message=message, button_protection=button_protection)
)
@cli.command()

View File

@ -28,9 +28,11 @@ class NewTrezorClient:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
if protocol is None:
try:
self.protocol = self._get_protocol()
except Exception as e:
print(e)
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
@ -52,9 +54,8 @@ class NewTrezorClient:
def get_session(
self,
passphrase: str = "",
passphrase: str | None = None,
derive_cardano: bool = False,
management_session: bool = False,
) -> Session:
if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano)

View File

@ -4,6 +4,7 @@ import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from ...mapping import ProtobufMapping
from .channel_data import ChannelData
@ -30,12 +31,28 @@ class ProtocolAndChannel:
# def read(self, session_id: bytes) -> t.Any: ...
def get_features(self) -> messages.Features:
raise NotImplementedError()
def get_channel_data(self) -> ChannelData:
raise NotImplementedError
class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features
_has_valid_features: bool = False
def get_features(self) -> messages.Features:
if not self._has_valid_features:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
self._has_valid_features = True
return self._features
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()

View File

@ -10,12 +10,12 @@ from enum import IntEnum
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import messages
from ... import exceptions, messages
from ...mapping import ProtobufMapping
from ..thp import checksum, curve25519, thp_io
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.packet_header import PacketHeader
from . import control_byte
from . import channel_database, control_byte
from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel
from .transport import NewTransport
@ -56,9 +56,9 @@ class ProtocolV2(ProtocolAndChannel):
sync_bit_send: int
sync_bit_receive: int
has_valid_channel: bool = False
has_valid_features: bool = False
features: messages.Features
_has_valid_channel: bool = False
_has_valid_features: bool = False
_features: messages.Features
def __init__(
self,
@ -75,13 +75,14 @@ class ProtocolV2(ProtocolAndChannel):
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.has_valid_channel = True
self._has_valid_channel = True
def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel:
if not self._has_valid_channel:
self._establish_new_channel()
if not self.has_valid_features:
self.update_features()
# TODO - Q: should ask for features now or when needed?
# if not self.has_valid_features:
# self.update_features()
return self
def get_channel_data(self) -> ChannelData:
@ -98,12 +99,23 @@ class ProtocolV2(ProtocolAndChannel):
)
def read(self, session_id: int) -> t.Any:
header, data = self._read_until_valid_crc_check()
# TODO
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on different session.")
channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if not self._has_valid_features:
self.update_features()
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
@ -111,11 +123,12 @@ class ProtocolV2(ProtocolAndChannel):
self.session_id: int = 0
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
session_id, msg_type, msg_data = self.read_and_decrypt()
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features)
self.features = features
self.has_valid_features = True
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
self._has_valid_features = True
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0
@ -260,7 +273,7 @@ class ProtocolV2(ProtocolAndChannel):
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self.has_valid_channel = True
self._has_valid_channel = True
def _send_ack_0(self):
LOG.debug("sending ack 0")
@ -302,8 +315,13 @@ class ProtocolV2(ProtocolAndChannel):
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
return self.read_and_decrypt()
if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!")
print(
hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode()
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(

View File

@ -19,7 +19,7 @@ class Session:
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> Session:
raise NotImplementedError
@ -30,7 +30,7 @@ class Session:
class SessionV1(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"")
@ -54,9 +54,8 @@ class SessionV2(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2)
session = SessionV2(client, b"\x00")
new_session: ThpNewSession = session.call(
@ -73,9 +72,7 @@ class SessionV2(Session):
self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id)
if not self.channel.has_valid_features:
self.channel.update_features()
self.features = self.channel.features
self.features = self.channel.get_features()
def call(self, msg: t.Any) -> t.Any:
self.channel.write(self.sid, msg)