diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 521b023e6b..247774ed0b 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/ - `get_default_client` respects `TREZOR_PATH` environment variable - UI callback `get_passphrase` has an additional argument `available_on_device`, indicating that the connected Trezor is capable of on-device entry +- `Transport.write` and `read` method signatures changed to accept bytes instead of + protobuf messages ### Fixed diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 7f3262e976..1c0bbfa55c 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -21,7 +21,8 @@ import warnings from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools +from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools +from .log import DUMP_BYTES from .messages import Capability if sys.version_info.major < 3: @@ -114,11 +115,34 @@ class TrezorClient: def _raw_write(self, msg): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self.transport.write(msg) + LOG.debug( + "sending message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = mapping.encode(msg) + LOG.log( + DUMP_BYTES, + "encoded as type {} ({} bytes): {}".format( + msg_type, len(msg_bytes), msg_bytes.hex() + ), + ) + self.transport.write(msg_type, msg_bytes) def _raw_read(self): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - return self.transport.read() + msg_type, msg_bytes = self.transport.read() + LOG.log( + DUMP_BYTES, + "received type {} ({} bytes): {}".format( + msg_type, len(msg_bytes), msg_bytes.hex() + ), + ) + msg = mapping.decode(msg_type, msg_bytes) + LOG.debug( + "received message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + return msg def _callback_pin(self, msg): try: diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 5c71aa8dcf..e17beb88cc 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -19,7 +19,7 @@ from copy import deepcopy from mnemonic import Mnemonic -from . import messages as proto, protobuf +from . import mapping, messages as proto, protobuf from .client import TrezorClient from .tools import expect @@ -45,11 +45,12 @@ class DebugLink: self.transport.end_session() def _call(self, msg, nowait=False): - self.transport.write(msg) + msg_type, msg_bytes = mapping.encode(msg) + self.transport.write(msg_type, msg_bytes) if nowait: return None - ret = self.transport.read() - return ret + ret_type, ret_bytes = self.transport.read() + return mapping.decode(ret_type, ret_bytes) def state(self): return self._call(proto.DebugLinkGetState()) diff --git a/python/src/trezorlib/log.py b/python/src/trezorlib/log.py index 8c6aa5a85e..5740ecc533 100644 --- a/python/src/trezorlib/log.py +++ b/python/src/trezorlib/log.py @@ -22,8 +22,10 @@ from . import protobuf OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]] DUMP_BYTES = 5 +DUMP_PACKETS = 4 logging.addLevelName(DUMP_BYTES, "BYTES") +logging.addLevelName(DUMP_PACKETS, "PACKETS") class PrettyProtobufFormatter(logging.Formatter): @@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] = level = logging.DEBUG if verbosity > 1: level = DUMP_BYTES + if verbosity > 2: + level = DUMP_PACKETS logger = logging.getLogger("trezorlib") logger.setLevel(level) diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 64c3b0a3d4..c370712122 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -14,7 +14,10 @@ # You should have received a copy of the License along with this library. # If not, see . -from . import messages +import io +from typing import Tuple + +from . import messages, protobuf map_type_to_class = {} map_class_to_type = {} @@ -59,4 +62,17 @@ def get_class(t): return map_type_to_class[t] +def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: + message_type = msg.MESSAGE_WIRE_TYPE + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return message_type, buf.getvalue() + + +def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType: + cls = get_class(message_type) + buf = io.BytesIO(message_bytes) + return protobuf.load_message(buf, cls) + + build_map() diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 46b1cfe628..f71642cc39 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -15,10 +15,9 @@ # If not, see . import logging -from typing import Iterable, List, Type +from typing import Iterable, List, Tuple, Type from ..exceptions import TrezorException -from ..protobuf import MessageType LOG = logging.getLogger(__name__) @@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() +MessagePayload = Tuple[int, bytes] + + class TransportException(TrezorException): pass @@ -71,10 +73,10 @@ class Transport: def end_session(self) -> None: raise NotImplementedError - def read(self) -> MessageType: + def read(self) -> MessagePayload: raise NotImplementedError - def write(self, message: MessageType) -> None: + def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError @classmethod diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index 872cbb6122..54e8e53928 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -16,14 +16,12 @@ import logging import struct -from io import BytesIO from typing import Any, Dict, Iterable, Optional import requests -from .. import mapping, protobuf -from ..log import DUMP_BYTES -from . import Transport, TransportException +from ..log import DUMP_PACKETS +from . import MessagePayload, Transport, TransportException LOG = logging.getLogger(__name__) @@ -66,10 +64,12 @@ class BridgeHandle: class BridgeHandleModern(BridgeHandle): def write_buf(self, buf: bytes) -> None: + LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex())) self.transport._call("post", data=buf.hex()) def read_buf(self) -> bytes: data = self.transport._call("read") + LOG.log(DUMP_PACKETS, "received message: {}".format(data.text)) return bytes.fromhex(data.text) @@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle): if self.request is None: raise TransportException("Can't read without write on legacy Bridge") try: + LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request)) data = self.transport._call("call", data=self.request) + LOG.log(DUMP_PACKETS, "received response: {}".format(data.text)) return bytes.fromhex(data.text) finally: self.request = None @@ -152,29 +154,12 @@ class BridgeTransport(Transport): self._call("release") self.session = None - def write(self, msg: protobuf.MessageType) -> None: - LOG.debug( - "sending message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - buffer = BytesIO() - protobuf.dump_message(buffer, msg) - ser = buffer.getvalue() - LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex())) - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) + def write(self, message_type: int, message_data: bytes) -> None: + header = struct.pack(">HL", message_type, len(message_data)) + self.handle.write_buf(header + message_data) - self.handle.write_buf(header + ser) - - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) - ser = data[headerlen : headerlen + datalen] - LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex())) - buffer = BytesIO(ser) - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - return msg + return msg_type, data[headerlen : headerlen + datalen] diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 12bf4c25de..aab512b2fe 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -19,6 +19,7 @@ import sys import time from typing import Any, Dict, Iterable +from ..log import DUMP_PACKETS from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -82,9 +83,10 @@ class HidHandle: raise TransportException("Unexpected chunk size: %d" % len(chunk)) if self.hid_version == 2: - self.handle.write(b"\0" + bytearray(chunk)) - else: - self.handle.write(chunk) + chunk = b"\x00" + chunk + + LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex())) + self.handle.write(chunk) def read_chunk(self) -> bytes: while True: @@ -93,6 +95,8 @@ class HidHandle: break else: time.sleep(0.001) + + LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return bytes(chunk) @@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport): self.device = device self.handle = HidHandle(device["path"], device["serial_number"]) - protocol = ProtocolV1(self.handle) - super().__init__(protocol=protocol) + super().__init__(protocol=ProtocolV1(self.handle)) def get_path(self) -> str: return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode()) @@ -142,15 +145,11 @@ class HidTransport(ProtocolBasedTransport): return devices def find_debug(self) -> "HidTransport": - if self.protocol.VERSION >= 2: - # use the same device - return self - else: - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") def is_wirelink(dev: HidDevice) -> bool: diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index 92afcfb3db..da3806cd76 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -15,16 +15,12 @@ # If not, see . import logging -import os import struct -from io import BytesIO from typing import Tuple from typing_extensions import Protocol as StructuralType -from .. import mapping, protobuf -from ..log import DUMP_BYTES -from . import Transport +from . import MessagePayload, Transport REPLEN = 64 @@ -72,7 +68,6 @@ class Protocol: - open and close physical connections, - and send and receive binary chunks. - We declare a protocol version (we have implementations of v1 and v2). For now, the class also handles session counting and opening the underlying Handle. This will probably be removed in the future. @@ -80,8 +75,6 @@ class Protocol: its messages. """ - VERSION = None # type: int - def __init__(self, handle: Handle) -> None: self.handle = handle self.session_counter = 0 @@ -97,10 +90,10 @@ class Protocol: if self.session_counter == 0: self.handle.close() - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: raise NotImplementedError - def write(self, message: protobuf.MessageType) -> None: + def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError @@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport): def __init__(self, protocol: Protocol) -> None: self.protocol = protocol - def write(self, message: protobuf.MessageType) -> None: - self.protocol.write(message) + def write(self, message_type: int, message_data: bytes) -> None: + self.protocol.write(message_type, message_data) - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: return self.protocol.read() def begin_session(self) -> None: @@ -132,19 +125,11 @@ class ProtocolV1(Protocol): Does not understand sessions. """ - VERSION = 1 + HEADER_LEN = struct.calcsize(">HL") - def write(self, msg: protobuf.MessageType) -> None: - LOG.debug( - "sending message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - data = BytesIO() - protobuf.dump_message(data, msg) - ser = data.getvalue() - LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex())) - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - buffer = bytearray(b"##" + header + ser) + def write(self, message_type: int, message_data: bytes) -> None: + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) while buffer: # Report ID, data padded to 63 bytes @@ -153,7 +138,7 @@ class ProtocolV1(Protocol): self.handle.write_chunk(chunk) buffer = buffer[63:] - def read(self) -> protobuf.MessageType: + def read(self) -> MessagePayload: buffer = bytearray() # Read header with first part of message data msg_type, datalen, first_chunk = self.read_first() @@ -163,30 +148,18 @@ class ProtocolV1(Protocol): while len(buffer) < datalen: buffer.extend(self.read_next()) - # Strip padding - ser = buffer[:datalen] - data = BytesIO(ser) - LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex())) - - # Parse to protobuf - msg = protobuf.load_message(data, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - return msg + return msg_type, buffer[:datalen] def read_first(self) -> Tuple[int, int, bytes]: chunk = self.handle.read_chunk() if chunk[:3] != b"?##": raise RuntimeError("Unexpected magic characters") try: - headerlen = struct.calcsize(">HL") - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen]) + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) except Exception: raise RuntimeError("Cannot parse header") - data = chunk[3 + headerlen :] + data = chunk[3 + self.HEADER_LEN :] return msg_type, datalen, data def read_next(self) -> bytes: @@ -194,160 +167,3 @@ class ProtocolV1(Protocol): if chunk[:1] != b"?": raise RuntimeError("Unexpected magic characters") return chunk[1:] - - -class ProtocolV2(Protocol): - """Protocol version 2. Currently (11/2018) not used. - Intended to mimic U2F/WebAuthN session handling. - """ - - VERSION = 2 - - def __init__(self, handle: Handle) -> None: - self.session = None - super().__init__(handle) - - def begin_session(self) -> None: - # ensure open connection - super().begin_session() - - # initiate session - chunk = struct.pack(">B", V2_BEGIN_SESSION) - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - - # get session identifier - resp = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BL") - magic, session = struct.unpack(">BL", resp[:headerlen]) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_BEGIN_SESSION: - raise RuntimeError("Unexpected magic character") - self.session = session - - LOG.debug("[session {}] session started".format(self.session)) - - def end_session(self) -> None: - if not self.session: - return - - try: - chunk = struct.pack(">BL", V2_END_SESSION, self.session) - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - resp = self.handle.read_chunk() - (magic,) = struct.unpack(">B", resp[:1]) - if magic != V2_END_SESSION: - raise RuntimeError("Expected session close") - LOG.debug("[session {}] session ended".format(self.session)) - finally: - self.session = None - # close connection if appropriate - super().end_session() - - def write(self, msg: protobuf.MessageType) -> None: - if not self.session: - raise RuntimeError("Missing session for v2 protocol") - - LOG.debug( - "[session {}] sending message: {}".format( - self.session, msg.__class__.__name__ - ), - extra={"protobuf": msg}, - ) - # Serialize whole message - data = BytesIO() - protobuf.dump_message(data, msg) - data = data.getvalue() - dataheader = struct.pack(">LL", mapping.get_type(msg), len(data)) - data = dataheader + data - seq = -1 - - # Write it out - while data: - if seq < 0: - repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session) - else: - repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq) - datalen = REPLEN - len(repheader) - chunk = repheader + data[:datalen] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - data = data[datalen:] - seq += 1 - - def read(self) -> protobuf.MessageType: - if not self.session: - raise RuntimeError("Missing session for v2 protocol") - - buffer = bytearray() - - # Read header with first part of message data - msg_type, datalen, chunk = self.read_first() - buffer.extend(chunk) - - # Read the rest of the message - while len(buffer) < datalen: - next_chunk = self.read_next() - buffer.extend(next_chunk) - - # Strip padding - buffer = BytesIO(buffer[:datalen]) - - # Parse to protobuf - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "[session {}] received message: {}".format( - self.session, msg.__class__.__name__ - ), - extra={"protobuf": msg}, - ) - return msg - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BLLL") - magic, session, msg_type, datalen = struct.unpack( - ">BLLL", chunk[:headerlen] - ) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_FIRST_CHUNK: - raise RuntimeError("Unexpected magic character") - if session != self.session: - raise RuntimeError("Session id mismatch") - return msg_type, datalen, chunk[headerlen:] - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - try: - headerlen = struct.calcsize(">BLL") - magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen]) - except Exception: - raise RuntimeError("Cannot parse header") - if magic != V2_NEXT_CHUNK: - raise RuntimeError("Unexpected magic characters") - if session != self.session: - raise RuntimeError("Session id mismatch") - return chunk[headerlen:] - - -def get_protocol(handle: Handle, want_v2: bool) -> Protocol: - """Make a Protocol instance for the given handle. - - Each transport can have a preference for using a particular protocol version. - This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable, - which forces the library to use V1 anyways. - - As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible - to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask - for it (i.e., USB transports for Trezor T). - """ - force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1)) - if want_v2 and not force_v1: - return ProtocolV2(handle) - else: - return ProtocolV1(handle) diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7a79079ed9..e95830f972 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,15 +14,19 @@ # You should have received a copy of the License along with this library. # If not, see . +import logging import socket import time from typing import Iterable, Optional, cast +from ..log import DUMP_PACKETS from . import TransportException -from .protocol import ProtocolBasedTransport, get_protocol +from .protocol import ProtocolBasedTransport, ProtocolV1 SOCKET_TIMEOUT = 10 +LOG = logging.getLogger(__name__) + class UdpTransport(ProtocolBasedTransport): @@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport): self.device = (host, port) self.socket = None # type: Optional[socket.socket] - protocol = get_protocol(self, want_v2=False) - super().__init__(protocol=protocol) + super().__init__(protocol=ProtocolV1(self)) def get_path(self) -> str: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) @@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport): assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") + LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex())) self.socket.sendall(chunk) def read_chunk(self) -> bytes: @@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport): break except socket.timeout: continue + LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return bytearray(chunk) diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index af9ffdfdb3..40cd1ea408 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -20,6 +20,7 @@ import sys import time from typing import Iterable, Optional +from ..log import DUMP_PACKETS from . import TREZORS, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -65,6 +66,7 @@ class WebUsbHandle: assert self.handle is not None if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) + LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex())) self.handle.interruptWrite(self.endpoint, chunk) def read_chunk(self) -> bytes: @@ -76,6 +78,7 @@ class WebUsbHandle: break else: time.sleep(0.001) + LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex())) if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return chunk @@ -136,14 +139,8 @@ class WebUsbTransport(ProtocolBasedTransport): return devices def find_debug(self) -> "WebUsbTransport": - if self.protocol.VERSION >= 2: - # TODO test this - # XXX this is broken right now because sessions don't really work - # For v2 protocol, use the same WebUSB interface with a different session - return WebUsbTransport(self.device, self.handle) - else: - # For v1 protocol, find debug USB interface for the same serial number - return WebUsbTransport(self.device, debug=True) + # For v1 protocol, find debug USB interface for the same serial number + return WebUsbTransport(self.device, debug=True) def is_vendor_class(dev: "usb1.USBDevice") -> bool: diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index b05b684a27..9fa5f82c94 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message): resp = client.call_raw(message) assert isinstance(resp, m.ButtonRequest) - client.transport.write(m.ButtonAck()) - client.transport.write(m.Initialize()) + client._raw_write(m.ButtonAck()) + client._raw_write(m.Initialize()) - resp = client.transport.read() + resp = client._raw_read() assert isinstance(resp, m.Features) diff --git a/tests/device_tests/test_msg_recoverydevice_bip39_t2.py b/tests/device_tests/test_msg_recoverydevice_bip39_t2.py index 79d516e215..886f326b4a 100644 --- a/tests/device_tests/test_msg_recoverydevice_bip39_t2.py +++ b/tests/device_tests/test_msg_recoverydevice_bip39_t2.py @@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2: # Enter mnemonic words assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) - client.transport.write(proto.ButtonAck()) + client._raw_write(proto.ButtonAck()) for word in mnemonic: client.debug.input(word) - ret = client.transport.read() + ret = client._raw_read() # Confirm success assert isinstance(ret, proto.ButtonRequest) @@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2: # Enter mnemonic words assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput) - client.transport.write(proto.ButtonAck()) + client._raw_write(proto.ButtonAck()) for word in mnemonic: client.debug.input(word) - ret = client.transport.read() + ret = client._raw_read() # Confirm success assert isinstance(ret, proto.ButtonRequest)