1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 11:39:03 +00:00

trezorlib: shuffling things from client

This commit is contained in:
matejcik 2018-05-21 14:28:53 +02:00
parent 9dc86f3955
commit 1820f529fc
11 changed files with 102 additions and 86 deletions

View File

@ -30,7 +30,7 @@ import logging
import os
import sys
from trezorlib.client import TrezorClient, CallException
from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins
from trezorlib import log
@ -250,12 +250,12 @@ def set_homescreen(connect, filename):
elif filename.endswith('.toif'):
img = open(filename, 'rb').read()
if img[:8] != b'TOIf\x90\x00\x90\x00':
raise CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144')
raise tools.CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144')
else:
from PIL import Image
im = Image.open(filename)
if im.size != (128, 64):
raise CallException(proto.FailureType.DataError, 'Wrong size of the image')
raise tools.CallException(proto.FailureType.DataError, 'Wrong size of the image')
im = im.convert('1')
pix = im.load()
img = bytearray(1024)
@ -297,7 +297,7 @@ def wipe_device(connect, bootloader):
try:
return connect().wipe_device()
except CallException as e:
except tools.CallException as e:
click.echo('Action failed: {} {}'.format(*e.args))
sys.exit(3)
@ -314,7 +314,7 @@ def wipe_device(connect, bootloader):
@click.pass_obj
def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
if not mnemonic and not xprv and not slip0014:
raise CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv')
raise tools.CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv')
client = connect()
if mnemonic:
@ -474,7 +474,7 @@ def firmware_update(connect, filename, url, version, skip_check, fingerprint):
try:
return client.firmware_update(fp=io.BytesIO(fp))
except CallException as e:
except tools.CallException as e:
if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled):
click.echo("Update aborted on device.")
else:
@ -806,7 +806,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
if ' ' in value:
value, unit = value.split(' ', 1)
if unit.lower() not in ether_units:
raise CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit)
raise tools.CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit)
value = int(value) * ether_units[unit.lower()]
else:
value = int(value)
@ -815,7 +815,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
if ' ' in gas_price:
gas_price, unit = gas_price.split(' ', 1)
if unit.lower() not in ether_units:
raise CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit)
raise tools.CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit)
gas_price = int(gas_price) * ether_units[unit.lower()]
else:
gas_price = int(gas_price)

View File

@ -31,8 +31,8 @@ from . import messages as proto
from . import tools
from . import mapping
from . import nem
from . import protobuf
from . import stellar
from .tools import CallException, field, expect
from .debuglink import DebugLink
if sys.version_info.major < 3:
@ -79,69 +79,26 @@ def get_buttonrequest_value(code):
return [k for k in dir(proto.ButtonRequestType) if getattr(proto.ButtonRequestType, k) == code][0]
class CallException(Exception):
class PinException(tools.CallException):
pass
class PinException(CallException):
pass
class MovedTo:
"""Deprecation redirector for methods that were formerly part of TrezorClient"""
def __init__(self, where):
self.where = where
self.name = where.__module__ + '.' + where.__name__
def _deprecated_redirect(self, client, *args, **kwargs):
"""Redirector for a deprecated method on TrezorClient"""
warnings.warn("Function has been moved to %s" % self.name, DeprecationWarning, stacklevel=2)
return self.where(client, *args, **kwargs)
class field:
# Decorator extracts single value from
# protobuf object. If the field is not
# present, raises an exception.
def __init__(self, field):
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, *expected):
self.expected = expected
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
return ret
return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
client = args[0]
client.transport.session_begin()
try:
return f(*args, **kwargs)
finally:
client.transport.session_end()
return wrapped_f
def normalize_nfc(txt):
'''
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
'''
if isinstance(txt, bytes):
txt = txt.decode('utf-8')
return unicodedata.normalize('NFC', txt).encode('utf-8')
def __get__(self, instance, cls):
if instance is None:
return self._deprecated_redirect
else:
return functools.partial(self._deprecated_redirect, instance)
class BaseClient(object):
@ -158,13 +115,13 @@ class BaseClient(object):
def cancel(self):
self.transport.write(proto.Cancel())
@session
@tools.session
def call_raw(self, msg):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
self.transport.write(msg)
return self.transport.read()
@session
@tools.session
def call(self, msg):
resp = self.call_raw(msg)
handler_name = "callback_%s" % resp.__class__.__name__
@ -183,7 +140,7 @@ class BaseClient(object):
proto.FailureType.PinCancelled, proto.FailureType.PinExpected):
raise PinException(msg.code, msg.message)
raise CallException(msg.code, msg.message)
raise tools.CallException(msg.code, msg.message)
def register_message(self, msg):
'''Allow application to register custom protobuf message type'''
@ -451,7 +408,7 @@ class ProtocolMixin(object):
init_msg = proto.Initialize()
if self.state is not None:
init_msg.state = self.state
self.features = expect(proto.Features)(self.call)(init_msg)
self.features = tools.expect(proto.Features)(self.call)(init_msg)
if str(self.features.vendor) not in self.VENDORS:
raise RuntimeError("Unsupported device")
@ -465,7 +422,7 @@ class ProtocolMixin(object):
@staticmethod
def expand_path(n):
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning)
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2)
return tools.parse_path(n)
@expect(proto.PublicKey)

View File

@ -18,7 +18,7 @@ from binascii import hexlify
import pytest
from .common import TrezorTest
from trezorlib.client import CallException
from trezorlib.tools import CallException
class TestMsgGetpublickeyCurve(TrezorTest):

View File

@ -21,9 +21,8 @@ from .common import TrezorTest
from .conftest import TREZOR_VERSION
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tx_api import TxApiInsight
from trezorlib.tools import parse_path, CallException
TxApiTestnet = TxApiInsight("insight_testnet")

View File

@ -21,9 +21,7 @@ from .common import TrezorTest
from ..support.ckd_public import deserialize
from trezorlib import coins
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tools import parse_path, CallException
TxApiBcash = coins.tx_api['Bcash']

View File

@ -21,8 +21,7 @@ from .common import TrezorTest
from ..support.ckd_public import deserialize
from trezorlib import coins
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tools import parse_path, CallException
TxApiBitcoinGold = coins.tx_api["Bgold"]

View File

@ -22,9 +22,8 @@ from .conftest import TREZOR_VERSION
from .common import TrezorTest
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tx_api import TxApiInsight
from trezorlib.tools import parse_path, CallException
TxApiTestnet = TxApiInsight("insight_testnet")

View File

@ -20,7 +20,7 @@ import pytest
from .common import TrezorTest
from ..support import ckd_public as bip32
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import CallException
TXHASH_c6091a = unhexlify('c6091adf4c0c23982a35899a6e58ae11e703eacd7954f588ed4b9cdefc4dba52')

View File

@ -20,8 +20,7 @@ import pytest
from .common import TrezorTest
from .conftest import TREZOR_VERSION
from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import CallException
TXHASH_d5f65e = unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')

View File

@ -19,7 +19,9 @@ import pytest
from .common import TrezorTest
from trezorlib import messages as proto
from trezorlib.client import PinException, CallException
from trezorlib.client import PinException
from trezorlib.tools import CallException
# FIXME TODO Add passphrase tests

View File

@ -16,6 +16,7 @@
import hashlib
import struct
import unicodedata
from typing import NewType, List
from .coins import slip44
@ -159,3 +160,65 @@ def parse_path(nstr: str) -> Address:
return list(str_to_harden(x) for x in n)
except Exception:
raise ValueError('Invalid BIP32 path', nstr)
def normalize_nfc(txt):
'''
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
'''
if isinstance(txt, bytes):
txt = txt.decode('utf-8')
return unicodedata.normalize('NFC', txt).encode('utf-8')
class CallException(Exception):
pass
class field:
# Decorator extracts single value from
# protobuf object. If the field is not
# present, raises an exception.
def __init__(self, field):
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, *expected):
self.expected = expected
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
return ret
return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
client = args[0]
client.transport.session_begin()
try:
return f(*args, **kwargs)
finally:
client.transport.session_end()
return wrapped_f