1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-08-02 20:08:31 +00:00

trezorctl: cleanup

This commit is contained in:
Pavol Rusnak 2017-07-01 17:59:11 +02:00
parent d33e9a178b
commit 0ee1667c6f
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
8 changed files with 76 additions and 36 deletions

View File

@ -21,7 +21,7 @@ from __future__ import print_function
import unittest import unittest
import hashlib import hashlib
from trezorlib.client import TrezorClient, TrezorDebugClient from trezorlib.client import TrezorClient, TrezorClientDebugLink
from trezorlib import tx_api from trezorlib import tx_api
import config import config
@ -35,7 +35,7 @@ class TrezorTest(unittest.TestCase):
transport = config.TRANSPORT(*config.TRANSPORT_ARGS, **config.TRANSPORT_KWARGS) transport = config.TRANSPORT(*config.TRANSPORT_ARGS, **config.TRANSPORT_KWARGS)
if hasattr(config, 'DEBUG_TRANSPORT'): if hasattr(config, 'DEBUG_TRANSPORT'):
debug_transport = config.DEBUG_TRANSPORT(*config.DEBUG_TRANSPORT_ARGS, **config.DEBUG_TRANSPORT_KWARGS) debug_transport = config.DEBUG_TRANSPORT(*config.DEBUG_TRANSPORT_ARGS, **config.DEBUG_TRANSPORT_KWARGS)
self.client = TrezorDebugClient(transport) self.client = TrezorClientDebugLink(transport)
self.client.set_debuglink(debug_transport) self.client.set_debuglink(debug_transport)
else: else:
self.client = TrezorClient(transport) self.client = TrezorClient(transport)

View File

@ -10,7 +10,6 @@ import trezorlib.ckd_public as bip32
import hashlib import hashlib
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.client import TrezorClientDebug
from trezorlib.tx_api import TxApiTestnet from trezorlib.tx_api import TxApiTestnet
from trezorlib.tx_api import TxApiBitcoin from trezorlib.tx_api import TxApiBitcoin
from trezorlib.transport_hid import HidTransport from trezorlib.transport_hid import HidTransport

View File

@ -21,17 +21,16 @@
from __future__ import print_function from __future__ import print_function
import sys import sys
import os
import binascii import binascii
import argparse import argparse
import json import json
import base64 import base64
import tempfile
from io import BytesIO from io import BytesIO
from trezorlib.client import TrezorClient, TrezorClientDebug, CallException from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException
import trezorlib.types_pb2 as types import trezorlib.types_pb2 as types
ether_units = { ether_units = {
"wei": 1, "wei": 1,
"kwei": 1000, "kwei": 1000,
@ -54,6 +53,7 @@ ether_units = {
"eth": 1000000000000000000, "eth": 1000000000000000000,
} }
def init_parser(commands): def init_parser(commands):
parser = argparse.ArgumentParser(description='Commandline tool for TREZOR devices.') parser = argparse.ArgumentParser(description='Commandline tool for TREZOR devices.')
parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Prints communication to device') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Prints communication to device')
@ -80,38 +80,33 @@ def init_parser(commands):
return parser return parser
def get_transport(transport_string, path, **kwargs):
if transport_string == 'usb': def get_transport_class_by_name(name):
if name == 'usb':
from trezorlib.transport_hid import HidTransport from trezorlib.transport_hid import HidTransport
return HidTransport
for d in HidTransport.enumerate(): if name == 'udp':
# Two-tuple of (normal_interface, debug_interface)
if path == '' or path in d:
return HidTransport(d, **kwargs)
raise CallException(types.Failure_NotInitialized, "Device not found")
if transport_string == 'udp':
from trezorlib.transport_udp import UdpTransport from trezorlib.transport_udp import UdpTransport
return UdpTransport(path, **kwargs) return UdpTransport
if transport_string == 'pipe': if name == 'pipe':
from trezorlib.transport_pipe import PipeTransport from trezorlib.transport_pipe import PipeTransport
if path == '': return PipeTransport
path = '/tmp/pipe.trezor'
return PipeTransport(path, is_device=False, **kwargs)
if transport_string == 'bridge': if name == 'bridge':
from trezorlib.transport_bridge import BridgeTransport from trezorlib.transport_bridge import BridgeTransport
return BridgeTransport
devices = BridgeTransport.enumerate() raise NotImplementedError('Unknown transport: "%s"' % name)
for d in devices:
if path == '' or d['path'] == binascii.hexlify(path):
return BridgeTransport(d, **kwargs)
raise CallException(types.Failure_NotInitialized, "Device not found")
raise NotImplementedError("Unknown transport") def get_transport(transport_name, path):
transport = get_transport_class_by_name(transport_name)
dev = transport.find_by_path(path)
return dev
class Commands(object): class Commands(object):
def __init__(self, client): def __init__(self, client):
@ -594,7 +589,7 @@ def main():
transport = get_transport(args.transport, args.path) transport = get_transport(args.transport, args.path)
if args.verbose: if args.verbose:
client = TrezorClientDebug(transport) client = TrezorClientVerbose(transport)
else: else:
client = TrezorClient(transport) client = TrezorClient(transport)
@ -616,3 +611,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -216,10 +216,10 @@ class BaseClient(object):
self.transport.close() self.transport.close()
class DebugWireMixin(object): class VerboseWireMixin(object):
def call_raw(self, msg): def call_raw(self, msg):
log("SENDING " + pprint(msg)) log("SENDING " + pprint(msg))
resp = super(DebugWireMixin, self).call_raw(msg) resp = super(VerboseWireMixin, self).call_raw(msg)
log("RECEIVED " + pprint(resp)) log("RECEIVED " + pprint(resp))
return resp return resp
@ -1032,9 +1032,9 @@ class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):
pass pass
class TrezorClientDebug(ProtocolMixin, TextUIMixin, DebugWireMixin, BaseClient): class TrezorClientVerbose(ProtocolMixin, TextUIMixin, VerboseWireMixin, BaseClient):
pass pass
class TrezorDebugClient(ProtocolMixin, DebugLinkMixin, DebugWireMixin, BaseClient): class TrezorClientDebugLink(ProtocolMixin, DebugLinkMixin, VerboseWireMixin, BaseClient):
pass pass

View File

@ -19,6 +19,7 @@
'''BridgeTransport implements transport TREZOR Bridge (aka trezord).''' '''BridgeTransport implements transport TREZOR Bridge (aka trezord).'''
import binascii
import json import json
import requests import requests
from . import protobuf_json from . import protobuf_json
@ -34,6 +35,7 @@ def get_error(resp):
class BridgeTransport(TransportV1): class BridgeTransport(TransportV1):
CONFIGURED = False CONFIGURED = False
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
@ -71,11 +73,21 @@ class BridgeTransport(TransportV1):
r = requests.get(TREZORD_HOST + '/enumerate') r = requests.get(TREZORD_HOST + '/enumerate')
if r.status_code != 200: if r.status_code != 200:
raise Exception('trezord: Could not enumerate devices' + get_error(r)) raise Exception('trezord: Could not enumerate devices' + get_error(r))
enum = r.json() enum = r.json()
return enum return enum
@classmethod
def find_by_path(cls, path=None):
"""
Finds a device by transport-specific path.
If path is not set, return first device.
"""
devices = cls.enumerate()
for dev in devices:
if not path or dev['path'] == binascii.hexlify(path):
return cls(dev)
raise Exception('Device not found')
def _open(self): def _open(self):
r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path) r = self.conn.post(TREZORD_HOST + '/acquire/%s' % self.path)
if r.status_code != 200: if r.status_code != 200:

View File

@ -48,6 +48,18 @@ def enumerate():
return sorted(devices.values()) return sorted(devices.values())
def find_by_path(path=None):
"""
Finds a device by transport-specific path.
If path is not set, return first device.
"""
devices = enumerate()
for dev in devices:
if not path or path in dev:
return HidTransport(dev)
raise Exception('Device not found')
def path_to_transport(path): def path_to_transport(path):
try: try:
device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0] device = [d for d in hid.enumerate(0, 0) if d['path'] == path][0]
@ -168,3 +180,4 @@ def HidTransport(device, *args, **kwargs):
# Backward compatibility hack; HidTransport is a function, not a class like before # Backward compatibility hack; HidTransport is a function, not a class like before
HidTransport.enumerate = enumerate HidTransport.enumerate = enumerate
HidTransport.find_by_path = find_by_path

View File

@ -28,11 +28,22 @@ Use this transport for talking with trezor simulator."""
class PipeTransport(TransportV1): class PipeTransport(TransportV1):
def __init__(self, device, is_device, *args, **kwargs):
def __init__(self, device='/tmp/pipe.trezor', is_device=False, *args, **kwargs):
if not device:
device = '/tmp/pipe.trezor'
self.is_device = is_device # set True if act as device self.is_device = is_device # set True if act as device
super(PipeTransport, self).__init__(device, *args, **kwargs) super(PipeTransport, self).__init__(device, *args, **kwargs)
@classmethod
def enumerate(cls):
raise Exception('This transport cannot enumerate devices')
@classmethod
def find_by_path(cls, path=None):
return cls(path)
def _open(self): def _open(self):
if self.is_device: if self.is_device:
self.filename_read = self.device + '.to' self.filename_read = self.device + '.to'

View File

@ -24,6 +24,7 @@ from .transport import TransportV2
class UdpTransport(TransportV2): class UdpTransport(TransportV2):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
device = device.split(':') device = device.split(':')
if len(device) < 2: if len(device) < 2:
@ -38,6 +39,14 @@ class UdpTransport(TransportV2):
self.socket = None self.socket = None
super(UdpTransport, self).__init__(device, *args, **kwargs) super(UdpTransport, self).__init__(device, *args, **kwargs)
@classmethod
def enumerate(cls):
raise Exception('This transport cannot enumerate devices')
@classmethod
def find_by_path(cls, path=None):
return cls(path)
def _open(self): def _open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)