From 9a330f34756c7f553c187d007ddce14eea73ca65 Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 5 Mar 2020 17:38:31 +0100 Subject: [PATCH] python: unify protobuf-encoding code paths Protobuf encoding now happens in TrezorClient, and transports get encoded blobs to (chunkify and) send. This is a better design because transports don't need to know about protobuf. It also lays groundwork for sending raw bytes feature (#116) This commit also removes all vestiges of ProtocolV2 which was never used and will probably need to be redesigned from the ground up anyway. The code is still ready for protocol flexibility. --- python/CHANGELOG.md | 2 + python/src/trezorlib/client.py | 30 ++- python/src/trezorlib/debuglink.py | 9 +- python/src/trezorlib/log.py | 4 + python/src/trezorlib/mapping.py | 18 +- python/src/trezorlib/transport/__init__.py | 10 +- python/src/trezorlib/transport/bridge.py | 37 +-- python/src/trezorlib/transport/hid.py | 27 ++- python/src/trezorlib/transport/protocol.py | 212 ++---------------- python/src/trezorlib/transport/udp.py | 11 +- python/src/trezorlib/transport/webusb.py | 13 +- tests/device_tests/test_cancel.py | 6 +- .../test_msg_recoverydevice_bip39_t2.py | 8 +- 13 files changed, 119 insertions(+), 268 deletions(-) diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 521b023e6..247774ed0 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/ - `get_default_client` respects `TREZOR_PATH` environment variable - UI callback `get_passphrase` has an additional argument `available_on_device`, indicating that the connected Trezor is capable of on-device entry +- `Transport.write` and `read` method signatures changed to accept bytes instead of + protobuf messages ### Fixed diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 7f3262e97..1c0bbfa55 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -21,7 +21,8 @@ import warnings from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools +from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools +from .log import DUMP_BYTES from .messages import Capability if sys.version_info.major < 3: @@ -114,11 +115,34 @@ class TrezorClient: def _raw_write(self, msg): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self.transport.write(msg) + LOG.debug( + "sending message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = mapping.encode(msg) + LOG.log( + DUMP_BYTES, + "encoded as type {} ({} bytes): {}".format( + msg_type, len(msg_bytes), msg_bytes.hex() + ), + ) + self.transport.write(msg_type, msg_bytes) def _raw_read(self): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - return self.transport.read() + msg_type, msg_bytes = self.transport.read() + LOG.log( + DUMP_BYTES, + "received type {} ({} bytes): {}".format( + msg_type, len(msg_bytes), msg_bytes.hex() + ), + ) + msg = mapping.decode(msg_type, msg_bytes) + LOG.debug( + "received message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + return msg def _callback_pin(self, msg): try: diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 5c71aa8dc..e17beb88c 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -19,7 +19,7 @@ from copy import deepcopy from mnemonic import Mnemonic -from . import messages as proto, protobuf +from . import mapping, messages as proto, protobuf from .client import TrezorClient from .tools import expect @@ -45,11 +45,12 @@ class DebugLink: self.transport.end_session() def _call(self, msg, nowait=False): - self.transport.write(msg) + msg_type, msg_bytes = mapping.encode(msg) + self.transport.write(msg_type, msg_bytes) if nowait: return None - ret = self.transport.read() - return ret + ret_type, ret_bytes = self.transport.read() + return mapping.decode(ret_type, ret_bytes) def state(self): return self._call(proto.DebugLinkGetState()) diff --git a/python/src/trezorlib/log.py b/python/src/trezorlib/log.py index 8c6aa5a85..5740ecc53 100644 --- a/python/src/trezorlib/log.py +++ b/python/src/trezorlib/log.py @@ -22,8 +22,10 @@ from . import protobuf OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]] DUMP_BYTES = 5 +DUMP_PACKETS = 4 logging.addLevelName(DUMP_BYTES, "BYTES") +logging.addLevelName(DUMP_PACKETS, "PACKETS") class PrettyProtobufFormatter(logging.Formatter): @@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = level = logging.DEBUG if verbosity > 1: level = DUMP_BYTES + if verbosity > 2: + level = DUMP_PACKETS logger = logging.getLogger("trezorlib") logger.setLevel(level) diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 64c3b0a3d..c37071212 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -14,7 +14,10 @@ # You should have received a copy of the License along with this library. # If not, see . -from . import messages +import io +from typing import Tuple + +from . import messages, protobuf map_type_to_class = {} map_class_to_type = {} @@ -59,4 +62,17 @@ def get_class(t): return map_type_to_class[t] +def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: + message_type = msg.MESSAGE_WIRE_TYPE + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return message_type, buf.getvalue() + + +def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType: + cls = get_class(message_type) + buf = io.BytesIO(message_bytes) + return protobuf.load_message(buf, cls) + + build_map() diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 46b1cfe62..f71642cc3 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -15,10 +15,9 @@ # If not, see . import logging -from typing import Iterable, List, Type +from typing import Iterable, List, Tuple, Type from ..exceptions import TrezorException -from ..protobuf import MessageType LOG = logging.getLogger(__name__) @@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() +MessagePayload = Tuple[int, bytes] + + class TransportException(TrezorException): pass @@ -71,10 +73,10 @@ class Transport: def end_session(self) -> None: raise NotImplementedError - def read(self) -> MessageType: + def read(self) -> MessagePayload: raise NotImplementedError - def write(self, message: MessageType) -> None: + def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError @classmethod diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index 872cbb612..54e8e5392 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -16,14 +16,12 @@ import logging import struct -from io import BytesIO from typing import Any, Dict, Iterable, Optional import requests -from .. import mapping, protobuf -from ..log import DUMP_BYTES -from . import Transport, TransportException +from ..log import DUMP_PACKETS +from . import MessagePayload, Transport, TransportException LOG = logging.getLogger(__name__) @@ -66,10 +64,12 @@ class BridgeHandle: class BridgeHandleModern(BridgeHandle): def write_buf(self, buf: bytes) -> None: + LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex())) self.transport._call("post", data=buf.hex()) def read_buf(self) -> bytes: data = self.transport._call("read") + LOG.log(DUMP_PACKETS, "received message: {}".format(data.text)) return bytes.fromhex(data.text) @@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle): if self.request is None: raise TransportException("Can't read without write on legacy Bridge") try: + LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request)) data = self.transport._call("call", data=self.request) + LOG.log(DUMP_PACKETS, "received response: {}".format(data.text)) return bytes.fromhex(data.text) finally: self.request = None @@ -152,29 +154,12 @@ class BridgeTransport(Transport): self._call("release") self.session = None - def write(self, msg: protobuf.MessageType) -> None: - LOG.debug( - "sending message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - buffer = BytesIO() - protobuf.dump_message(buffer, msg) - ser = buffer.getvalue() - LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex())) - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - - self.handle.write_buf(header + ser) + def write(self, message_type: int, message_data: bytes) -> None: + header = struct.pack(">HL", message_type, len(message_data)) + self.handle.write_buf(header + message_data) - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) - ser = data[headerlen : headerlen + datalen] - LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex())) - buffer = BytesIO(ser) - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - return msg + return msg_type, data[headerlen : headerlen + datalen] diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 12bf4c25d..aab512b2f 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -19,6 +19,7 @@ import sys import time from typing import Any, Dict, Iterable +from ..log import DUMP_PACKETS from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -82,9 +83,10 @@ class HidHandle: raise TransportException("Unexpected chunk size: %d" % len(chunk)) if self.hid_version == 2: - self.handle.write(b"\0" + bytearray(chunk)) - else: - self.handle.write(chunk) + chunk = b"\x00" + chunk + + LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex())) + self.handle.write(chunk) def read_chunk(self) -> bytes: while True: @@ -93,6 +95,8 @@ class HidHandle: break else: time.sleep(0.001) + + LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return bytes(chunk) @@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport): self.device = device self.handle = HidHandle(device["path"], device["serial_number"]) - protocol = ProtocolV1(self.handle) - super().__init__(protocol=protocol) + super().__init__(protocol=ProtocolV1(self.handle)) def get_path(self) -> str: return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode()) @@ -142,15 +145,11 @@ class HidTransport(ProtocolBasedTransport): return devices def find_debug(self) -> "HidTransport": - if self.protocol.VERSION >= 2: - # use the same device - return self - else: - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") def is_wirelink(dev: HidDevice) -> bool: diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index 92afcfb3d..da3806cd7 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -15,16 +15,12 @@ # If not, see . import logging -import os import struct -from io import BytesIO from typing import Tuple from typing_extensions import Protocol as StructuralType -from .. import mapping, protobuf -from ..log import DUMP_BYTES -from . import Transport +from . import MessagePayload, Transport REPLEN = 64 @@ -72,7 +68,6 @@ class Protocol: - open and close physical connections, - and send and receive binary chunks. - We declare a protocol version (we have implementations of v1 and v2). For now, the class also handles session counting and opening the underlying Handle. This will probably be removed in the future. @@ -80,8 +75,6 @@ class Protocol: its messages. """ - VERSION = None # type: int - def __init__(self, handle: Handle) -> None: self.handle = handle self.session_counter = 0 @@ -97,10 +90,10 @@ class Protocol: if self.session_counter == 0: self.handle.close() - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: raise NotImplementedError - def write(self, message: protobuf.MessageType) -> None: + def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError @@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport): def __init__(self, protocol: Protocol) -> None: self.protocol = protocol - def write(self, message: protobuf.MessageType) -> None: - self.protocol.write(message) + def write(self, message_type: int, message_data: bytes) -> None: + self.protocol.write(message_type, message_data) - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: return self.protocol.read() def begin_session(self) -> None: @@ -132,19 +125,11 @@ class ProtocolV1(Protocol): Does not understand sessions. """ - VERSION = 1 + HEADER_LEN = struct.calcsize(">HL") - def write(self, msg: protobuf.MessageType) -> None: - LOG.debug( - "sending message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - data = BytesIO() - protobuf.dump_message(data, msg) - ser = data.getvalue() - LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex())) - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - buffer = bytearray(b"##" + header + ser) + def write(self, message_type: int, message_data: bytes) -> None: + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) while buffer: # Report ID, data padded to 63 bytes @@ -153,7 +138,7 @@ class ProtocolV1(Protocol): self.handle.write_chunk(chunk) buffer = buffer[63:] - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: buffer = bytearray() # Read header with first part of message data msg_type, datalen, first_chunk = self.read_first() @@ -163,30 +148,18 @@ class ProtocolV1(Protocol): while len(buffer) < datalen: buffer.extend(self.read_next()) - # Strip padding - ser = buffer[:datalen] - data = BytesIO(ser) - LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex())) - - # Parse to protobuf - msg = protobuf.load_message(data, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - return msg + return msg_type, buffer[:datalen] def read_first(self) -> Tuple[int, int, bytes]: chunk = self.handle.read_chunk() if chunk[:3] != b"?##": raise RuntimeError("Unexpected magic characters") try: - headerlen = struct.calcsize(">HL") - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen]) + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) except Exception: raise RuntimeError("Cannot parse header") - data = chunk[3 + headerlen :] + data = chunk[3 + self.HEADER_LEN :] return msg_type, datalen, data def read_next(self) -> bytes: @@ -194,160 +167,3 @@ class ProtocolV1(Protocol): if chunk[:1] != b"?": raise RuntimeError("Unexpected magic characters") return chunk[1:] - - -class ProtocolV2(Protocol): - """Protocol version 2. Currently (11/2018) not used. - Intended to mimic U2F/WebAuthN session handling. - """ - - VERSION = 2 - - def __init__(self, handle: Handle) -> None: - self.session = None - super().__init__(handle) - - def begin_session(self) -> None: - # ensure open connection - super().begin_session() - - # initiate session - chunk = struct.pack(">B", V2_BEGIN_SESSION) - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - - # get session identifier - resp = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BL") - magic, session = struct.unpack(">BL", resp[:headerlen]) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_BEGIN_SESSION: - raise RuntimeError("Unexpected magic character") - self.session = session - - LOG.debug("[session {}] session started".format(self.session)) - - def end_session(self) -> None: - if not self.session: - return - - try: - chunk = struct.pack(">BL", V2_END_SESSION, self.session) - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - resp = self.handle.read_chunk() - (magic,) = struct.unpack(">B", resp[:1]) - if magic != V2_END_SESSION: - raise RuntimeError("Expected session close") - LOG.debug("[session {}] session ended".format(self.session)) - finally: - self.session = None - # close connection if appropriate - super().end_session() - - def write(self, msg: protobuf.MessageType) -> None: - if not self.session: - raise RuntimeError("Missing session for v2 protocol") - - LOG.debug( - "[session {}] sending message: {}".format( - self.session, msg.__class__.__name__ - ), - extra={"protobuf": msg}, - ) - # Serialize whole message - data = BytesIO() - protobuf.dump_message(data, msg) - data = data.getvalue() - dataheader = struct.pack(">LL", mapping.get_type(msg), len(data)) - data = dataheader + data - seq = -1 - - # Write it out - while data: - if seq < 0: - repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session) - else: - repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq) - datalen = REPLEN - len(repheader) - chunk = repheader + data[:datalen] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - data = data[datalen:] - seq += 1 - - def read(self) -> protobuf.MessageType: - if not self.session: - raise RuntimeError("Missing session for v2 protocol") - - buffer = bytearray() - - # Read header with first part of message data - msg_type, datalen, chunk = self.read_first() - buffer.extend(chunk) - - # Read the rest of the message - while len(buffer) < datalen: - next_chunk = self.read_next() - buffer.extend(next_chunk) - - # Strip padding - buffer = BytesIO(buffer[:datalen]) - - # Parse to protobuf - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "[session {}] received message: {}".format( - self.session, msg.__class__.__name__ - ), - extra={"protobuf": msg}, - ) - return msg - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BLLL") - magic, session, msg_type, datalen = struct.unpack( - ">BLLL", chunk[:headerlen] - ) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_FIRST_CHUNK: - raise RuntimeError("Unexpected magic character") - if session != self.session: - raise RuntimeError("Session id mismatch") - return msg_type, datalen, chunk[headerlen:] - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BLL") - magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen]) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_NEXT_CHUNK: - raise RuntimeError("Unexpected magic characters") - if session != self.session: - raise RuntimeError("Session id mismatch") - return chunk[headerlen:] - - -def get_protocol(handle: Handle, want_v2: bool) -> Protocol: - """Make a Protocol instance for the given handle. - - Each transport can have a preference for using a particular protocol version. - This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable, - which forces the library to use V1 anyways. - - As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible - to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask - for it (i.e., USB transports for Trezor T). - """ - force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1)) - if want_v2 and not force_v1: - return ProtocolV2(handle) - else: - return ProtocolV1(handle) diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7a79079ed..e95830f97 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,15 +14,19 @@ # You should have received a copy of the License along with this library. # If not, see . +import logging import socket import time from typing import Iterable, Optional, cast +from ..log import DUMP_PACKETS from . import TransportException -from .protocol import ProtocolBasedTransport, get_protocol +from .protocol import ProtocolBasedTransport, ProtocolV1 SOCKET_TIMEOUT = 10 +LOG = logging.getLogger(__name__) + class UdpTransport(ProtocolBasedTransport): @@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport): self.device = (host, port) self.socket = None # type: Optional[socket.socket] - protocol = get_protocol(self, want_v2=False) - super().__init__(protocol=protocol) + super().__init__(protocol=ProtocolV1(self)) def get_path(self) -> str: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) @@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport): assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") + LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex())) self.socket.sendall(chunk) def read_chunk(self) -> bytes: @@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport): break except socket.timeout: continue + LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return bytearray(chunk) diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index af9ffdfdb..40cd1ea40 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -20,6 +20,7 @@ import sys import time from typing import Iterable, Optional +from ..log import DUMP_PACKETS from . import TREZORS, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -65,6 +66,7 @@ class WebUsbHandle: assert self.handle is not None if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) + LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex())) self.handle.interruptWrite(self.endpoint, chunk) def read_chunk(self) -> bytes: @@ -76,6 +78,7 @@ class WebUsbHandle: break else: time.sleep(0.001) + LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return chunk @@ -136,14 +139,8 @@ class WebUsbTransport(ProtocolBasedTransport): return devices def find_debug(self) -> "WebUsbTransport": - if self.protocol.VERSION >= 2: - # TODO test this - # XXX this is broken right now because sessions don't really work - # For v2 protocol, use the same WebUSB interface with a different session - return WebUsbTransport(self.device, self.handle) - else: - # For v1 protocol, find debug USB interface for the same serial number - return WebUsbTransport(self.device, debug=True) + # For v1 protocol, find debug USB interface for the same serial number + return WebUsbTransport(self.device, debug=True) def is_vendor_class(dev: "usb1.USBDevice") -> bool: diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b05b684a2..9fa5f82c9 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message): resp = client.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client.transport.write(m.ButtonAck()) - client.transport.write(m.Initialize()) + client._raw_write(m.ButtonAck()) + client._raw_write(m.Initialize()) - resp = client.transport.read() + resp = client._raw_read() assert isinstance(resp, m.Features) diff --git a/tests/device_tests/test_msg_recoverydevice_bip39_t2.py b/tests/device_tests/test_msg_recoverydevice_bip39_t2.py index 79d516e21..886f326b4 100644 --- a/tests/device_tests/test_msg_recoverydevice_bip39_t2.py +++ b/tests/device_tests/test_msg_recoverydevice_bip39_t2.py @@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2: # Enter mnemonic words assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) - client.transport.write(proto.ButtonAck()) + client._raw_write(proto.ButtonAck()) for word in mnemonic: client.debug.input(word) - ret = client.transport.read() + ret = client._raw_read() # Confirm success assert isinstance(ret, proto.ButtonRequest) @@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2: # Enter mnemonic words assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) - client.transport.write(proto.ButtonAck()) + client._raw_write(proto.ButtonAck()) for word in mnemonic: client.debug.input(word) - ret = client.transport.read() + ret = client._raw_read() # Confirm success assert isinstance(ret, proto.ButtonRequest)