diff --git a/trezorlib/transport/bridge.py b/trezorlib/transport/bridge.py index 7b5cfb4776..68e5635ff6 100644 --- a/trezorlib/transport/bridge.py +++ b/trezorlib/transport/bridge.py @@ -29,6 +29,8 @@ LOG = logging.getLogger(__name__) TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} +TREZORD_VERSION_MODERN = (2, 0, 25) + CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -44,6 +46,52 @@ def call_bridge(uri: str, data=None) -> requests.Response: return r +def is_legacy_bridge() -> bool: + config = call_bridge("configure").json() + version_tuple = tuple(map(int, config["version"].split("."))) + return version_tuple < TREZORD_VERSION_MODERN + + +class BridgeHandle: + def __init__(self, transport: "BridgeTransport") -> None: + self.transport = transport + + def read_buf(self) -> bytes: + raise NotImplementedError + + def write_buf(self, buf: bytes) -> None: + raise NotImplementedError + + +class BridgeHandleModern(BridgeHandle): + def write_buf(self, buf: bytes) -> None: + self.transport._call("post", data=buf.hex()) + + def read_buf(self) -> bytes: + data = self.transport._call("read") + return bytes.fromhex(data.text) + + +class BridgeHandleLegacy(BridgeHandle): + def __init__(self, transport: "BridgeTransport") -> None: + super().__init__(transport) + self.request = None # type: Optional[str] + + def write_buf(self, buf: bytes) -> None: + if self.request is not None: + raise TransportException("Can't write twice on legacy Bridge") + self.request = buf.hex() + + def read_buf(self) -> bytes: + if self.request is None: + raise TransportException("Can't read without write on legacy Bridge") + try: + data = self.transport._call("call", data=self.request) + return bytes.fromhex(data.text) + finally: + self.request = None + + class BridgeTransport(Transport): """ BridgeTransport implements transport through TREZOR Bridge (aka trezord). @@ -51,44 +99,58 @@ class BridgeTransport(Transport): PATH_PREFIX = "bridge" - def __init__(self, device: Dict[str, Any]) -> None: + def __init__( + self, device: Dict[str, Any], legacy: bool, debug: bool = False + ) -> None: + if legacy and debug: + raise TransportException("Debugging not supported on legacy Bridge") + self.device = device self.session = None # type: Optional[str] - self.request = None # type: Optional[str] - self.debug = False + self.debug = debug + self.legacy = legacy + + if legacy: + self.handle = BridgeHandleLegacy(self) # type: BridgeHandle + else: + self.handle = BridgeHandleModern(self) def get_path(self) -> str: return "%s:%s" % (self.PATH_PREFIX, self.device["path"]) - @classmethod - def enumerate(cls) -> Iterable["BridgeTransport"]: - try: - return [BridgeTransport(dev) for dev in call_bridge("enumerate").json()] - except Exception: - return [] + def find_debug(self) -> "BridgeTransport": + if not self.device.get("debug"): + raise TransportException("Debug device not available") + return BridgeTransport(self.device, self.legacy, debug=True) def _call(self, action: str, data: str = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) + if self.debug: + uri = "debug/" + uri return call_bridge(uri, data=data) + @classmethod + def enumerate(cls) -> Iterable["BridgeTransport"]: + try: + legacy = is_legacy_bridge() + return [ + BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() + ] + except Exception: + return [] + def begin_session(self) -> None: - LOG.debug("acquiring session from {}".format(self.session)) data = self._call("acquire/" + self.device["path"]) self.session = data.json()["session"] - LOG.debug("acquired session {}".format(self.session)) def end_session(self) -> None: - LOG.debug("releasing session {}".format(self.session)) if not self.session: return self._call("release") self.session = None def write(self, msg: protobuf.MessageType) -> None: - if self.request is not None: - raise TransportException("Cannot enqueue two requests") - LOG.debug( "sending message: {}".format(msg.__class__.__name__), extra={"protobuf": msg}, @@ -98,25 +160,19 @@ class BridgeTransport(Transport): ser = buffer.getvalue() header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - self.request = (header + ser).hex() + self.handle.write_buf(header + ser) def read(self) -> protobuf.MessageType: - if self.request is None: - raise TransportException("No request stored") - - try: - data = bytes.fromhex(self._call("call", data=self.request).text) - headerlen = struct.calcsize(">HL") - msg_type, datalen = struct.unpack(">HL", data[:headerlen]) - buffer = BytesIO(data[headerlen : headerlen + datalen]) - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - return msg - finally: - self.request = None + data = self.handle.read_buf() + headerlen = struct.calcsize(">HL") + msg_type, datalen = struct.unpack(">HL", data[:headerlen]) + buffer = BytesIO(data[headerlen : headerlen + datalen]) + msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) + LOG.debug( + "received message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + return msg TRANSPORT = BridgeTransport