mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-30 02:18:16 +00:00
wip trezorlib add passphrase sessions
This commit is contained in:
parent
a7f386f3a9
commit
6a65d62353
@ -14,21 +14,25 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport.new import channel_database
|
||||
from ..transport.new.client import NewTrezorClient
|
||||
from ..transport.new.transport import NewTransport
|
||||
from ..ui import ClickUI, ScriptUI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
@ -43,7 +47,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ChoiceType(click.Choice):
|
||||
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
|
||||
|
||||
def __init__(
|
||||
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
|
||||
) -> None:
|
||||
super().__init__(list(typemap.keys()))
|
||||
self.case_sensitive = case_sensitive
|
||||
if case_sensitive:
|
||||
@ -51,7 +58,7 @@ class ChoiceType(click.Choice):
|
||||
else:
|
||||
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
||||
|
||||
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
|
||||
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
|
||||
if value in self.typemap.values():
|
||||
return value
|
||||
value = super().convert(value, param, ctx)
|
||||
@ -60,11 +67,48 @@ class ChoiceType(click.Choice):
|
||||
return self.typemap[value]
|
||||
|
||||
|
||||
def get_passphrase(
|
||||
passphrase_on_host: bool, available_on_device: bool
|
||||
) -> t.Union[str, object]:
|
||||
if available_on_device and not passphrase_on_host:
|
||||
return ui.PASSPHRASE_ON_DEVICE
|
||||
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return env_passphrase
|
||||
|
||||
while True:
|
||||
try:
|
||||
passphrase = ui.prompt(
|
||||
"Passphrase required",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
# In case user sees the input on the screen, we do not need confirmation
|
||||
if not ui.CAN_HANDLE_HIDDEN_INPUT:
|
||||
return passphrase
|
||||
second = ui.prompt(
|
||||
"Confirm your passphrase",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
else:
|
||||
ui.echo("Passphrase did not match. Please try again.")
|
||||
except click.Abort:
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
class NewTrezorConnection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
session_id: Optional[bytes],
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
) -> None:
|
||||
@ -73,6 +117,29 @@ class NewTrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
):
|
||||
client = self.get_client()
|
||||
|
||||
if self.session_id is not None:
|
||||
pass # TODO Try resume
|
||||
features = client.protocol.get_features()
|
||||
|
||||
passphrase_enabled = True # TODO what to do here?
|
||||
|
||||
if not passphrase_enabled:
|
||||
return client.get_session(derive_cardano=True)
|
||||
|
||||
# TODO Passphrase empty by default - ???
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
session = client.get_session(passphrase=passphrase, derive_cardano=True)
|
||||
return session
|
||||
|
||||
def get_transport(self) -> "NewTransport":
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
@ -100,6 +167,7 @@ class NewTrezorConnection:
|
||||
)
|
||||
else:
|
||||
client = NewTrezorClient(transport)
|
||||
|
||||
return client
|
||||
|
||||
@contextmanager
|
||||
@ -135,10 +203,11 @@ class NewTrezorConnection:
|
||||
|
||||
|
||||
class TrezorConnection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
session_id: Optional[bytes],
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
) -> None:
|
||||
@ -205,9 +274,33 @@ class TrezorConnection:
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
|
||||
from ..transport.new.session import Session
|
||||
|
||||
|
||||
def with_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_session(
|
||||
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
session = obj.get_session()
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
# TODO try end session if not resumed
|
||||
|
||||
# the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
return function_with_session # type: ignore [is incompatible with return type]
|
||||
|
||||
|
||||
def new_with_client(
|
||||
func: "Callable[Concatenate[NewTrezorClient, P], R]",
|
||||
) -> "Callable[P, R]":
|
||||
func: "t.Callable[Concatenate[NewTrezorClient, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
@ -243,7 +336,9 @@ def new_with_client(
|
||||
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
|
||||
|
||||
|
||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||
def with_client(
|
||||
func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
@ -305,14 +400,14 @@ class AliasedGroup(click.Group):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aliases: Optional[Dict[str, click.Command]] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
aliases: t.Dict[str, click.Command] | None = None,
|
||||
*args: t.Any,
|
||||
**kwargs: t.Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aliases = aliases or {}
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
cmd_name = cmd_name.replace("_", "-")
|
||||
# try to look up the real name
|
||||
cmd = super().get_command(ctx, cmd_name)
|
||||
|
@ -29,6 +29,7 @@ from ..client import TrezorClient
|
||||
from ..transport import DeviceIsBusy, new_enumerate_devices
|
||||
from ..transport.new import channel_database
|
||||
from ..transport.new.client import NewTrezorClient
|
||||
from ..transport.new.session import Session
|
||||
from ..transport.new.udp import UdpTransport
|
||||
from . import (
|
||||
AliasedGroup,
|
||||
@ -53,6 +54,7 @@ from . import (
|
||||
stellar,
|
||||
tezos,
|
||||
with_client,
|
||||
with_session,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
@ -334,10 +336,14 @@ def version() -> str:
|
||||
@cli.command()
|
||||
@click.argument("message")
|
||||
@click.option("-b", "--button-protection", is_flag=True)
|
||||
@with_client
|
||||
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
||||
@with_session
|
||||
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||
"""Send ping message."""
|
||||
return client.ping(message, button_protection=button_protection)
|
||||
|
||||
# TODO return short-circuit from old client for old Trezors
|
||||
return session.call(
|
||||
messages.Ping(message=message, button_protection=button_protection)
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -28,9 +28,11 @@ class NewTrezorClient:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
|
||||
if protocol is None:
|
||||
self.protocol = self._get_protocol()
|
||||
try:
|
||||
self.protocol = self._get_protocol()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
@ -52,9 +54,8 @@ class NewTrezorClient:
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str = "",
|
||||
passphrase: str | None = None,
|
||||
derive_cardano: bool = False,
|
||||
management_session: bool = False,
|
||||
) -> Session:
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
return SessionV1.new(self, passphrase, derive_cardano)
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from ...mapping import ProtobufMapping
|
||||
from .channel_data import ChannelData
|
||||
@ -30,12 +31,28 @@ class ProtocolAndChannel:
|
||||
|
||||
# def read(self, session_id: bytes) -> t.Any: ...
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
_features: messages.Features
|
||||
_has_valid_features: bool = False
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_features:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
self._has_valid_features = True
|
||||
|
||||
return self._features
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
|
@ -10,12 +10,12 @@ from enum import IntEnum
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import messages
|
||||
from ... import exceptions, messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.packet_header import PacketHeader
|
||||
from . import control_byte
|
||||
from . import channel_database, control_byte
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
from .transport import NewTransport
|
||||
@ -56,9 +56,9 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
|
||||
has_valid_channel: bool = False
|
||||
has_valid_features: bool = False
|
||||
features: messages.Features
|
||||
_has_valid_channel: bool = False
|
||||
_has_valid_features: bool = False
|
||||
_features: messages.Features
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -75,13 +75,14 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
self.nonce_response = channel_data.nonce_response
|
||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||
self.sync_bit_send = channel_data.sync_bit_send
|
||||
self.has_valid_channel = True
|
||||
self._has_valid_channel = True
|
||||
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self.has_valid_channel:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
if not self.has_valid_features:
|
||||
self.update_features()
|
||||
# TODO - Q: should ask for features now or when needed?
|
||||
# if not self.has_valid_features:
|
||||
# self.update_features()
|
||||
return self
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
@ -98,12 +99,23 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
)
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
header, data = self._read_until_valid_crc_check()
|
||||
# TODO
|
||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||
if sid != session_id:
|
||||
raise Exception("Received messsage on different session.")
|
||||
channel_database.save_channel(self)
|
||||
return self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||
channel_database.save_channel(self)
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
if not self._has_valid_features:
|
||||
self.update_features()
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
message = messages.GetFeatures()
|
||||
@ -111,11 +123,12 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
self.session_id: int = 0
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||
session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
features = self.mapping.decode(msg_type, msg_data)
|
||||
assert isinstance(features, messages.Features)
|
||||
self.features = features
|
||||
self.has_valid_features = True
|
||||
if not isinstance(features, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = features
|
||||
self._has_valid_features = True
|
||||
|
||||
def _establish_new_channel(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
@ -260,7 +273,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
maaa = self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
self.has_valid_channel = True
|
||||
self._has_valid_channel = True
|
||||
|
||||
def _send_ack_0(self):
|
||||
LOG.debug("sending ack 0")
|
||||
@ -302,8 +315,13 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if control_byte.is_ack(header.ctrl_byte):
|
||||
return self.read_and_decrypt()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
print(
|
||||
hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode()
|
||||
)
|
||||
|
||||
if not control_byte.is_ack(header.ctrl_byte):
|
||||
LOG.debug(
|
||||
|
@ -19,7 +19,7 @@ class Session:
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
|
||||
) -> Session:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -30,7 +30,7 @@ class Session:
|
||||
class SessionV1(Session):
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, b"")
|
||||
@ -54,9 +54,8 @@ class SessionV2(Session):
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool
|
||||
) -> SessionV2:
|
||||
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
session = SessionV2(client, b"\x00")
|
||||
new_session: ThpNewSession = session.call(
|
||||
@ -73,9 +72,7 @@ class SessionV2(Session):
|
||||
|
||||
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||
self.update_id_and_sid(id)
|
||||
if not self.channel.has_valid_features:
|
||||
self.channel.update_features()
|
||||
self.features = self.channel.features
|
||||
self.features = self.channel.get_features()
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
self.channel.write(self.sid, msg)
|
||||
|
Loading…
Reference in New Issue
Block a user