mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-12 17:48:09 +00:00
wip trezorlib
This commit is contained in:
parent
11309cccd0
commit
d59684988c
@ -289,12 +289,10 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
|
|||||||
|
|
||||||
for transport in new_enumerate_devices():
|
for transport in new_enumerate_devices():
|
||||||
try:
|
try:
|
||||||
print("test A")
|
|
||||||
client = NewTrezorClient(transport)
|
client = NewTrezorClient(transport)
|
||||||
session = client.get_session()
|
session = client.get_management_session()
|
||||||
description = format_device_name(session.features)
|
description = format_device_name(session.features)
|
||||||
# client.end_session()
|
# client.end_session()
|
||||||
print("after end session")
|
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
|||||||
from . import messages, protobuf
|
from . import messages, protobuf
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProtobufMapping:
|
class ProtobufMapping:
|
||||||
@ -63,7 +65,7 @@ class ProtobufMapping:
|
|||||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||||
if wire_type is None:
|
if wire_type is None:
|
||||||
raise ValueError("Cannot encode class without wire type")
|
raise ValueError("Cannot encode class without wire type")
|
||||||
print("wire type", wire_type)
|
LOG.debug("encoding wire type %d", wire_type)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
return wire_type, buf.getvalue()
|
return wire_type, buf.getvalue()
|
||||||
|
@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class NewTrezorClient:
|
class NewTrezorClient:
|
||||||
|
management_session: Session | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: NewTransport,
|
transport: NewTransport,
|
||||||
@ -26,10 +28,8 @@ class NewTrezorClient:
|
|||||||
self.mapping = mapping.DEFAULT_MAPPING
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
else:
|
else:
|
||||||
self.mapping = protobuf_mapping
|
self.mapping = protobuf_mapping
|
||||||
print("test B")
|
|
||||||
|
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
print("test C")
|
|
||||||
self.protocol = self._get_protocol()
|
self.protocol = self._get_protocol()
|
||||||
else:
|
else:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
@ -40,7 +40,10 @@ class NewTrezorClient:
|
|||||||
) -> NewTrezorClient: ...
|
) -> NewTrezorClient: ...
|
||||||
|
|
||||||
def get_session(
|
def get_session(
|
||||||
self, passphrase: str = "", derive_cardano: bool = False
|
self,
|
||||||
|
passphrase: str = "",
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
management_session: bool = False,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
if isinstance(self.protocol, ProtocolV1):
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
return SessionV1.new(self, passphrase, derive_cardano)
|
return SessionV1.new(self, passphrase, derive_cardano)
|
||||||
@ -48,6 +51,18 @@ class NewTrezorClient:
|
|||||||
return SessionV2.new(self, passphrase, derive_cardano)
|
return SessionV2.new(self, passphrase, derive_cardano)
|
||||||
raise NotImplementedError # TODO
|
raise NotImplementedError # TODO
|
||||||
|
|
||||||
|
def get_management_session(self):
|
||||||
|
if self.management_session is not None:
|
||||||
|
return self.management_session
|
||||||
|
|
||||||
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
|
self.management_session = SessionV1.new(self, "", False)
|
||||||
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
|
self.management_session = SessionV2(self, b"\x00")
|
||||||
|
|
||||||
|
assert self.management_session is not None
|
||||||
|
return self.management_session
|
||||||
|
|
||||||
def resume_session(self, session_id: bytes) -> Session:
|
def resume_session(self, session_id: bytes) -> Session:
|
||||||
raise NotImplementedError # TODO
|
raise NotImplementedError # TODO
|
||||||
|
|
||||||
@ -66,7 +81,6 @@ class NewTrezorClient:
|
|||||||
response = protocol.read()
|
response = protocol.read()
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
if isinstance(response, messages.Failure):
|
if isinstance(response, messages.Failure):
|
||||||
print("test F1")
|
|
||||||
if (
|
if (
|
||||||
response.code == FailureType.UnexpectedMessage
|
response.code == FailureType.UnexpectedMessage
|
||||||
and response.message == "Invalid protocol"
|
and response.message == "Invalid protocol"
|
||||||
|
@ -61,11 +61,9 @@ class ProtocolV1(ProtocolAndChannel):
|
|||||||
self._write(msg_type, msg_bytes)
|
self._write(msg_type, msg_bytes)
|
||||||
|
|
||||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||||
print("wooooo")
|
|
||||||
chunk_size = self.transport.CHUNK_SIZE
|
chunk_size = self.transport.CHUNK_SIZE
|
||||||
header = struct.pack(">HL", message_type, len(message_data))
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
buffer = bytearray(b"##" + header + message_data)
|
buffer = bytearray(b"##" + header + message_data)
|
||||||
print("wooooo")
|
|
||||||
|
|
||||||
while buffer:
|
while buffer:
|
||||||
# Report ID, data padded to 63 bytes
|
# Report ID, data padded to 63 bytes
|
||||||
|
@ -54,6 +54,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
sync_bit_send: int
|
sync_bit_send: int
|
||||||
sync_bit_receive: int
|
sync_bit_receive: int
|
||||||
has_valid_channel: bool = False
|
has_valid_channel: bool = False
|
||||||
|
features: messages.Features
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -67,6 +68,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
def get_channel(self) -> ProtocolV2:
|
def get_channel(self) -> ProtocolV2:
|
||||||
if not self.has_valid_channel:
|
if not self.has_valid_channel:
|
||||||
self._establish_new_channel()
|
self._establish_new_channel()
|
||||||
|
self.update_features()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def read(self, session_id: int) -> t.Any:
|
def read(self, session_id: int) -> t.Any:
|
||||||
@ -77,7 +79,25 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
msg_type, msg_data = self.mapping.encode(msg)
|
msg_type, msg_data = self.mapping.encode(msg)
|
||||||
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
|
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
|
||||||
|
|
||||||
def _establish_new_channel(self):
|
def update_features(self) -> None:
|
||||||
|
message = messages.GetFeatures()
|
||||||
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
|
||||||
|
self.session_id: int = 0
|
||||||
|
self._encrypt_and_write(
|
||||||
|
MANAGEMENT_SESSION_ID,
|
||||||
|
message_type,
|
||||||
|
message_data,
|
||||||
|
0x14, # TODO update control byte
|
||||||
|
)
|
||||||
|
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||||
|
session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
features = self.mapping.decode(msg_type, msg_data)
|
||||||
|
assert isinstance(features, messages.Features)
|
||||||
|
self.features = features
|
||||||
|
self._send_ack_2()
|
||||||
|
|
||||||
|
def _establish_new_channel(self) -> None:
|
||||||
self.sync_bit_send = 0
|
self.sync_bit_send = 0
|
||||||
self.sync_bit_receive = 0
|
self.sync_bit_receive = 0
|
||||||
# Send channel allocation request
|
# Send channel allocation request
|
||||||
@ -124,7 +144,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
noise_tag = payload[80:96]
|
noise_tag = payload[80:96]
|
||||||
|
|
||||||
# TODO check noise tag
|
# TODO check noise tag
|
||||||
print("noise_tag: ", hexlify(noise_tag).decode())
|
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||||
|
|
||||||
# Prepare and send handshake completion request
|
# Prepare and send handshake completion request
|
||||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||||
@ -143,7 +163,6 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||||
IV_1, encrypted_trezor_static_pubkey, h
|
IV_1, encrypted_trezor_static_pubkey, h
|
||||||
)
|
)
|
||||||
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||||
@ -259,7 +278,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
self.transport, header, encrypted_message
|
self.transport, header, encrypted_message
|
||||||
)
|
)
|
||||||
|
|
||||||
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
|
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||||
header, raw_payload = self._read_until_valid_crc_check()
|
header, raw_payload = self._read_until_valid_crc_check()
|
||||||
if not header.is_encrypted_transport():
|
if not header.is_encrypted_transport():
|
||||||
print("Trying to decrypt not encrypted message!")
|
print("Trying to decrypt not encrypted message!")
|
||||||
@ -272,7 +291,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
message_type = message[1:3]
|
message_type = message[1:3]
|
||||||
message_data = message[3:]
|
message_data = message[3:]
|
||||||
return (
|
return (
|
||||||
int.to_bytes(session_id, 1, "big"),
|
session_id,
|
||||||
int.from_bytes(message_type, "big"),
|
int.from_bytes(message_type, "big"),
|
||||||
message_data,
|
message_data,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from ...messages import Features, Initialize
|
from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession
|
||||||
from .protocol_and_channel import ProtocolV1
|
from .protocol_and_channel import ProtocolV1
|
||||||
from .protocol_v2 import ProtocolV2
|
from .protocol_v2 import ProtocolV2
|
||||||
|
|
||||||
@ -34,11 +34,11 @@ class SessionV1(Session):
|
|||||||
) -> SessionV1:
|
) -> SessionV1:
|
||||||
assert isinstance(client.protocol, ProtocolV1)
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
session = SessionV1(client, b"")
|
session = SessionV1(client, b"")
|
||||||
cls.features = session.call(
|
session.features = session.call(
|
||||||
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
|
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
|
||||||
Initialize()
|
Initialize()
|
||||||
)
|
)
|
||||||
session.id = cls.features.session_id
|
session.id = session.features.session_id
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
|
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
|
||||||
@ -51,15 +51,32 @@ class SessionV1(Session):
|
|||||||
|
|
||||||
|
|
||||||
class SessionV2(Session):
|
class SessionV2(Session):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||||
|
) -> SessionV2:
|
||||||
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
|
session = SessionV2(client, b"\x00")
|
||||||
|
new_session: ThpNewSession = session.call(
|
||||||
|
ThpCreateNewSession(passphrase=passphrase, derive_cardano=derive_cardano)
|
||||||
|
)
|
||||||
|
assert new_session.new_session_id is not None
|
||||||
|
session_id = new_session.new_session_id
|
||||||
|
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||||
|
return session
|
||||||
|
|
||||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||||
super().__init__(client, id)
|
super().__init__(client, id)
|
||||||
assert isinstance(client.protocol, ProtocolV2)
|
assert isinstance(client.protocol, ProtocolV2)
|
||||||
self.channel = client.protocol.get_channel()
|
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||||
self.sid = self._convert_id_to_sid(id)
|
self.update_id_and_sid(id)
|
||||||
|
self.features = self.channel.features
|
||||||
|
|
||||||
def call(self, msg: t.Any) -> t.Any:
|
def call(self, msg: t.Any) -> t.Any:
|
||||||
self.channel.write(self.sid, msg)
|
self.channel.write(self.sid, msg)
|
||||||
return self.channel.read(self.sid)
|
return self.channel.read(self.sid)
|
||||||
|
|
||||||
def _convert_id_to_sid(self, id: bytes) -> int:
|
def update_id_and_sid(self, id: bytes) -> None:
|
||||||
return int.from_bytes(id, "big") # TODO update to extract only sid
|
self.id = id
|
||||||
|
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
||||||
|
Loading…
Reference in New Issue
Block a user