diff --git a/trezorlib/transport/bridge.py b/trezorlib/transport/bridge.py index 1dd7e52668..7b5cfb4776 100644 --- a/trezorlib/transport/bridge.py +++ b/trezorlib/transport/bridge.py @@ -14,7 +14,6 @@ # You should have received a copy of the License along with this library. # If not, see . -import binascii import logging import struct from io import BytesIO @@ -28,10 +27,21 @@ from .. import mapping, protobuf LOG = logging.getLogger(__name__) TREZORD_HOST = "http://127.0.0.1:21325" +TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} + +CONNECTION = requests.Session() +CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) -def get_error(resp: requests.Response) -> str: - return " (error=%d str=%s)" % (resp.status_code, resp.json()["error"]) +def call_bridge(uri: str, data=None) -> requests.Response: + url = TREZORD_HOST + "/" + uri + r = CONNECTION.post(url, data=data) + if r.status_code != 200: + error_str = "trezord: {} failed with code {}: {}".format( + uri, r.status_code, r.json()["error"] + ) + raise TransportException(error_str) + return r class BridgeTransport(Transport): @@ -40,13 +50,12 @@ class BridgeTransport(Transport): """ PATH_PREFIX = "bridge" - HEADERS = {"Origin": "https://python.trezor.io"} def __init__(self, device: Dict[str, Any]) -> None: self.device = device - self.conn = requests.Session() self.session = None # type: Optional[str] - self.response = None # type: Optional[str] + self.request = None # type: Optional[str] + self.debug = False def get_path(self) -> str: return "%s:%s" % (self.PATH_PREFIX, self.device["path"]) @@ -54,39 +63,32 @@ class BridgeTransport(Transport): @classmethod def enumerate(cls) -> Iterable["BridgeTransport"]: try: - r = requests.post(TREZORD_HOST + "/enumerate", headers=cls.HEADERS) - if r.status_code != 200: - raise TransportException( - "trezord: Could not enumerate devices" + get_error(r) - ) - return [BridgeTransport(dev) for dev in r.json()] + return [BridgeTransport(dev) for dev in call_bridge("enumerate").json()] except Exception: return [] + def _call(self, action: str, data: str = None) -> requests.Response: + session = self.session or "null" + uri = action + "/" + str(session) + return call_bridge(uri, data=data) + def begin_session(self) -> None: - r = self.conn.post( - TREZORD_HOST + "/acquire/%s/null" % self.device["path"], - headers=self.HEADERS, - ) - if r.status_code != 200: - raise TransportException( - "trezord: Could not acquire session" + get_error(r) - ) - self.session = r.json()["session"] + LOG.debug("acquiring session from {}".format(self.session)) + data = self._call("acquire/" + self.device["path"]) + self.session = data.json()["session"] + LOG.debug("acquired session {}".format(self.session)) def end_session(self) -> None: + LOG.debug("releasing session {}".format(self.session)) if not self.session: return - r = self.conn.post( - TREZORD_HOST + "/release/%s" % self.session, headers=self.HEADERS - ) - if r.status_code != 200: - raise TransportException( - "trezord: Could not release session" + get_error(r) - ) + self._call("release") self.session = None def write(self, msg: protobuf.MessageType) -> None: + if self.request is not None: + raise TransportException("Cannot enqueue two requests") + LOG.debug( "sending message: {}".format(msg.__class__.__name__), extra={"protobuf": msg}, @@ -95,28 +97,26 @@ class BridgeTransport(Transport): protobuf.dump_message(buffer, msg) ser = buffer.getvalue() header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - data = binascii.hexlify(header + ser).decode() - r = self.conn.post( # type: ignore # typeshed bug - TREZORD_HOST + "/call/%s" % self.session, data=data, headers=self.HEADERS - ) - if r.status_code != 200: - raise TransportException("trezord: Could not write message" + get_error(r)) - self.response = r.text + + self.request = (header + ser).hex() def read(self) -> protobuf.MessageType: - if self.response is None: - raise TransportException("No response stored") - data = binascii.unhexlify(self.response) - headerlen = struct.calcsize(">HL") - (msg_type, datalen) = struct.unpack(">HL", data[:headerlen]) - buffer = BytesIO(data[headerlen : headerlen + datalen]) - msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) - LOG.debug( - "received message: {}".format(msg.__class__.__name__), - extra={"protobuf": msg}, - ) - self.response = None - return msg + if self.request is None: + raise TransportException("No request stored") + + try: + data = bytes.fromhex(self._call("call", data=self.request).text) + headerlen = struct.calcsize(">HL") + msg_type, datalen = struct.unpack(">HL", data[:headerlen]) + buffer = BytesIO(data[headerlen : headerlen + datalen]) + msg = protobuf.load_message(buffer, mapping.get_class(msg_type)) + LOG.debug( + "received message: {}".format(msg.__class__.__name__), + extra={"protobuf": msg}, + ) + return msg + finally: + self.request = None TRANSPORT = BridgeTransport