From 300bf5801dd05ab5cf811c9d4e7e7d04ba7daf48 Mon Sep 17 00:00:00 2001 From: Dominik Kozaczko Date: Wed, 28 Sep 2016 00:01:32 +0200 Subject: [PATCH] fix exceptions - writeout errors instead of throwing tracebacks --- trezorctl | 66 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/trezorctl b/trezorctl index 3045942b6..132e47ce7 100755 --- a/trezorctl +++ b/trezorctl @@ -6,9 +6,12 @@ import argparse import json import base64 import tempfile +from trezorlib import types_pb2 as types from io import BytesIO -from trezorlib.client import TrezorClient, TrezorClientDebug +import sys + +from trezorlib.client import TrezorClient, TrezorClientDebug, CallException ether_units = { "wei": 1, @@ -70,7 +73,7 @@ def get_transport(transport_string, path, **kwargs): if path == '' or path in d: return HidTransport(d, **kwargs) - raise Exception("Device not found") + raise CallException(types.Failure_Other, "Device not found") if transport_string == 'udp': from trezorlib.transport_udp import UdpTransport @@ -90,7 +93,7 @@ def get_transport(transport_string, path, **kwargs): if path == '' or d['path'] == binascii.hexlify(path): return BridgeTransport(d, **kwargs) - raise Exception("Device not found") + raise CallException(types.Failure_Other, "Device not found") raise NotImplementedError("Unknown transport") @@ -124,7 +127,7 @@ class Commands(object): if ' ' in value: value, unit = value.split(' ', 1) if unit.lower() not in ether_units: - raise Exception("Unrecognized ether unit %r", unit) + raise CallException(types.Failure_Other, "Unrecognized ether unit %r" % unit) value = int(value) * ether_units[unit.lower()] else: value = int(value) @@ -196,7 +199,7 @@ class Commands(object): from PIL import Image im = Image.open(args.filename) if im.size != (128, 64): - raise Exception('Wrong size of the image') + raise CallException(types.Failure_Other, 'Wrong size of the image') im = im.convert('1') pix = im.load() img = '' @@ -223,7 +226,7 @@ class Commands(object): def load_device(self, args): if not args.mnemonic and not args.xprv: - raise Exception("Please provide mnemonic or xprv") + raise CallException(types.Failure_Other, "Please provide mnemonic or xprv") if args.mnemonic: mnemonic = ' '.join(args.mnemonic) @@ -313,7 +316,7 @@ class Commands(object): if fp[:8] == b'54525a52': fp = binascii.unhexlify(fp) if fp[:4] != b'TRZR': - raise Exception("TREZOR firmware header expected") + raise CallException(types.Failure_FirmwareError, "TREZOR firmware header expected") print("Please confirm action on device...") @@ -534,32 +537,39 @@ def qt_pin_func(input_text, message=None): def main(): args = parse_args(Commands) - if args.cmd == 'list': - devices = list_usb() - if args.json: - print(json.dumps(devices)) + try: + + if args.cmd == 'list': + devices = list_usb() + if args.json: + print(json.dumps(devices)) + else: + for dev in devices: + if dev[1] != None: + print("%s - debuglink enabled" % dev[0]) + else: + print(dev[0]) + return + + transport = get_transport(args.transport, args.path) + if args.verbose: + client = TrezorClientDebug(transport) else: - for dev in devices: - if dev[1] != None: - print("%s - debuglink enabled" % dev[0]) - else: - print(dev[0]) - return + client = TrezorClient(transport) - transport = get_transport(args.transport, args.path) - if args.verbose: - client = TrezorClientDebug(transport) - else: - client = TrezorClient(transport) + cmds = Commands(client) - cmds = Commands(client) + res = args.func(cmds, args) - res = args.func(cmds, args) + if args.json: + print(json.dumps(res, sort_keys=True, indent=4)) + else: + print(res) + except CallException as e: + status, message = e.args + sys.stderr.write('failure: {message}\n'.format(message=message)) + exit(status) - if args.json: - print(json.dumps(res, sort_keys=True, indent=4)) - else: - print(res) if __name__ == '__main__': main()