diff --git a/trezorctl b/trezorctl index b12499efb..1f7a10d78 100755 --- a/trezorctl +++ b/trezorctl @@ -24,12 +24,14 @@ import base64 import binascii import click import json +import logging import os import sys from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException from trezorlib.transport import get_transport, enumerate_devices from trezorlib import coins +from trezorlib import log from trezorlib import messages as proto from trezorlib import protobuf from trezorlib import stellar @@ -64,32 +66,38 @@ CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({ }) +def enable_logging(): + handler = logging.StreamHandler() + handler.setFormatter(log.PrettyProtobufFormatter()) + logger = logging.getLogger('trezorlib') + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + log.OMITTED_MESSAGES.add(proto.Features) + + @click.group(context_settings={'max_content_width': 400}) @click.option('-p', '--path', help='Select device by specific path.', default=os.environ.get('TREZOR_PATH')) @click.option('-v', '--verbose', is_flag=True, help='Show communication messages.') @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') @click.pass_context def cli(ctx, path, verbose, is_json): - if ctx.invoked_subcommand != 'list': - if verbose: - cls = TrezorClientVerbose - else: - cls = TrezorClient + if verbose: + enable_logging() - def get_device(): + def get_device(): + try: + device = get_transport(path, prefix_search=False) + except: try: - device = get_transport(path, prefix_search=False) + device = get_transport(path, prefix_search=True) except: - try: - device = get_transport(path, prefix_search=True) - except: - click.echo("Failed to find a TREZOR device.") - if path is not None: - click.echo("Using path: {}".format(path)) - sys.exit(1) - return cls(transport=device) - - ctx.obj = get_device + click.echo("Failed to find a TREZOR device.") + if path is not None: + click.echo("Using path: {}".format(path)) + sys.exit(1) + return TrezorClient(transport=device) + + ctx.obj = get_device @cli.resultcallback() diff --git a/trezorlib/log.py b/trezorlib/log.py new file mode 100644 index 000000000..f23e39ed3 --- /dev/null +++ b/trezorlib/log.py @@ -0,0 +1,19 @@ +import logging +from typing import Set, Type + +from . import protobuf + +OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]] + + +class PrettyProtobufFormatter(logging.Formatter): + + def format(self, record): + time = self.formatTime(record) + message = '[{time}] {level}: {msg}'.format(time=time, level=record.levelname, msg=super().format(record)) + if hasattr(record, 'protobuf'): + if type(record.protobuf) in OMITTED_MESSAGES: + message += " ({} bytes)".format(record.protobuf.ByteSize()) + else: + message += "\n" + protobuf.format_message(record.protobuf) + return message diff --git a/trezorlib/protobuf.py b/trezorlib/protobuf.py index b42347c3d..17e7fe3c7 100644 --- a/trezorlib/protobuf.py +++ b/trezorlib/protobuf.py @@ -377,4 +377,8 @@ def format_message(pb: MessageType, indent: int=0, sep: str= ' ' * 4) -> str: return repr(value) - return pb.__class__.__name__ + ' ' + pformat_value(pb.__dict__, indent) + return '{name} ({size} bytes) {content}'.format( + name=pb.__class__.__name__, + size=pb.ByteSize(), + content=pformat_value(pb.__dict__, indent) + ) diff --git a/trezorlib/protocol_v1.py b/trezorlib/protocol_v1.py index 6c04af325..cd263e7d7 100644 --- a/trezorlib/protocol_v1.py +++ b/trezorlib/protocol_v1.py @@ -19,22 +19,30 @@ from __future__ import absolute_import from io import BytesIO +import logging import struct +from typing import Tuple, Type + from . import mapping from . import protobuf +from .transport import Transport REPLEN = 64 +LOG = logging.getLogger(__name__) + -class ProtocolV1(object): +class ProtocolV1: - def session_begin(self, transport): + def session_begin(self, transport: Transport) -> None: pass - def session_end(self, transport): + def session_end(self, transport: Transport) -> None: pass - def write(self, transport, msg): + def write(self, transport: Transport, msg: protobuf.MessageType) -> None: + LOG.debug("sending message: {}".format(msg.__class__.__name__), + extra={'protobuf': msg}) data = BytesIO() protobuf.dump_message(data, msg) ser = data.getvalue() @@ -48,10 +56,10 @@ class ProtocolV1(object): transport.write_chunk(chunk) data = data[63:] - def read(self, transport): + def read(self, transport: Transport) -> protobuf.MessageType: # Read header with first part of message data chunk = transport.read_chunk() - (msg_type, datalen, data) = self.parse_first(chunk) + msg_type, datalen, data = self.parse_first(chunk) # Read the rest of the message while len(data) < datalen: @@ -63,21 +71,23 @@ class ProtocolV1(object): # Parse to protobuf msg = protobuf.load_message(data, mapping.get_class(msg_type)) + LOG.debug("received message: {}".format(msg.__class__.__name__), + extra={'protobuf': msg}) return msg - def parse_first(self, chunk): + def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]: if chunk[:3] != b'?##': raise RuntimeError('Unexpected magic characters') try: headerlen = struct.calcsize('>HL') - (msg_type, datalen) = struct.unpack('>HL', chunk[3:3 + headerlen]) + msg_type, datalen = struct.unpack('>HL', chunk[3:3 + headerlen]) except: raise RuntimeError('Cannot parse header') data = chunk[3 + headerlen:] - return (msg_type, datalen, data) + return msg_type, datalen, data - def parse_next(self, chunk): + def parse_next(self, chunk: bytes) -> bytes: if chunk[:1] != b'?': raise RuntimeError('Unexpected magic characters') return chunk[1:] diff --git a/trezorlib/protocol_v2.py b/trezorlib/protocol_v2.py index fd33c96a7..ba0c4f02d 100644 --- a/trezorlib/protocol_v2.py +++ b/trezorlib/protocol_v2.py @@ -18,28 +18,34 @@ from __future__ import absolute_import -import struct from io import BytesIO -from . import messages as proto +import logging +import struct +from typing import Tuple + from . import mapping from . import protobuf +from .transport import Transport REPLEN = 64 +LOG = logging.getLogger(__name__) + -class ProtocolV2(object): +class ProtocolV2: - def __init__(self): + def __init__(self) -> None: self.session = None - def session_begin(self, transport): + def session_begin(self, transport: Transport) -> None: chunk = struct.pack('>B', 0x03) chunk = chunk.ljust(REPLEN, b'\x00') transport.write_chunk(chunk) resp = transport.read_chunk() self.session = self.parse_session_open(resp) + LOG.debug("[session {}] session started".format(self.session)) - def session_end(self, transport): + def session_end(self, transport: Transport) -> None: if not self.session: return chunk = struct.pack('>BL', 0x04, self.session) @@ -49,12 +55,15 @@ class ProtocolV2(object): (magic, ) = struct.unpack('>B', resp[:1]) if magic != 0x04: raise RuntimeError('Expected session close') + LOG.debug("[session {}] session ended".format(self.session)) self.session = None - def write(self, transport, msg): + def write(self, transport: Transport, msg: protobuf.MessageType) -> None: if not self.session: raise RuntimeError('Missing session for v2 protocol') + LOG.debug("[session {}] sending message: {}".format(self.session, msg.__class__.__name__), + extra={'protobuf': msg}) # Serialize whole message data = BytesIO() protobuf.dump_message(data, msg) @@ -76,7 +85,7 @@ class ProtocolV2(object): data = data[datalen:] seq += 1 - def read(self, transport): + def read(self, transport: Transport) -> protobuf.MessageType: if not self.session: raise RuntimeError('Missing session for v2 protocol') @@ -95,12 +104,14 @@ class ProtocolV2(object): # Parse to protobuf msg = protobuf.load_message(data, mapping.get_class(msg_type)) + LOG.debug("[session {}] received message: {}".format(self.session, msg.__class__.__name__), + extra={'protobuf': msg}) return msg - def parse_first(self, chunk): + def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]: try: headerlen = struct.calcsize('>BLLL') - (magic, session, msg_type, datalen) = struct.unpack('>BLLL', chunk[:headerlen]) + magic, session, msg_type, datalen = struct.unpack('>BLLL', chunk[:headerlen]) except: raise RuntimeError('Cannot parse header') if magic != 0x01: @@ -109,10 +120,10 @@ class ProtocolV2(object): raise RuntimeError('Session id mismatch') return msg_type, datalen, chunk[headerlen:] - def parse_next(self, chunk): + def parse_next(self, chunk: bytes) -> bytes: try: headerlen = struct.calcsize('>BLL') - (magic, session, sequence) = struct.unpack('>BLL', chunk[:headerlen]) + magic, session, sequence = struct.unpack('>BLL', chunk[:headerlen]) except: raise RuntimeError('Cannot parse header') if magic != 0x02: @@ -121,10 +132,10 @@ class ProtocolV2(object): raise RuntimeError('Session id mismatch') return chunk[headerlen:] - def parse_session_open(self, chunk): + def parse_session_open(self, chunk: bytes) -> int: try: headerlen = struct.calcsize('>BL') - (magic, session) = struct.unpack('>BL', chunk[:headerlen]) + magic, session = struct.unpack('>BL', chunk[:headerlen]) except: raise RuntimeError('Cannot parse header') if magic != 0x03: