1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-17 01:52:02 +00:00

protocol: python logging to supersede VerboseWire

This commit is contained in:
matejcik 2018-04-23 12:58:30 +02:00
parent b7c7190573
commit eed91db880
5 changed files with 93 additions and 41 deletions

View File

@ -24,12 +24,14 @@ import base64
import binascii import binascii
import click import click
import json import json
import logging
import os import os
import sys import sys
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
from trezorlib.transport import get_transport, enumerate_devices from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins from trezorlib import coins
from trezorlib import log
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib import protobuf from trezorlib import protobuf
from trezorlib import stellar from trezorlib import stellar
@ -64,32 +66,38 @@ CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({
}) })
def enable_logging():
handler = logging.StreamHandler()
handler.setFormatter(log.PrettyProtobufFormatter())
logger = logging.getLogger('trezorlib')
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
log.OMITTED_MESSAGES.add(proto.Features)
@click.group(context_settings={'max_content_width': 400}) @click.group(context_settings={'max_content_width': 400})
@click.option('-p', '--path', help='Select device by specific path.', default=os.environ.get('TREZOR_PATH')) @click.option('-p', '--path', help='Select device by specific path.', default=os.environ.get('TREZOR_PATH'))
@click.option('-v', '--verbose', is_flag=True, help='Show communication messages.') @click.option('-v', '--verbose', is_flag=True, help='Show communication messages.')
@click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object')
@click.pass_context @click.pass_context
def cli(ctx, path, verbose, is_json): def cli(ctx, path, verbose, is_json):
if ctx.invoked_subcommand != 'list': if verbose:
if verbose: enable_logging()
cls = TrezorClientVerbose
else:
cls = TrezorClient
def get_device(): def get_device():
try:
device = get_transport(path, prefix_search=False)
except:
try: try:
device = get_transport(path, prefix_search=False) device = get_transport(path, prefix_search=True)
except: except:
try: click.echo("Failed to find a TREZOR device.")
device = get_transport(path, prefix_search=True) if path is not None:
except: click.echo("Using path: {}".format(path))
click.echo("Failed to find a TREZOR device.") sys.exit(1)
if path is not None: return TrezorClient(transport=device)
click.echo("Using path: {}".format(path))
sys.exit(1)
return cls(transport=device)
ctx.obj = get_device ctx.obj = get_device
@cli.resultcallback() @cli.resultcallback()

19
trezorlib/log.py Normal file
View File

@ -0,0 +1,19 @@
import logging
from typing import Set, Type
from . import protobuf
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
class PrettyProtobufFormatter(logging.Formatter):
def format(self, record):
time = self.formatTime(record)
message = '[{time}] {level}: {msg}'.format(time=time, level=record.levelname, msg=super().format(record))
if hasattr(record, 'protobuf'):
if type(record.protobuf) in OMITTED_MESSAGES:
message += " ({} bytes)".format(record.protobuf.ByteSize())
else:
message += "\n" + protobuf.format_message(record.protobuf)
return message

View File

@ -377,4 +377,8 @@ def format_message(pb: MessageType, indent: int=0, sep: str= ' ' * 4) -> str:
return repr(value) return repr(value)
return pb.__class__.__name__ + ' ' + pformat_value(pb.__dict__, indent) return '{name} ({size} bytes) {content}'.format(
name=pb.__class__.__name__,
size=pb.ByteSize(),
content=pformat_value(pb.__dict__, indent)
)

View File

@ -19,22 +19,30 @@
from __future__ import absolute_import from __future__ import absolute_import
from io import BytesIO from io import BytesIO
import logging
import struct import struct
from typing import Tuple, Type
from . import mapping from . import mapping
from . import protobuf from . import protobuf
from .transport import Transport
REPLEN = 64 REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV1(object):
def session_begin(self, transport): class ProtocolV1:
def session_begin(self, transport: Transport) -> None:
pass pass
def session_end(self, transport): def session_end(self, transport: Transport) -> None:
pass pass
def write(self, transport, msg): def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
LOG.debug("sending message: {}".format(msg.__class__.__name__),
extra={'protobuf': msg})
data = BytesIO() data = BytesIO()
protobuf.dump_message(data, msg) protobuf.dump_message(data, msg)
ser = data.getvalue() ser = data.getvalue()
@ -48,10 +56,10 @@ class ProtocolV1(object):
transport.write_chunk(chunk) transport.write_chunk(chunk)
data = data[63:] data = data[63:]
def read(self, transport): def read(self, transport: Transport) -> protobuf.MessageType:
# Read header with first part of message data # Read header with first part of message data
chunk = transport.read_chunk() chunk = transport.read_chunk()
(msg_type, datalen, data) = self.parse_first(chunk) msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message # Read the rest of the message
while len(data) < datalen: while len(data) < datalen:
@ -63,21 +71,23 @@ class ProtocolV1(object):
# Parse to protobuf # Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type)) msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug("received message: {}".format(msg.__class__.__name__),
extra={'protobuf': msg})
return msg return msg
def parse_first(self, chunk): def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
if chunk[:3] != b'?##': if chunk[:3] != b'?##':
raise RuntimeError('Unexpected magic characters') raise RuntimeError('Unexpected magic characters')
try: try:
headerlen = struct.calcsize('>HL') headerlen = struct.calcsize('>HL')
(msg_type, datalen) = struct.unpack('>HL', chunk[3:3 + headerlen]) msg_type, datalen = struct.unpack('>HL', chunk[3:3 + headerlen])
except: except:
raise RuntimeError('Cannot parse header') raise RuntimeError('Cannot parse header')
data = chunk[3 + headerlen:] data = chunk[3 + headerlen:]
return (msg_type, datalen, data) return msg_type, datalen, data
def parse_next(self, chunk): def parse_next(self, chunk: bytes) -> bytes:
if chunk[:1] != b'?': if chunk[:1] != b'?':
raise RuntimeError('Unexpected magic characters') raise RuntimeError('Unexpected magic characters')
return chunk[1:] return chunk[1:]

View File

@ -18,28 +18,34 @@
from __future__ import absolute_import from __future__ import absolute_import
import struct
from io import BytesIO from io import BytesIO
from . import messages as proto import logging
import struct
from typing import Tuple
from . import mapping from . import mapping
from . import protobuf from . import protobuf
from .transport import Transport
REPLEN = 64 REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV2(object):
def __init__(self): class ProtocolV2:
def __init__(self) -> None:
self.session = None self.session = None
def session_begin(self, transport): def session_begin(self, transport: Transport) -> None:
chunk = struct.pack('>B', 0x03) chunk = struct.pack('>B', 0x03)
chunk = chunk.ljust(REPLEN, b'\x00') chunk = chunk.ljust(REPLEN, b'\x00')
transport.write_chunk(chunk) transport.write_chunk(chunk)
resp = transport.read_chunk() resp = transport.read_chunk()
self.session = self.parse_session_open(resp) self.session = self.parse_session_open(resp)
LOG.debug("[session {}] session started".format(self.session))
def session_end(self, transport): def session_end(self, transport: Transport) -> None:
if not self.session: if not self.session:
return return
chunk = struct.pack('>BL', 0x04, self.session) chunk = struct.pack('>BL', 0x04, self.session)
@ -49,12 +55,15 @@ class ProtocolV2(object):
(magic, ) = struct.unpack('>B', resp[:1]) (magic, ) = struct.unpack('>B', resp[:1])
if magic != 0x04: if magic != 0x04:
raise RuntimeError('Expected session close') raise RuntimeError('Expected session close')
LOG.debug("[session {}] session ended".format(self.session))
self.session = None self.session = None
def write(self, transport, msg): def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
if not self.session: if not self.session:
raise RuntimeError('Missing session for v2 protocol') raise RuntimeError('Missing session for v2 protocol')
LOG.debug("[session {}] sending message: {}".format(self.session, msg.__class__.__name__),
extra={'protobuf': msg})
# Serialize whole message # Serialize whole message
data = BytesIO() data = BytesIO()
protobuf.dump_message(data, msg) protobuf.dump_message(data, msg)
@ -76,7 +85,7 @@ class ProtocolV2(object):
data = data[datalen:] data = data[datalen:]
seq += 1 seq += 1
def read(self, transport): def read(self, transport: Transport) -> protobuf.MessageType:
if not self.session: if not self.session:
raise RuntimeError('Missing session for v2 protocol') raise RuntimeError('Missing session for v2 protocol')
@ -95,12 +104,14 @@ class ProtocolV2(object):
# Parse to protobuf # Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type)) msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug("[session {}] received message: {}".format(self.session, msg.__class__.__name__),
extra={'protobuf': msg})
return msg return msg
def parse_first(self, chunk): def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
try: try:
headerlen = struct.calcsize('>BLLL') headerlen = struct.calcsize('>BLLL')
(magic, session, msg_type, datalen) = struct.unpack('>BLLL', chunk[:headerlen]) magic, session, msg_type, datalen = struct.unpack('>BLLL', chunk[:headerlen])
except: except:
raise RuntimeError('Cannot parse header') raise RuntimeError('Cannot parse header')
if magic != 0x01: if magic != 0x01:
@ -109,10 +120,10 @@ class ProtocolV2(object):
raise RuntimeError('Session id mismatch') raise RuntimeError('Session id mismatch')
return msg_type, datalen, chunk[headerlen:] return msg_type, datalen, chunk[headerlen:]
def parse_next(self, chunk): def parse_next(self, chunk: bytes) -> bytes:
try: try:
headerlen = struct.calcsize('>BLL') headerlen = struct.calcsize('>BLL')
(magic, session, sequence) = struct.unpack('>BLL', chunk[:headerlen]) magic, session, sequence = struct.unpack('>BLL', chunk[:headerlen])
except: except:
raise RuntimeError('Cannot parse header') raise RuntimeError('Cannot parse header')
if magic != 0x02: if magic != 0x02:
@ -121,10 +132,10 @@ class ProtocolV2(object):
raise RuntimeError('Session id mismatch') raise RuntimeError('Session id mismatch')
return chunk[headerlen:] return chunk[headerlen:]
def parse_session_open(self, chunk): def parse_session_open(self, chunk: bytes) -> int:
try: try:
headerlen = struct.calcsize('>BL') headerlen = struct.calcsize('>BL')
(magic, session) = struct.unpack('>BL', chunk[:headerlen]) magic, session = struct.unpack('>BL', chunk[:headerlen])
except: except:
raise RuntimeError('Cannot parse header') raise RuntimeError('Cannot parse header')
if magic != 0x03: if magic != 0x03: