1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-03 16:56:07 +00:00

python: unify protobuf-encoding code paths

Protobuf encoding now happens in TrezorClient, and transports get encoded blobs
to (chunkify and) send.  This is a better design because transports don't need
to know about protobuf.

It also lays groundwork for sending raw bytes feature (#116)

This commit also removes all vestiges of ProtocolV2 which was never used and
will probably need to be redesigned from the ground up anyway. The code is
still ready for protocol flexibility.
This commit is contained in:
matejcik 2020-03-05 17:38:31 +01:00
parent 22b167a961
commit 9a330f3475
13 changed files with 119 additions and 268 deletions

View File

@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/
- `get_default_client` respects `TREZOR_PATH` environment variable - `get_default_client` respects `TREZOR_PATH` environment variable
- UI callback `get_passphrase` has an additional argument `available_on_device`, - UI callback `get_passphrase` has an additional argument `available_on_device`,
indicating that the connected Trezor is capable of on-device entry indicating that the connected Trezor is capable of on-device entry
- `Transport.write` and `read` method signatures changed to accept bytes instead of
protobuf messages
### Fixed ### Fixed

View File

@ -21,7 +21,8 @@ import warnings
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
from .log import DUMP_BYTES
from .messages import Capability from .messages import Capability
if sys.version_info.major < 3: if sys.version_info.major < 3:
@ -114,11 +115,34 @@ class TrezorClient:
def _raw_write(self, msg): def _raw_write(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
self.transport.write(msg) LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
msg_type, msg_bytes = mapping.encode(msg)
LOG.log(
DUMP_BYTES,
"encoded as type {} ({} bytes): {}".format(
msg_type, len(msg_bytes), msg_bytes.hex()
),
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self): def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
return self.transport.read() msg_type, msg_bytes = self.transport.read()
LOG.log(
DUMP_BYTES,
"received type {} ({} bytes): {}".format(
msg_type, len(msg_bytes), msg_bytes.hex()
),
)
msg = mapping.decode(msg_type, msg_bytes)
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def _callback_pin(self, msg): def _callback_pin(self, msg):
try: try:

View File

@ -19,7 +19,7 @@ from copy import deepcopy
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import messages as proto, protobuf from . import mapping, messages as proto, protobuf
from .client import TrezorClient from .client import TrezorClient
from .tools import expect from .tools import expect
@ -45,11 +45,12 @@ class DebugLink:
self.transport.end_session() self.transport.end_session()
def _call(self, msg, nowait=False): def _call(self, msg, nowait=False):
self.transport.write(msg) msg_type, msg_bytes = mapping.encode(msg)
self.transport.write(msg_type, msg_bytes)
if nowait: if nowait:
return None return None
ret = self.transport.read() ret_type, ret_bytes = self.transport.read()
return ret return mapping.decode(ret_type, ret_bytes)
def state(self): def state(self):
return self._call(proto.DebugLinkGetState()) return self._call(proto.DebugLinkGetState())

View File

@ -22,8 +22,10 @@ from . import protobuf
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]] OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
DUMP_BYTES = 5 DUMP_BYTES = 5
DUMP_PACKETS = 4
logging.addLevelName(DUMP_BYTES, "BYTES") logging.addLevelName(DUMP_BYTES, "BYTES")
logging.addLevelName(DUMP_PACKETS, "PACKETS")
class PrettyProtobufFormatter(logging.Formatter): class PrettyProtobufFormatter(logging.Formatter):
@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] =
level = logging.DEBUG level = logging.DEBUG
if verbosity > 1: if verbosity > 1:
level = DUMP_BYTES level = DUMP_BYTES
if verbosity > 2:
level = DUMP_PACKETS
logger = logging.getLogger("trezorlib") logger = logging.getLogger("trezorlib")
logger.setLevel(level) logger.setLevel(level)

View File

@ -14,7 +14,10 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages import io
from typing import Tuple
from . import messages, protobuf
map_type_to_class = {} map_type_to_class = {}
map_class_to_type = {} map_class_to_type = {}
@ -59,4 +62,17 @@ def get_class(t):
return map_type_to_class[t] return map_type_to_class[t]
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
message_type = msg.MESSAGE_WIRE_TYPE
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return message_type, buf.getvalue()
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
cls = get_class(message_type)
buf = io.BytesIO(message_bytes)
return protobuf.load_message(buf, cls)
build_map() build_map()

View File

@ -15,10 +15,9 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging import logging
from typing import Iterable, List, Type from typing import Iterable, List, Tuple, Type
from ..exceptions import TrezorException from ..exceptions import TrezorException
from ..protobuf import MessageType
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip() """.strip()
MessagePayload = Tuple[int, bytes]
class TransportException(TrezorException): class TransportException(TrezorException):
pass pass
@ -71,10 +73,10 @@ class Transport:
def end_session(self) -> None: def end_session(self) -> None:
raise NotImplementedError raise NotImplementedError
def read(self) -> MessageType: def read(self) -> MessagePayload:
raise NotImplementedError raise NotImplementedError
def write(self, message: MessageType) -> None: def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -16,14 +16,12 @@
import logging import logging
import struct import struct
from io import BytesIO
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional
import requests import requests
from .. import mapping, protobuf from ..log import DUMP_PACKETS
from ..log import DUMP_BYTES from . import MessagePayload, Transport, TransportException
from . import Transport, TransportException
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -66,10 +64,12 @@ class BridgeHandle:
class BridgeHandleModern(BridgeHandle): class BridgeHandleModern(BridgeHandle):
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex()))
self.transport._call("post", data=buf.hex()) self.transport._call("post", data=buf.hex())
def read_buf(self) -> bytes: def read_buf(self) -> bytes:
data = self.transport._call("read") data = self.transport._call("read")
LOG.log(DUMP_PACKETS, "received message: {}".format(data.text))
return bytes.fromhex(data.text) return bytes.fromhex(data.text)
@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle):
if self.request is None: if self.request is None:
raise TransportException("Can't read without write on legacy Bridge") raise TransportException("Can't read without write on legacy Bridge")
try: try:
LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request))
data = self.transport._call("call", data=self.request) data = self.transport._call("call", data=self.request)
LOG.log(DUMP_PACKETS, "received response: {}".format(data.text))
return bytes.fromhex(data.text) return bytes.fromhex(data.text)
finally: finally:
self.request = None self.request = None
@ -152,29 +154,12 @@ class BridgeTransport(Transport):
self._call("release") self._call("release")
self.session = None self.session = None
def write(self, msg: protobuf.MessageType) -> None: def write(self, message_type: int, message_data: bytes) -> None:
LOG.debug( header = struct.pack(">HL", message_type, len(message_data))
"sending message: {}".format(msg.__class__.__name__), self.handle.write_buf(header + message_data)
extra={"protobuf": msg},
)
buffer = BytesIO()
protobuf.dump_message(buffer, msg)
ser = buffer.getvalue()
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
self.handle.write_buf(header + ser) def read(self) -> MessagePayload:
def read(self) -> protobuf.MessageType:
data = self.handle.read_buf() data = self.handle.read_buf()
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) msg_type, datalen = struct.unpack(">HL", data[:headerlen])
ser = data[headerlen : headerlen + datalen] return msg_type, data[headerlen : headerlen + datalen]
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
buffer = BytesIO(ser)
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg

View File

@ -19,6 +19,7 @@ import sys
import time import time
from typing import Any, Dict, Iterable from typing import Any, Dict, Iterable
from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
@ -82,8 +83,9 @@ class HidHandle:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2: if self.hid_version == 2:
self.handle.write(b"\0" + bytearray(chunk)) chunk = b"\x00" + chunk
else:
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
self.handle.write(chunk) self.handle.write(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
@ -93,6 +95,8 @@ class HidHandle:
break break
else: else:
time.sleep(0.001) time.sleep(0.001)
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytes(chunk) return bytes(chunk)
@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport):
self.device = device self.device = device
self.handle = HidHandle(device["path"], device["serial_number"]) self.handle = HidHandle(device["path"], device["serial_number"])
protocol = ProtocolV1(self.handle) super().__init__(protocol=ProtocolV1(self.handle))
super().__init__(protocol=protocol)
def get_path(self) -> str: def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode()) return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
@ -142,10 +145,6 @@ class HidTransport(ProtocolBasedTransport):
return devices return devices
def find_debug(self) -> "HidTransport": def find_debug(self) -> "HidTransport":
if self.protocol.VERSION >= 2:
# use the same device
return self
else:
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True): for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]: if debug.device["serial_number"] == self.device["serial_number"]:

View File

@ -15,16 +15,12 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging import logging
import os
import struct import struct
from io import BytesIO
from typing import Tuple from typing import Tuple
from typing_extensions import Protocol as StructuralType from typing_extensions import Protocol as StructuralType
from .. import mapping, protobuf from . import MessagePayload, Transport
from ..log import DUMP_BYTES
from . import Transport
REPLEN = 64 REPLEN = 64
@ -72,7 +68,6 @@ class Protocol:
- open and close physical connections, - open and close physical connections,
- and send and receive binary chunks. - and send and receive binary chunks.
We declare a protocol version (we have implementations of v1 and v2).
For now, the class also handles session counting and opening the underlying Handle. For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future. This will probably be removed in the future.
@ -80,8 +75,6 @@ class Protocol:
its messages. its messages.
""" """
VERSION = None # type: int
def __init__(self, handle: Handle) -> None: def __init__(self, handle: Handle) -> None:
self.handle = handle self.handle = handle
self.session_counter = 0 self.session_counter = 0
@ -97,10 +90,10 @@ class Protocol:
if self.session_counter == 0: if self.session_counter == 0:
self.handle.close() self.handle.close()
def read(self) -> protobuf.MessageType: def read(self) -> MessagePayload:
raise NotImplementedError raise NotImplementedError
def write(self, message: protobuf.MessageType) -> None: def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError raise NotImplementedError
@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport):
def __init__(self, protocol: Protocol) -> None: def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol self.protocol = protocol
def write(self, message: protobuf.MessageType) -> None: def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message) self.protocol.write(message_type, message_data)
def read(self) -> protobuf.MessageType: def read(self) -> MessagePayload:
return self.protocol.read() return self.protocol.read()
def begin_session(self) -> None: def begin_session(self) -> None:
@ -132,19 +125,11 @@ class ProtocolV1(Protocol):
Does not understand sessions. Does not understand sessions.
""" """
VERSION = 1 HEADER_LEN = struct.calcsize(">HL")
def write(self, msg: protobuf.MessageType) -> None: def write(self, message_type: int, message_data: bytes) -> None:
LOG.debug( header = struct.pack(">HL", message_type, len(message_data))
"sending message: {}".format(msg.__class__.__name__), buffer = bytearray(b"##" + header + message_data)
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
buffer = bytearray(b"##" + header + ser)
while buffer: while buffer:
# Report ID, data padded to 63 bytes # Report ID, data padded to 63 bytes
@ -153,7 +138,7 @@ class ProtocolV1(Protocol):
self.handle.write_chunk(chunk) self.handle.write_chunk(chunk)
buffer = buffer[63:] buffer = buffer[63:]
def read(self) -> protobuf.MessageType: def read(self) -> MessagePayload:
buffer = bytearray() buffer = bytearray()
# Read header with first part of message data # Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first() msg_type, datalen, first_chunk = self.read_first()
@ -163,30 +148,18 @@ class ProtocolV1(Protocol):
while len(buffer) < datalen: while len(buffer) < datalen:
buffer.extend(self.read_next()) buffer.extend(self.read_next())
# Strip padding return msg_type, buffer[:datalen]
ser = buffer[:datalen]
data = BytesIO(ser)
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]: def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk() chunk = self.handle.read_chunk()
if chunk[:3] != b"?##": if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters") raise RuntimeError("Unexpected magic characters")
try: try:
headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception: except Exception:
raise RuntimeError("Cannot parse header") raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :] data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data return msg_type, datalen, data
def read_next(self) -> bytes: def read_next(self) -> bytes:
@ -194,160 +167,3 @@ class ProtocolV1(Protocol):
if chunk[:1] != b"?": if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters") raise RuntimeError("Unexpected magic characters")
return chunk[1:] return chunk[1:]
class ProtocolV2(Protocol):
"""Protocol version 2. Currently (11/2018) not used.
Intended to mimic U2F/WebAuthN session handling.
"""
VERSION = 2
def __init__(self, handle: Handle) -> None:
self.session = None
super().__init__(handle)
def begin_session(self) -> None:
# ensure open connection
super().begin_session()
# initiate session
chunk = struct.pack(">B", V2_BEGIN_SESSION)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
# get session identifier
resp = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", resp[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_BEGIN_SESSION:
raise RuntimeError("Unexpected magic character")
self.session = session
LOG.debug("[session {}] session started".format(self.session))
def end_session(self) -> None:
if not self.session:
return
try:
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
resp = self.handle.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != V2_END_SESSION:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
finally:
self.session = None
# close connection if appropriate
super().end_session()
def write(self, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
else:
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, chunk = self.read_first()
buffer.extend(chunk)
# Read the rest of the message
while len(buffer) < datalen:
next_chunk = self.read_next()
buffer.extend(next_chunk)
# Strip padding
buffer = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_FIRST_CHUNK:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_NEXT_CHUNK:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
"""Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version.
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
which forces the library to use V1 anyways.
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
for it (i.e., USB transports for Trezor T).
"""
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
if want_v2 and not force_v1:
return ProtocolV2(handle)
else:
return ProtocolV1(handle)

View File

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import socket import socket
import time import time
from typing import Iterable, Optional, cast from typing import Iterable, Optional, cast
from ..log import DUMP_PACKETS
from . import TransportException from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol from .protocol import ProtocolBasedTransport, ProtocolV1
SOCKET_TIMEOUT = 10 SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport): class UdpTransport(ProtocolBasedTransport):
@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport):
self.device = (host, port) self.device = (host, port)
self.socket = None # type: Optional[socket.socket] self.socket = None # type: Optional[socket.socket]
protocol = get_protocol(self, want_v2=False) super().__init__(protocol=ProtocolV1(self))
super().__init__(protocol=protocol)
def get_path(self) -> str: def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport):
assert self.socket is not None assert self.socket is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex()))
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport):
break break
except socket.timeout: except socket.timeout:
continue continue
LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex()))
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk) return bytearray(chunk)

View File

@ -20,6 +20,7 @@ import sys
import time import time
from typing import Iterable, Optional from typing import Iterable, Optional
from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException from . import TREZORS, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
@ -65,6 +66,7 @@ class WebUsbHandle:
assert self.handle is not None assert self.handle is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
self.handle.interruptWrite(self.endpoint, chunk) self.handle.interruptWrite(self.endpoint, chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
@ -76,6 +78,7 @@ class WebUsbHandle:
break break
else: else:
time.sleep(0.001) time.sleep(0.001)
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk return chunk
@ -136,12 +139,6 @@ class WebUsbTransport(ProtocolBasedTransport):
return devices return devices
def find_debug(self) -> "WebUsbTransport": def find_debug(self) -> "WebUsbTransport":
if self.protocol.VERSION >= 2:
# TODO test this
# XXX this is broken right now because sessions don't really work
# For v2 protocol, use the same WebUSB interface with a different session
return WebUsbTransport(self.device, self.handle)
else:
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return WebUsbTransport(self.device, debug=True)

View File

@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message):
resp = client.call_raw(message) resp = client.call_raw(message)
assert isinstance(resp, m.ButtonRequest) assert isinstance(resp, m.ButtonRequest)
client.transport.write(m.ButtonAck()) client._raw_write(m.ButtonAck())
client.transport.write(m.Initialize()) client._raw_write(m.Initialize())
resp = client.transport.read() resp = client._raw_read()
assert isinstance(resp, m.Features) assert isinstance(resp, m.Features)

View File

@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2:
# Enter mnemonic words # Enter mnemonic words
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
client.transport.write(proto.ButtonAck()) client._raw_write(proto.ButtonAck())
for word in mnemonic: for word in mnemonic:
client.debug.input(word) client.debug.input(word)
ret = client.transport.read() ret = client._raw_read()
# Confirm success # Confirm success
assert isinstance(ret, proto.ButtonRequest) assert isinstance(ret, proto.ButtonRequest)
@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2:
# Enter mnemonic words # Enter mnemonic words
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
client.transport.write(proto.ButtonAck()) client._raw_write(proto.ButtonAck())
for word in mnemonic: for word in mnemonic:
client.debug.input(word) client.debug.input(word)
ret = client.transport.read() ret = client._raw_read()
# Confirm success # Confirm success
assert isinstance(ret, proto.ButtonRequest) assert isinstance(ret, proto.ButtonRequest)