1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-22 05:10:56 +00:00

wip fix cardano part 2

This commit is contained in:
M1nd3r 2024-11-22 16:09:10 +01:00
parent c15a1b54f5
commit 0074b62d43

View File

@ -3,6 +3,11 @@ from __future__ import annotations
import logging import logging
import typing as t import typing as t
from mnemonic import Mnemonic
from trezorlib.client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE
from trezor.enums import Capability
from .. import exceptions, messages, models from .. import exceptions, messages, models
from .thp.protocol_v1 import ProtocolV1 from .thp.protocol_v1 import ProtocolV1
from .thp.protocol_v2 import ProtocolV2 from .thp.protocol_v2 import ProtocolV2
@ -16,6 +21,7 @@ LOG = logging.getLogger(__name__)
class Session: class Session:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
def __init__(self, client: TrezorClient, id: bytes) -> None: def __init__(self, client: TrezorClient, id: bytes) -> None:
self.client = client self.client = client
@ -37,8 +43,9 @@ class Session:
raise Exception # TODO raise Exception # TODO
resp = self.pin_callback(self, resp) resp = self.pin_callback(self, resp)
elif isinstance(resp, messages.PassphraseRequest): elif isinstance(resp, messages.PassphraseRequest):
raise NotImplementedError if self.passphrase_callback is None:
# resp = self._callback_passphrase(resp) raise Exception # TODO
resp = self.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest): elif isinstance(resp, messages.ButtonRequest):
if self.button_callback is None: if self.button_callback is None:
raise Exception # TODO raise Exception # TODO
@ -97,6 +104,7 @@ class SessionV1(Session):
session = SessionV1(client, session_id) session = SessionV1(client, session_id)
session.button_callback = client.button_callback session.button_callback = client.button_callback
session.pin_callback = client.pin_callback session.pin_callback = client.pin_callback
session.passphrase_callback = _callback_passphrase
session._init_session(derive_cardano=derive_cardano) session._init_session(derive_cardano=derive_cardano)
return session return session
@ -111,17 +119,58 @@ class SessionV1(Session):
return self.client.protocol.read() return self.client.protocol.read()
def _init_session(self, derive_cardano: bool = False): def _init_session(self, derive_cardano: bool = False):
self._write( self.call_raw(
messages.Initialize(session_id=self.id, derive_cardano=derive_cardano) messages.Initialize(session_id=self.id, derive_cardano=derive_cardano)
) )
_ = self._read()
def _callback_button(session: Session, msg: t.Any) -> t.Any: def _callback_button(session: Session, msg: t.Any) -> t.Any:
print("Please confirm action on your Trezor device.") # TODO how to handle UI? print("Please confirm action on your Trezor device.") # TODO how to handle UI?
return session.call(messages.ButtonAck()) return session.call(messages.ButtonAck())
def _callback_passphrase(session: Session, msg: messages.PassphraseRequest) -> t.Any:
available_on_device = Capability.PassphraseEntry in Session.features.capabilities
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> t.Any:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
passphrase = session.client.ui.get_passphrase(
available_on_device=available_on_device
) # TODO
except exceptions.Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
class SessionV2(Session): class SessionV2(Session):
@classmethod @classmethod