mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 22:40:59 +00:00
wip fix cardano part 2
This commit is contained in:
parent
ecfc3626f9
commit
5d0a52837b
@ -3,6 +3,11 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from mnemonic import Mnemonic
|
||||||
|
from trezorlib.client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE
|
||||||
|
|
||||||
|
from trezor.enums import Capability
|
||||||
|
|
||||||
from .. import exceptions, messages, models
|
from .. import exceptions, messages, models
|
||||||
from .thp.protocol_v1 import ProtocolV1
|
from .thp.protocol_v1 import ProtocolV1
|
||||||
from .thp.protocol_v2 import ProtocolV2
|
from .thp.protocol_v2 import ProtocolV2
|
||||||
@ -16,6 +21,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
class Session:
|
class Session:
|
||||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
|
||||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
@ -37,8 +43,9 @@ class Session:
|
|||||||
raise Exception # TODO
|
raise Exception # TODO
|
||||||
resp = self.pin_callback(self, resp)
|
resp = self.pin_callback(self, resp)
|
||||||
elif isinstance(resp, messages.PassphraseRequest):
|
elif isinstance(resp, messages.PassphraseRequest):
|
||||||
raise NotImplementedError
|
if self.passphrase_callback is None:
|
||||||
# resp = self._callback_passphrase(resp)
|
raise Exception # TODO
|
||||||
|
resp = self.passphrase_callback(self, resp)
|
||||||
elif isinstance(resp, messages.ButtonRequest):
|
elif isinstance(resp, messages.ButtonRequest):
|
||||||
if self.button_callback is None:
|
if self.button_callback is None:
|
||||||
raise Exception # TODO
|
raise Exception # TODO
|
||||||
@ -97,6 +104,7 @@ class SessionV1(Session):
|
|||||||
session = SessionV1(client, session_id)
|
session = SessionV1(client, session_id)
|
||||||
session.button_callback = client.button_callback
|
session.button_callback = client.button_callback
|
||||||
session.pin_callback = client.pin_callback
|
session.pin_callback = client.pin_callback
|
||||||
|
session.passphrase_callback = _callback_passphrase
|
||||||
session._init_session(derive_cardano=derive_cardano)
|
session._init_session(derive_cardano=derive_cardano)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@ -111,17 +119,58 @@ class SessionV1(Session):
|
|||||||
return self.client.protocol.read()
|
return self.client.protocol.read()
|
||||||
|
|
||||||
def _init_session(self, derive_cardano: bool = False):
|
def _init_session(self, derive_cardano: bool = False):
|
||||||
self._write(
|
self.call_raw(
|
||||||
messages.Initialize(session_id=self.id, derive_cardano=derive_cardano)
|
messages.Initialize(session_id=self.id, derive_cardano=derive_cardano)
|
||||||
)
|
)
|
||||||
_ = self._read()
|
|
||||||
|
|
||||||
|
|
||||||
def _callback_button(session: Session, msg: t.Any) -> t.Any:
|
def _callback_button(session: Session, msg: t.Any) -> t.Any:
|
||||||
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
|
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
|
||||||
return session.call(messages.ButtonAck())
|
return session.call(messages.ButtonAck())
|
||||||
|
|
||||||
|
|
||||||
|
def _callback_passphrase(session: Session, msg: messages.PassphraseRequest) -> t.Any:
|
||||||
|
available_on_device = Capability.PassphraseEntry in Session.features.capabilities
|
||||||
|
|
||||||
|
def send_passphrase(
|
||||||
|
passphrase: str | None = None, on_device: bool | None = None
|
||||||
|
) -> t.Any:
|
||||||
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
|
resp = session.call_raw(msg)
|
||||||
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
|
session.session_id = resp.state
|
||||||
|
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# short-circuit old style entry
|
||||||
|
if msg._on_device is True:
|
||||||
|
return send_passphrase(None, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
passphrase = session.client.ui.get_passphrase(
|
||||||
|
available_on_device=available_on_device
|
||||||
|
) # TODO
|
||||||
|
except exceptions.Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||||
|
if not available_on_device:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise RuntimeError("Device is not capable of entering passphrase")
|
||||||
|
else:
|
||||||
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
|
# else process host-entered passphrase
|
||||||
|
if not isinstance(passphrase, str):
|
||||||
|
raise RuntimeError("Passphrase must be a str")
|
||||||
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Passphrase too long")
|
||||||
|
|
||||||
|
return send_passphrase(passphrase, on_device=False)
|
||||||
|
|
||||||
|
|
||||||
class SessionV2(Session):
|
class SessionV2(Session):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user