diff --git a/trezorctl b/trezorctl index 37f5032ae..2e85c7967 100755 --- a/trezorctl +++ b/trezorctl @@ -30,7 +30,7 @@ import logging import os import sys -from trezorlib.client import TrezorClient, CallException +from trezorlib.client import TrezorClient from trezorlib.transport import get_transport, enumerate_devices from trezorlib import coins from trezorlib import log @@ -250,12 +250,12 @@ def set_homescreen(connect, filename): elif filename.endswith('.toif'): img = open(filename, 'rb').read() if img[:8] != b'TOIf\x90\x00\x90\x00': - raise CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144') + raise tools.CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144') else: from PIL import Image im = Image.open(filename) if im.size != (128, 64): - raise CallException(proto.FailureType.DataError, 'Wrong size of the image') + raise tools.CallException(proto.FailureType.DataError, 'Wrong size of the image') im = im.convert('1') pix = im.load() img = bytearray(1024) @@ -297,7 +297,7 @@ def wipe_device(connect, bootloader): try: return connect().wipe_device() - except CallException as e: + except tools.CallException as e: click.echo('Action failed: {} {}'.format(*e.args)) sys.exit(3) @@ -314,7 +314,7 @@ def wipe_device(connect, bootloader): @click.pass_obj def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014): if not mnemonic and not xprv and not slip0014: - raise CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv') + raise tools.CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv') client = connect() if mnemonic: @@ -474,7 +474,7 @@ def firmware_update(connect, filename, url, version, skip_check, fingerprint): try: return client.firmware_update(fp=io.BytesIO(fp)) - except CallException as e: + except tools.CallException as e: if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled): click.echo("Update aborted on device.") else: @@ -806,7 +806,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri if ' ' in value: value, unit = value.split(' ', 1) if unit.lower() not in ether_units: - raise CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit) + raise tools.CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit) value = int(value) * ether_units[unit.lower()] else: value = int(value) @@ -815,7 +815,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri if ' ' in gas_price: gas_price, unit = gas_price.split(' ', 1) if unit.lower() not in ether_units: - raise CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit) + raise tools.CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit) gas_price = int(gas_price) * ether_units[unit.lower()] else: gas_price = int(gas_price) diff --git a/trezorlib/client.py b/trezorlib/client.py index e8cc4d67a..fbf44cd3b 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -31,8 +31,8 @@ from . import messages as proto from . import tools from . import mapping from . import nem -from . import protobuf from . import stellar +from .tools import CallException, field, expect from .debuglink import DebugLink if sys.version_info.major < 3: @@ -79,69 +79,26 @@ def get_buttonrequest_value(code): return [k for k in dir(proto.ButtonRequestType) if getattr(proto.ButtonRequestType, k) == code][0] -class CallException(Exception): +class PinException(tools.CallException): pass -class PinException(CallException): - pass +class MovedTo: + """Deprecation redirector for methods that were formerly part of TrezorClient""" + def __init__(self, where): + self.where = where + self.name = where.__module__ + '.' + where.__name__ + def _deprecated_redirect(self, client, *args, **kwargs): + """Redirector for a deprecated method on TrezorClient""" + warnings.warn("Function has been moved to %s" % self.name, DeprecationWarning, stacklevel=2) + return self.where(client, *args, **kwargs) -class field: - # Decorator extracts single value from - # protobuf object. If the field is not - # present, raises an exception. - def __init__(self, field): - self.field = field - - def __call__(self, f): - @functools.wraps(f) - def wrapped_f(*args, **kwargs): - ret = f(*args, **kwargs) - return getattr(ret, self.field) - return wrapped_f - - -class expect: - # Decorator checks if the method - # returned one of expected protobuf messages - # or raises an exception - def __init__(self, *expected): - self.expected = expected - - def __call__(self, f): - @functools.wraps(f) - def wrapped_f(*args, **kwargs): - ret = f(*args, **kwargs) - if not isinstance(ret, self.expected): - raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected)) - return ret - return wrapped_f - - -def session(f): - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(*args, **kwargs): - __tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks - client = args[0] - client.transport.session_begin() - try: - return f(*args, **kwargs) - finally: - client.transport.session_end() - return wrapped_f - - -def normalize_nfc(txt): - ''' - Normalize message to NFC and return bytes suitable for protobuf. - This seems to be bitcoin-qt standard of doing things. - ''' - if isinstance(txt, bytes): - txt = txt.decode('utf-8') - return unicodedata.normalize('NFC', txt).encode('utf-8') + def __get__(self, instance, cls): + if instance is None: + return self._deprecated_redirect + else: + return functools.partial(self._deprecated_redirect, instance) class BaseClient(object): @@ -158,13 +115,13 @@ class BaseClient(object): def cancel(self): self.transport.write(proto.Cancel()) - @session + @tools.session def call_raw(self, msg): __tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks self.transport.write(msg) return self.transport.read() - @session + @tools.session def call(self, msg): resp = self.call_raw(msg) handler_name = "callback_%s" % resp.__class__.__name__ @@ -183,7 +140,7 @@ class BaseClient(object): proto.FailureType.PinCancelled, proto.FailureType.PinExpected): raise PinException(msg.code, msg.message) - raise CallException(msg.code, msg.message) + raise tools.CallException(msg.code, msg.message) def register_message(self, msg): '''Allow application to register custom protobuf message type''' @@ -451,7 +408,7 @@ class ProtocolMixin(object): init_msg = proto.Initialize() if self.state is not None: init_msg.state = self.state - self.features = expect(proto.Features)(self.call)(init_msg) + self.features = tools.expect(proto.Features)(self.call)(init_msg) if str(self.features.vendor) not in self.VENDORS: raise RuntimeError("Unsupported device") @@ -465,7 +422,7 @@ class ProtocolMixin(object): @staticmethod def expand_path(n): - warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning) + warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2) return tools.parse_path(n) @expect(proto.PublicKey) diff --git a/trezorlib/tests/device_tests/test_msg_getpublickey_curve.py b/trezorlib/tests/device_tests/test_msg_getpublickey_curve.py index 489b4d20c..a6d9550ea 100644 --- a/trezorlib/tests/device_tests/test_msg_getpublickey_curve.py +++ b/trezorlib/tests/device_tests/test_msg_getpublickey_curve.py @@ -18,7 +18,7 @@ from binascii import hexlify import pytest from .common import TrezorTest -from trezorlib.client import CallException +from trezorlib.tools import CallException class TestMsgGetpublickeyCurve(TrezorTest): diff --git a/trezorlib/tests/device_tests/test_msg_signtx.py b/trezorlib/tests/device_tests/test_msg_signtx.py index 0ea9529fb..ec0dfecd0 100644 --- a/trezorlib/tests/device_tests/test_msg_signtx.py +++ b/trezorlib/tests/device_tests/test_msg_signtx.py @@ -21,9 +21,8 @@ from .common import TrezorTest from .conftest import TREZOR_VERSION from trezorlib import messages as proto -from trezorlib.client import CallException -from trezorlib.tools import parse_path from trezorlib.tx_api import TxApiInsight +from trezorlib.tools import parse_path, CallException TxApiTestnet = TxApiInsight("insight_testnet") diff --git a/trezorlib/tests/device_tests/test_msg_signtx_bcash.py b/trezorlib/tests/device_tests/test_msg_signtx_bcash.py index 473fbada5..e9abbbac5 100644 --- a/trezorlib/tests/device_tests/test_msg_signtx_bcash.py +++ b/trezorlib/tests/device_tests/test_msg_signtx_bcash.py @@ -21,9 +21,7 @@ from .common import TrezorTest from ..support.ckd_public import deserialize from trezorlib import coins from trezorlib import messages as proto -from trezorlib.client import CallException -from trezorlib.tools import parse_path - +from trezorlib.tools import parse_path, CallException TxApiBcash = coins.tx_api['Bcash'] diff --git a/trezorlib/tests/device_tests/test_msg_signtx_bgold.py b/trezorlib/tests/device_tests/test_msg_signtx_bgold.py index a6149addd..27717a093 100644 --- a/trezorlib/tests/device_tests/test_msg_signtx_bgold.py +++ b/trezorlib/tests/device_tests/test_msg_signtx_bgold.py @@ -21,8 +21,7 @@ from .common import TrezorTest from ..support.ckd_public import deserialize from trezorlib import coins from trezorlib import messages as proto -from trezorlib.client import CallException -from trezorlib.tools import parse_path +from trezorlib.tools import parse_path, CallException TxApiBitcoinGold = coins.tx_api["Bgold"] diff --git a/trezorlib/tests/device_tests/test_msg_signtx_segwit.py b/trezorlib/tests/device_tests/test_msg_signtx_segwit.py index f20d49cb2..5c1f37efb 100644 --- a/trezorlib/tests/device_tests/test_msg_signtx_segwit.py +++ b/trezorlib/tests/device_tests/test_msg_signtx_segwit.py @@ -22,9 +22,8 @@ from .conftest import TREZOR_VERSION from .common import TrezorTest from trezorlib import messages as proto -from trezorlib.client import CallException -from trezorlib.tools import parse_path from trezorlib.tx_api import TxApiInsight +from trezorlib.tools import parse_path, CallException TxApiTestnet = TxApiInsight("insight_testnet") diff --git a/trezorlib/tests/device_tests/test_multisig.py b/trezorlib/tests/device_tests/test_multisig.py index 3bbf90308..608d20925 100644 --- a/trezorlib/tests/device_tests/test_multisig.py +++ b/trezorlib/tests/device_tests/test_multisig.py @@ -20,7 +20,7 @@ import pytest from .common import TrezorTest from ..support import ckd_public as bip32 from trezorlib import messages as proto -from trezorlib.client import CallException +from trezorlib.tools import CallException TXHASH_c6091a = unhexlify('c6091adf4c0c23982a35899a6e58ae11e703eacd7954f588ed4b9cdefc4dba52') diff --git a/trezorlib/tests/device_tests/test_op_return.py b/trezorlib/tests/device_tests/test_op_return.py index 480858571..db1373394 100644 --- a/trezorlib/tests/device_tests/test_op_return.py +++ b/trezorlib/tests/device_tests/test_op_return.py @@ -20,8 +20,7 @@ import pytest from .common import TrezorTest from .conftest import TREZOR_VERSION from trezorlib import messages as proto -from trezorlib.client import CallException - +from trezorlib.tools import CallException TXHASH_d5f65e = unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882') diff --git a/trezorlib/tests/device_tests/test_protect_call.py b/trezorlib/tests/device_tests/test_protect_call.py index 447061643..e1d73b2ab 100644 --- a/trezorlib/tests/device_tests/test_protect_call.py +++ b/trezorlib/tests/device_tests/test_protect_call.py @@ -19,7 +19,9 @@ import pytest from .common import TrezorTest from trezorlib import messages as proto -from trezorlib.client import PinException, CallException +from trezorlib.client import PinException +from trezorlib.tools import CallException + # FIXME TODO Add passphrase tests diff --git a/trezorlib/tools.py b/trezorlib/tools.py index cd4c55672..43770d22e 100644 --- a/trezorlib/tools.py +++ b/trezorlib/tools.py @@ -16,6 +16,7 @@ import hashlib import struct +import unicodedata from typing import NewType, List from .coins import slip44 @@ -159,3 +160,65 @@ def parse_path(nstr: str) -> Address: return list(str_to_harden(x) for x in n) except Exception: raise ValueError('Invalid BIP32 path', nstr) + + + +def normalize_nfc(txt): + ''' + Normalize message to NFC and return bytes suitable for protobuf. + This seems to be bitcoin-qt standard of doing things. + ''' + if isinstance(txt, bytes): + txt = txt.decode('utf-8') + return unicodedata.normalize('NFC', txt).encode('utf-8') + + +class CallException(Exception): + pass + + +class field: + # Decorator extracts single value from + # protobuf object. If the field is not + # present, raises an exception. + def __init__(self, field): + self.field = field + + def __call__(self, f): + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + ret = f(*args, **kwargs) + return getattr(ret, self.field) + return wrapped_f + + +class expect: + # Decorator checks if the method + # returned one of expected protobuf messages + # or raises an exception + def __init__(self, *expected): + self.expected = expected + + def __call__(self, f): + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + ret = f(*args, **kwargs) + if not isinstance(ret, self.expected): + raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected)) + return ret + return wrapped_f + + +def session(f): + # Decorator wraps a BaseClient method + # with session activation / deactivation + @functools.wraps(f) + def wrapped_f(*args, **kwargs): + __tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks + client = args[0] + client.transport.session_begin() + try: + return f(*args, **kwargs) + finally: + client.transport.session_end() + return wrapped_f