1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-30 18:38:27 +00:00

wip trezorlib add passphrase sessions

This commit is contained in:
M1nd3r 2024-09-13 17:25:21 +02:00
parent a7f386f3a9
commit 6a65d62353
6 changed files with 178 additions and 44 deletions

View File

@ -14,21 +14,25 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools import functools
import os
import sys import sys
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
from .. import exceptions, transport from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import TrezorClient
from ..messages import Capability
from ..transport.new import channel_database from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient from ..transport.new.client import NewTrezorClient
from ..transport.new.transport import NewTransport from ..transport.new.transport import NewTransport
from ..ui import ClickUI, ScriptUI from ..ui import ClickUI, ScriptUI
if TYPE_CHECKING: if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators # Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar from typing import TypeVar
@ -43,7 +47,10 @@ if TYPE_CHECKING:
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys())) super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive self.case_sensitive = case_sensitive
if case_sensitive: if case_sensitive:
@ -51,7 +58,7 @@ class ChoiceType(click.Choice):
else: else:
self.typemap = {k.lower(): v for k, v in typemap.items()} self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values(): if value in self.typemap.values():
return value return value
value = super().convert(value, param, ctx) value = super().convert(value, param, ctx)
@ -60,11 +67,48 @@ class ChoiceType(click.Choice):
return self.typemap[value] return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
class NewTrezorConnection: class NewTrezorConnection:
def __init__( def __init__(
self, self,
path: str, path: str,
session_id: Optional[bytes], session_id: bytes | None,
passphrase_on_host: bool, passphrase_on_host: bool,
script: bool, script: bool,
) -> None: ) -> None:
@ -73,6 +117,29 @@ class NewTrezorConnection:
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
self.script = script self.script = script
def get_session(
self,
):
client = self.get_client()
if self.session_id is not None:
pass # TODO Try resume
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=True)
# TODO Passphrase empty by default - ???
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(passphrase=passphrase, derive_cardano=True)
return session
def get_transport(self) -> "NewTransport": def get_transport(self) -> "NewTransport":
try: try:
# look for transport without prefix search # look for transport without prefix search
@ -100,6 +167,7 @@ class NewTrezorConnection:
) )
else: else:
client = NewTrezorClient(transport) client = NewTrezorClient(transport)
return client return client
@contextmanager @contextmanager
@ -135,10 +203,11 @@ class NewTrezorConnection:
class TrezorConnection: class TrezorConnection:
def __init__( def __init__(
self, self,
path: str, path: str,
session_id: Optional[bytes], session_id: bytes | None,
passphrase_on_host: bool, passphrase_on_host: bool,
script: bool, script: bool,
) -> None: ) -> None:
@ -205,9 +274,33 @@ class TrezorConnection:
# other exceptions may cause a traceback # other exceptions may cause a traceback
from ..transport.new.session import Session
def with_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session()
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return function_with_session # type: ignore [is incompatible with return type]
def new_with_client( def new_with_client(
func: "Callable[Concatenate[NewTrezorClient, P], R]", func: "t.Callable[Concatenate[NewTrezorClient, P], R]",
) -> "Callable[P, R]": ) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume Sessions are handled transparently. The user is warned when session did not resume
@ -243,7 +336,9 @@ def new_with_client(
return trezorctl_command_with_client # type: ignore [is incompatible with return type] return trezorctl_command_with_client # type: ignore [is incompatible with return type]
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume Sessions are handled transparently. The user is warned when session did not resume
@ -305,14 +400,14 @@ class AliasedGroup(click.Group):
def __init__( def __init__(
self, self,
aliases: Optional[Dict[str, click.Command]] = None, aliases: t.Dict[str, click.Command] | None = None,
*args: Any, *args: t.Any,
**kwargs: Any, **kwargs: t.Any,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.aliases = aliases or {} self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-") cmd_name = cmd_name.replace("_", "-")
# try to look up the real name # try to look up the real name
cmd = super().get_command(ctx, cmd_name) cmd = super().get_command(ctx, cmd_name)

View File

@ -29,6 +29,7 @@ from ..client import TrezorClient
from ..transport import DeviceIsBusy, new_enumerate_devices from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new import channel_database from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient from ..transport.new.client import NewTrezorClient
from ..transport.new.session import Session
from ..transport.new.udp import UdpTransport from ..transport.new.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
@ -53,6 +54,7 @@ from . import (
stellar, stellar,
tezos, tezos,
with_client, with_client,
with_session,
) )
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
@ -334,10 +336,14 @@ def version() -> str:
@cli.command() @cli.command()
@click.argument("message") @click.argument("message")
@click.option("-b", "--button-protection", is_flag=True) @click.option("-b", "--button-protection", is_flag=True)
@with_client @with_session
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message.""" """Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.call(
messages.Ping(message=message, button_protection=button_protection)
)
@cli.command() @cli.command()

View File

@ -28,9 +28,11 @@ class NewTrezorClient:
self.mapping = mapping.DEFAULT_MAPPING self.mapping = mapping.DEFAULT_MAPPING
else: else:
self.mapping = protobuf_mapping self.mapping = protobuf_mapping
if protocol is None: if protocol is None:
self.protocol = self._get_protocol() try:
self.protocol = self._get_protocol()
except Exception as e:
print(e)
else: else:
self.protocol = protocol self.protocol = protocol
self.protocol.mapping = self.mapping self.protocol.mapping = self.mapping
@ -52,9 +54,8 @@ class NewTrezorClient:
def get_session( def get_session(
self, self,
passphrase: str = "", passphrase: str | None = None,
derive_cardano: bool = False, derive_cardano: bool = False,
management_session: bool = False,
) -> Session: ) -> Session:
if isinstance(self.protocol, ProtocolV1): if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano) return SessionV1.new(self, passphrase, derive_cardano)

View File

@ -4,6 +4,7 @@ import logging
import struct import struct
import typing as t import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES from ...log import DUMP_BYTES
from ...mapping import ProtobufMapping from ...mapping import ProtobufMapping
from .channel_data import ChannelData from .channel_data import ChannelData
@ -30,12 +31,28 @@ class ProtocolAndChannel:
# def read(self, session_id: bytes) -> t.Any: ... # def read(self, session_id: bytes) -> t.Any: ...
def get_features(self) -> messages.Features:
raise NotImplementedError()
def get_channel_data(self) -> ChannelData: def get_channel_data(self) -> ChannelData:
raise NotImplementedError raise NotImplementedError
class ProtocolV1(ProtocolAndChannel): class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL") HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features
_has_valid_features: bool = False
def get_features(self) -> messages.Features:
if not self._has_valid_features:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
self._has_valid_features = True
return self._features
def read(self) -> t.Any: def read(self) -> t.Any:
msg_type, msg_bytes = self._read() msg_type, msg_bytes = self._read()

View File

@ -10,12 +10,12 @@ from enum import IntEnum
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import messages from ... import exceptions, messages
from ...mapping import ProtobufMapping from ...mapping import ProtobufMapping
from ..thp import checksum, curve25519, thp_io from ..thp import checksum, curve25519, thp_io
from ..thp.checksum import CHECKSUM_LENGTH from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.packet_header import PacketHeader from ..thp.packet_header import PacketHeader
from . import control_byte from . import channel_database, control_byte
from .channel_data import ChannelData from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel from .protocol_and_channel import ProtocolAndChannel
from .transport import NewTransport from .transport import NewTransport
@ -56,9 +56,9 @@ class ProtocolV2(ProtocolAndChannel):
sync_bit_send: int sync_bit_send: int
sync_bit_receive: int sync_bit_receive: int
has_valid_channel: bool = False _has_valid_channel: bool = False
has_valid_features: bool = False _has_valid_features: bool = False
features: messages.Features _features: messages.Features
def __init__( def __init__(
self, self,
@ -75,13 +75,14 @@ class ProtocolV2(ProtocolAndChannel):
self.nonce_response = channel_data.nonce_response self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send self.sync_bit_send = channel_data.sync_bit_send
self.has_valid_channel = True self._has_valid_channel = True
def get_channel(self) -> ProtocolV2: def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel: if not self._has_valid_channel:
self._establish_new_channel() self._establish_new_channel()
if not self.has_valid_features: # TODO - Q: should ask for features now or when needed?
self.update_features() # if not self.has_valid_features:
# self.update_features()
return self return self
def get_channel_data(self) -> ChannelData: def get_channel_data(self) -> ChannelData:
@ -98,12 +99,23 @@ class ProtocolV2(ProtocolAndChannel):
) )
def read(self, session_id: int) -> t.Any: def read(self, session_id: int) -> t.Any:
header, data = self._read_until_valid_crc_check() sid, msg_type, msg_data = self.read_and_decrypt()
# TODO if sid != session_id:
raise Exception("Received messsage on different session.")
channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None: def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg) msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data) self._encrypt_and_write(session_id, msg_type, msg_data)
channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if not self._has_valid_features:
self.update_features()
return self._features
def update_features(self) -> None: def update_features(self) -> None:
message = messages.GetFeatures() message = messages.GetFeatures()
@ -111,11 +123,12 @@ class ProtocolV2(ProtocolAndChannel):
self.session_id: int = 0 self.session_id: int = 0
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK _ = self._read_until_valid_crc_check() # TODO check ACK
session_id, msg_type, msg_data = self.read_and_decrypt() _, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data) features = self.mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features) if not isinstance(features, messages.Features):
self.features = features raise exceptions.TrezorException("Unexpected response to GetFeatures")
self.has_valid_features = True self._features = features
self._has_valid_features = True
def _establish_new_channel(self) -> None: def _establish_new_channel(self) -> None:
self.sync_bit_send = 0 self.sync_bit_send = 0
@ -260,7 +273,7 @@ class ProtocolV2(ProtocolAndChannel):
maaa = self.mapping.decode(msg_type, msg_data) maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse) assert isinstance(maaa, messages.ThpEndResponse)
self.has_valid_channel = True self._has_valid_channel = True
def _send_ack_0(self): def _send_ack_0(self):
LOG.debug("sending ack 0") LOG.debug("sending ack 0")
@ -302,8 +315,13 @@ class ProtocolV2(ProtocolAndChannel):
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check() header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
return self.read_and_decrypt()
if not header.is_encrypted_transport(): if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!") print("Trying to decrypt not encrypted message!")
print(
hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode()
)
if not control_byte.is_ack(header.ctrl_byte): if not control_byte.is_ack(header.ctrl_byte):
LOG.debug( LOG.debug(

View File

@ -19,7 +19,7 @@ class Session:
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> Session: ) -> Session:
raise NotImplementedError raise NotImplementedError
@ -30,7 +30,7 @@ class Session:
class SessionV1(Session): class SessionV1(Session):
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV1: ) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1) assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"") session = SessionV1(client, b"")
@ -54,9 +54,8 @@ class SessionV2(Session):
@classmethod @classmethod
def new( def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2: ) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2) assert isinstance(client.protocol, ProtocolV2)
session = SessionV2(client, b"\x00") session = SessionV2(client, b"\x00")
new_session: ThpNewSession = session.call( new_session: ThpNewSession = session.call(
@ -73,9 +72,7 @@ class SessionV2(Session):
self.channel: ProtocolV2 = client.protocol.get_channel() self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id) self.update_id_and_sid(id)
if not self.channel.has_valid_features: self.features = self.channel.get_features()
self.channel.update_features()
self.features = self.channel.features
def call(self, msg: t.Any) -> t.Any: def call(self, msg: t.Any) -> t.Any:
self.channel.write(self.sid, msg) self.channel.write(self.sid, msg)