wip trezorlib

M1nd3r/thp-improved
M1nd3r 1 week ago
parent 11309cccd0
commit d59684988c

@ -289,12 +289,10 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
for transport in new_enumerate_devices():
try:
print("test A")
client = NewTrezorClient(transport)
session = client.get_session()
session = client.get_management_session()
description = format_device_name(session.features)
# client.end_session()
print("after end session")
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:

@ -17,6 +17,7 @@
from __future__ import annotations
import io
import logging
from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar
@ -25,6 +26,7 @@ from typing_extensions import Self
from . import messages, protobuf
T = TypeVar("T")
LOG = logging.getLogger(__name__)
class ProtobufMapping:
@ -63,7 +65,7 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None:
raise ValueError("Cannot encode class without wire type")
print("wire type", wire_type)
LOG.debug("encoding wire type %d", wire_type)
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue()

@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
class NewTrezorClient:
management_session: Session | None = None
def __init__(
self,
transport: NewTransport,
@ -26,10 +28,8 @@ class NewTrezorClient:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
print("test B")
if protocol is None:
print("test C")
self.protocol = self._get_protocol()
else:
self.protocol = protocol
@ -40,7 +40,10 @@ class NewTrezorClient:
) -> NewTrezorClient: ...
def get_session(
self, passphrase: str = "", derive_cardano: bool = False
self,
passphrase: str = "",
derive_cardano: bool = False,
management_session: bool = False,
) -> Session:
if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano)
@ -48,6 +51,18 @@ class NewTrezorClient:
return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
def get_management_session(self):
if self.management_session is not None:
return self.management_session
if isinstance(self.protocol, ProtocolV1):
self.management_session = SessionV1.new(self, "", False)
elif isinstance(self.protocol, ProtocolV2):
self.management_session = SessionV2(self, b"\x00")
assert self.management_session is not None
return self.management_session
def resume_session(self, session_id: bytes) -> Session:
raise NotImplementedError # TODO
@ -66,7 +81,6 @@ class NewTrezorClient:
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
print("test F1")
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"

@ -61,11 +61,9 @@ class ProtocolV1(ProtocolAndChannel):
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
print("wooooo")
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
print("wooooo")
while buffer:
# Report ID, data padded to 63 bytes

@ -54,6 +54,7 @@ class ProtocolV2(ProtocolAndChannel):
sync_bit_send: int
sync_bit_receive: int
has_valid_channel: bool = False
features: messages.Features
def __init__(
self,
@ -67,6 +68,7 @@ class ProtocolV2(ProtocolAndChannel):
def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel:
self._establish_new_channel()
self.update_features()
return self
def read(self, session_id: int) -> t.Any:
@ -77,7 +79,25 @@ class ProtocolV2(ProtocolAndChannel):
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
def _establish_new_channel(self):
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(
MANAGEMENT_SESSION_ID,
message_type,
message_data,
0x14, # TODO update control byte
)
_ = self._read_until_valid_crc_check() # TODO check ACK
session_id, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features)
self.features = features
self._send_ack_2()
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Send channel allocation request
@ -124,7 +144,7 @@ class ProtocolV2(ProtocolAndChannel):
noise_tag = payload[80:96]
# TODO check noise tag
print("noise_tag: ", hexlify(noise_tag).decode())
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
@ -143,7 +163,6 @@ class ProtocolV2(ProtocolAndChannel):
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
@ -259,7 +278,7 @@ class ProtocolV2(ProtocolAndChannel):
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!")
@ -272,7 +291,7 @@ class ProtocolV2(ProtocolAndChannel):
message_type = message[1:3]
message_data = message[3:]
return (
int.to_bytes(session_id, 1, "big"),
session_id,
int.from_bytes(message_type, "big"),
message_data,
)

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from ...messages import Features, Initialize
from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession
from .protocol_and_channel import ProtocolV1
from .protocol_v2 import ProtocolV2
@ -34,11 +34,11 @@ class SessionV1(Session):
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"")
cls.features = session.call(
session.features = session.call(
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
Initialize()
)
session.id = cls.features.session_id
session.id = session.features.session_id
return session
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
@ -51,15 +51,32 @@ class SessionV1(Session):
class SessionV2(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV2(client, b"\x00")
new_session: ThpNewSession = session.call(
ThpCreateNewSession(passphrase=passphrase, derive_cardano=derive_cardano)
)
assert new_session.new_session_id is not None
session_id = new_session.new_session_id
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
self.channel = client.protocol.get_channel()
self.sid = self._convert_id_to_sid(id)
self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id)
self.features = self.channel.features
def call(self, msg: t.Any) -> t.Any:
self.channel.write(self.sid, msg)
return self.channel.read(self.sid)
def _convert_id_to_sid(self, id: bytes) -> int:
return int.from_bytes(id, "big") # TODO update to extract only sid
def update_id_and_sid(self, id: bytes) -> None:
self.id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

Loading…
Cancel
Save