mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
wip trezorlib
This commit is contained in:
parent
11309cccd0
commit
d59684988c
@ -289,12 +289,10 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
|
||||
|
||||
for transport in new_enumerate_devices():
|
||||
try:
|
||||
print("test A")
|
||||
client = NewTrezorClient(transport)
|
||||
session = client.get_session()
|
||||
session = client.get_management_session()
|
||||
description = format_device_name(session.features)
|
||||
# client.end_session()
|
||||
print("after end session")
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
except Exception:
|
||||
|
@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
||||
from . import messages, protobuf
|
||||
|
||||
T = TypeVar("T")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtobufMapping:
|
||||
@ -63,7 +65,7 @@ class ProtobufMapping:
|
||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||
if wire_type is None:
|
||||
raise ValueError("Cannot encode class without wire type")
|
||||
print("wire type", wire_type)
|
||||
LOG.debug("encoding wire type %d", wire_type)
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return wire_type, buf.getvalue()
|
||||
|
@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NewTrezorClient:
|
||||
management_session: Session | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
@ -26,10 +28,8 @@ class NewTrezorClient:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
print("test B")
|
||||
|
||||
if protocol is None:
|
||||
print("test C")
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
@ -40,7 +40,10 @@ class NewTrezorClient:
|
||||
) -> NewTrezorClient: ...
|
||||
|
||||
def get_session(
|
||||
self, passphrase: str = "", derive_cardano: bool = False
|
||||
self,
|
||||
passphrase: str = "",
|
||||
derive_cardano: bool = False,
|
||||
management_session: bool = False,
|
||||
) -> Session:
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
return SessionV1.new(self, passphrase, derive_cardano)
|
||||
@ -48,6 +51,18 @@ class NewTrezorClient:
|
||||
return SessionV2.new(self, passphrase, derive_cardano)
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def get_management_session(self):
|
||||
if self.management_session is not None:
|
||||
return self.management_session
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self.management_session = SessionV1.new(self, "", False)
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self.management_session = SessionV2(self, b"\x00")
|
||||
|
||||
assert self.management_session is not None
|
||||
return self.management_session
|
||||
|
||||
def resume_session(self, session_id: bytes) -> Session:
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
@ -66,7 +81,6 @@ class NewTrezorClient:
|
||||
response = protocol.read()
|
||||
self.transport.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
print("test F1")
|
||||
if (
|
||||
response.code == FailureType.UnexpectedMessage
|
||||
and response.message == "Invalid protocol"
|
||||
|
@ -61,11 +61,9 @@ class ProtocolV1(ProtocolAndChannel):
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
print("wooooo")
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
print("wooooo")
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
|
@ -54,6 +54,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
has_valid_channel: bool = False
|
||||
features: messages.Features
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,6 +68,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self.has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
self.update_features()
|
||||
return self
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
@ -77,7 +79,25 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
|
||||
|
||||
def _establish_new_channel(self):
|
||||
def update_features(self) -> None:
|
||||
message = messages.GetFeatures()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
|
||||
self.session_id: int = 0
|
||||
self._encrypt_and_write(
|
||||
MANAGEMENT_SESSION_ID,
|
||||
message_type,
|
||||
message_data,
|
||||
0x14, # TODO update control byte
|
||||
)
|
||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||
session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||
features = self.mapping.decode(msg_type, msg_data)
|
||||
assert isinstance(features, messages.Features)
|
||||
self.features = features
|
||||
self._send_ack_2()
|
||||
|
||||
def _establish_new_channel(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
# Send channel allocation request
|
||||
@ -124,7 +144,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
print("noise_tag: ", hexlify(noise_tag).decode())
|
||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
@ -143,7 +163,6 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||
except Exception as e:
|
||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
@ -259,7 +278,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
self.transport, header, encrypted_message
|
||||
)
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
|
||||
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
@ -272,7 +291,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
int.to_bytes(session_id, 1, "big"),
|
||||
session_id,
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...messages import Features, Initialize
|
||||
from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession
|
||||
from .protocol_and_channel import ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
|
||||
@ -34,11 +34,11 @@ class SessionV1(Session):
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, b"")
|
||||
cls.features = session.call(
|
||||
session.features = session.call(
|
||||
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
|
||||
Initialize()
|
||||
)
|
||||
session.id = cls.features.session_id
|
||||
session.id = session.features.session_id
|
||||
return session
|
||||
|
||||
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
|
||||
@ -51,15 +51,32 @@ class SessionV1(Session):
|
||||
|
||||
|
||||
class SessionV2(Session):
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
) -> SessionV2:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV2(client, b"\x00")
|
||||
new_session: ThpNewSession = session.call(
|
||||
ThpCreateNewSession(passphrase=passphrase, derive_cardano=derive_cardano)
|
||||
)
|
||||
assert new_session.new_session_id is not None
|
||||
session_id = new_session.new_session_id
|
||||
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||
return session
|
||||
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
self.channel = client.protocol.get_channel()
|
||||
self.sid = self._convert_id_to_sid(id)
|
||||
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||
self.update_id_and_sid(id)
|
||||
self.features = self.channel.features
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
self.channel.write(self.sid, msg)
|
||||
return self.channel.read(self.sid)
|
||||
|
||||
def _convert_id_to_sid(self, id: bytes) -> int:
|
||||
return int.from_bytes(id, "big") # TODO update to extract only sid
|
||||
def update_id_and_sid(self, id: bytes) -> None:
|
||||
self.id = id
|
||||
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
||||
|
Loading…
Reference in New Issue
Block a user