wip trezorlib

M1nd3r/thp-improved
M1nd3r 1 week ago
parent 11309cccd0
commit d59684988c

@ -289,12 +289,10 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
for transport in new_enumerate_devices(): for transport in new_enumerate_devices():
try: try:
print("test A")
client = NewTrezorClient(transport) client = NewTrezorClient(transport)
session = client.get_session() session = client.get_management_session()
description = format_device_name(session.features) description = format_device_name(session.features)
# client.end_session() # client.end_session()
print("after end session")
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception:

@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import io import io
import logging
from types import ModuleType from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar from typing import Dict, Optional, Tuple, Type, TypeVar
@ -25,6 +26,7 @@ from typing_extensions import Self
from . import messages, protobuf from . import messages, protobuf
T = TypeVar("T") T = TypeVar("T")
LOG = logging.getLogger(__name__)
class ProtobufMapping: class ProtobufMapping:
@ -63,7 +65,7 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None: if wire_type is None:
raise ValueError("Cannot encode class without wire type") raise ValueError("Cannot encode class without wire type")
print("wire type", wire_type) LOG.debug("encoding wire type %d", wire_type)
buf = io.BytesIO() buf = io.BytesIO()
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue() return wire_type, buf.getvalue()

@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
class NewTrezorClient: class NewTrezorClient:
management_session: Session | None = None
def __init__( def __init__(
self, self,
transport: NewTransport, transport: NewTransport,
@ -26,10 +28,8 @@ class NewTrezorClient:
self.mapping = mapping.DEFAULT_MAPPING self.mapping = mapping.DEFAULT_MAPPING
else: else:
self.mapping = protobuf_mapping self.mapping = protobuf_mapping
print("test B")
if protocol is None: if protocol is None:
print("test C")
self.protocol = self._get_protocol() self.protocol = self._get_protocol()
else: else:
self.protocol = protocol self.protocol = protocol
@ -40,7 +40,10 @@ class NewTrezorClient:
) -> NewTrezorClient: ... ) -> NewTrezorClient: ...
def get_session( def get_session(
self, passphrase: str = "", derive_cardano: bool = False self,
passphrase: str = "",
derive_cardano: bool = False,
management_session: bool = False,
) -> Session: ) -> Session:
if isinstance(self.protocol, ProtocolV1): if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano) return SessionV1.new(self, passphrase, derive_cardano)
@ -48,6 +51,18 @@ class NewTrezorClient:
return SessionV2.new(self, passphrase, derive_cardano) return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO raise NotImplementedError # TODO
def get_management_session(self):
if self.management_session is not None:
return self.management_session
if isinstance(self.protocol, ProtocolV1):
self.management_session = SessionV1.new(self, "", False)
elif isinstance(self.protocol, ProtocolV2):
self.management_session = SessionV2(self, b"\x00")
assert self.management_session is not None
return self.management_session
def resume_session(self, session_id: bytes) -> Session: def resume_session(self, session_id: bytes) -> Session:
raise NotImplementedError # TODO raise NotImplementedError # TODO
@ -66,7 +81,6 @@ class NewTrezorClient:
response = protocol.read() response = protocol.read()
self.transport.close() self.transport.close()
if isinstance(response, messages.Failure): if isinstance(response, messages.Failure):
print("test F1")
if ( if (
response.code == FailureType.UnexpectedMessage response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol" and response.message == "Invalid protocol"

@ -61,11 +61,9 @@ class ProtocolV1(ProtocolAndChannel):
self._write(msg_type, msg_bytes) self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None: def _write(self, message_type: int, message_data: bytes) -> None:
print("wooooo")
chunk_size = self.transport.CHUNK_SIZE chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data)) header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data) buffer = bytearray(b"##" + header + message_data)
print("wooooo")
while buffer: while buffer:
# Report ID, data padded to 63 bytes # Report ID, data padded to 63 bytes

@ -54,6 +54,7 @@ class ProtocolV2(ProtocolAndChannel):
sync_bit_send: int sync_bit_send: int
sync_bit_receive: int sync_bit_receive: int
has_valid_channel: bool = False has_valid_channel: bool = False
features: messages.Features
def __init__( def __init__(
self, self,
@ -67,6 +68,7 @@ class ProtocolV2(ProtocolAndChannel):
def get_channel(self) -> ProtocolV2: def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel: if not self.has_valid_channel:
self._establish_new_channel() self._establish_new_channel()
self.update_features()
return self return self
def read(self, session_id: int) -> t.Any: def read(self, session_id: int) -> t.Any:
@ -77,7 +79,25 @@ class ProtocolV2(ProtocolAndChannel):
msg_type, msg_data = self.mapping.encode(msg) msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
def _establish_new_channel(self): def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(
MANAGEMENT_SESSION_ID,
message_type,
message_data,
0x14, # TODO update control byte
)
_ = self._read_until_valid_crc_check() # TODO check ACK
session_id, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features)
self.features = features
self._send_ack_2()
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0 self.sync_bit_send = 0
self.sync_bit_receive = 0 self.sync_bit_receive = 0
# Send channel allocation request # Send channel allocation request
@ -124,7 +144,7 @@ class ProtocolV2(ProtocolAndChannel):
noise_tag = payload[80:96] noise_tag = payload[80:96]
# TODO check noise tag # TODO check noise tag
print("noise_tag: ", hexlify(noise_tag).decode()) LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# Prepare and send handshake completion request # Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
@ -143,7 +163,6 @@ class ProtocolV2(ProtocolAndChannel):
trezor_masked_static_pubkey = aes_ctx.decrypt( trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h IV_1, encrypted_trezor_static_pubkey, h
) )
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
except Exception as e: except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey) h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
@ -259,7 +278,7 @@ class ProtocolV2(ProtocolAndChannel):
self.transport, header, encrypted_message self.transport, header, encrypted_message
) )
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]: def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check() header, raw_payload = self._read_until_valid_crc_check()
if not header.is_encrypted_transport(): if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!") print("Trying to decrypt not encrypted message!")
@ -272,7 +291,7 @@ class ProtocolV2(ProtocolAndChannel):
message_type = message[1:3] message_type = message[1:3]
message_data = message[3:] message_data = message[3:]
return ( return (
int.to_bytes(session_id, 1, "big"), session_id,
int.from_bytes(message_type, "big"), int.from_bytes(message_type, "big"),
message_data, message_data,
) )

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t import typing as t
from ...messages import Features, Initialize from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession
from .protocol_and_channel import ProtocolV1 from .protocol_and_channel import ProtocolV1
from .protocol_v2 import ProtocolV2 from .protocol_v2 import ProtocolV2
@ -34,11 +34,11 @@ class SessionV1(Session):
) -> SessionV1: ) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1) assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"") session = SessionV1(client, b"")
cls.features = session.call( session.features = session.call(
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO # Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
Initialize() Initialize()
) )
session.id = cls.features.session_id session.id = session.features.session_id
return session return session
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any: def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
@ -51,15 +51,32 @@ class SessionV1(Session):
class SessionV2(Session): class SessionV2(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV2(client, b"\x00")
new_session: ThpNewSession = session.call(
ThpCreateNewSession(passphrase=passphrase, derive_cardano=derive_cardano)
)
assert new_session.new_session_id is not None
session_id = new_session.new_session_id
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: NewTrezorClient, id: bytes) -> None: def __init__(self, client: NewTrezorClient, id: bytes) -> None:
super().__init__(client, id) super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2) assert isinstance(client.protocol, ProtocolV2)
self.channel = client.protocol.get_channel() self.channel: ProtocolV2 = client.protocol.get_channel()
self.sid = self._convert_id_to_sid(id) self.update_id_and_sid(id)
self.features = self.channel.features
def call(self, msg: t.Any) -> t.Any: def call(self, msg: t.Any) -> t.Any:
self.channel.write(self.sid, msg) self.channel.write(self.sid, msg)
return self.channel.read(self.sid) return self.channel.read(self.sid)
def _convert_id_to_sid(self, id: bytes) -> int: def update_id_and_sid(self, id: bytes) -> None:
return int.from_bytes(id, "big") # TODO update to extract only sid self.id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

Loading…
Cancel
Save