1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-09 06:50:58 +00:00

trezorlib: shuffling things from client

This commit is contained in:
matejcik 2018-05-21 14:28:53 +02:00
parent 9dc86f3955
commit 1820f529fc
11 changed files with 102 additions and 86 deletions

View File

@ -30,7 +30,7 @@ import logging
import os import os
import sys import sys
from trezorlib.client import TrezorClient, CallException from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport, enumerate_devices from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins from trezorlib import coins
from trezorlib import log from trezorlib import log
@ -250,12 +250,12 @@ def set_homescreen(connect, filename):
elif filename.endswith('.toif'): elif filename.endswith('.toif'):
img = open(filename, 'rb').read() img = open(filename, 'rb').read()
if img[:8] != b'TOIf\x90\x00\x90\x00': if img[:8] != b'TOIf\x90\x00\x90\x00':
raise CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144') raise tools.CallException(proto.FailureType.DataError, 'File is not a TOIF file with size of 144x144')
else: else:
from PIL import Image from PIL import Image
im = Image.open(filename) im = Image.open(filename)
if im.size != (128, 64): if im.size != (128, 64):
raise CallException(proto.FailureType.DataError, 'Wrong size of the image') raise tools.CallException(proto.FailureType.DataError, 'Wrong size of the image')
im = im.convert('1') im = im.convert('1')
pix = im.load() pix = im.load()
img = bytearray(1024) img = bytearray(1024)
@ -297,7 +297,7 @@ def wipe_device(connect, bootloader):
try: try:
return connect().wipe_device() return connect().wipe_device()
except CallException as e: except tools.CallException as e:
click.echo('Action failed: {} {}'.format(*e.args)) click.echo('Action failed: {} {}'.format(*e.args))
sys.exit(3) sys.exit(3)
@ -314,7 +314,7 @@ def wipe_device(connect, bootloader):
@click.pass_obj @click.pass_obj
def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014): def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
if not mnemonic and not xprv and not slip0014: if not mnemonic and not xprv and not slip0014:
raise CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv') raise tools.CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv')
client = connect() client = connect()
if mnemonic: if mnemonic:
@ -474,7 +474,7 @@ def firmware_update(connect, filename, url, version, skip_check, fingerprint):
try: try:
return client.firmware_update(fp=io.BytesIO(fp)) return client.firmware_update(fp=io.BytesIO(fp))
except CallException as e: except tools.CallException as e:
if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled): if e.args[0] in (proto.FailureType.FirmwareError, proto.FailureType.ActionCancelled):
click.echo("Update aborted on device.") click.echo("Update aborted on device.")
else: else:
@ -806,7 +806,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
if ' ' in value: if ' ' in value:
value, unit = value.split(' ', 1) value, unit = value.split(' ', 1)
if unit.lower() not in ether_units: if unit.lower() not in ether_units:
raise CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit) raise tools.CallException(proto.Failure.DataError, 'Unrecognized ether unit %r' % unit)
value = int(value) * ether_units[unit.lower()] value = int(value) * ether_units[unit.lower()]
else: else:
value = int(value) value = int(value)
@ -815,7 +815,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
if ' ' in gas_price: if ' ' in gas_price:
gas_price, unit = gas_price.split(' ', 1) gas_price, unit = gas_price.split(' ', 1)
if unit.lower() not in ether_units: if unit.lower() not in ether_units:
raise CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit) raise tools.CallException(proto.Failure.DataError, 'Unrecognized gas price unit %r' % unit)
gas_price = int(gas_price) * ether_units[unit.lower()] gas_price = int(gas_price) * ether_units[unit.lower()]
else: else:
gas_price = int(gas_price) gas_price = int(gas_price)

View File

@ -31,8 +31,8 @@ from . import messages as proto
from . import tools from . import tools
from . import mapping from . import mapping
from . import nem from . import nem
from . import protobuf
from . import stellar from . import stellar
from .tools import CallException, field, expect
from .debuglink import DebugLink from .debuglink import DebugLink
if sys.version_info.major < 3: if sys.version_info.major < 3:
@ -79,69 +79,26 @@ def get_buttonrequest_value(code):
return [k for k in dir(proto.ButtonRequestType) if getattr(proto.ButtonRequestType, k) == code][0] return [k for k in dir(proto.ButtonRequestType) if getattr(proto.ButtonRequestType, k) == code][0]
class CallException(Exception): class PinException(tools.CallException):
pass pass
class PinException(CallException): class MovedTo:
pass """Deprecation redirector for methods that were formerly part of TrezorClient"""
def __init__(self, where):
self.where = where
self.name = where.__module__ + '.' + where.__name__
def _deprecated_redirect(self, client, *args, **kwargs):
"""Redirector for a deprecated method on TrezorClient"""
warnings.warn("Function has been moved to %s" % self.name, DeprecationWarning, stacklevel=2)
return self.where(client, *args, **kwargs)
class field: def __get__(self, instance, cls):
# Decorator extracts single value from if instance is None:
# protobuf object. If the field is not return self._deprecated_redirect
# present, raises an exception. else:
def __init__(self, field): return functools.partial(self._deprecated_redirect, instance)
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, *expected):
self.expected = expected
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
return ret
return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
client = args[0]
client.transport.session_begin()
try:
return f(*args, **kwargs)
finally:
client.transport.session_end()
return wrapped_f
def normalize_nfc(txt):
'''
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
'''
if isinstance(txt, bytes):
txt = txt.decode('utf-8')
return unicodedata.normalize('NFC', txt).encode('utf-8')
class BaseClient(object): class BaseClient(object):
@ -158,13 +115,13 @@ class BaseClient(object):
def cancel(self): def cancel(self):
self.transport.write(proto.Cancel()) self.transport.write(proto.Cancel())
@session @tools.session
def call_raw(self, msg): def call_raw(self, msg):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks __tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
self.transport.write(msg) self.transport.write(msg)
return self.transport.read() return self.transport.read()
@session @tools.session
def call(self, msg): def call(self, msg):
resp = self.call_raw(msg) resp = self.call_raw(msg)
handler_name = "callback_%s" % resp.__class__.__name__ handler_name = "callback_%s" % resp.__class__.__name__
@ -183,7 +140,7 @@ class BaseClient(object):
proto.FailureType.PinCancelled, proto.FailureType.PinExpected): proto.FailureType.PinCancelled, proto.FailureType.PinExpected):
raise PinException(msg.code, msg.message) raise PinException(msg.code, msg.message)
raise CallException(msg.code, msg.message) raise tools.CallException(msg.code, msg.message)
def register_message(self, msg): def register_message(self, msg):
'''Allow application to register custom protobuf message type''' '''Allow application to register custom protobuf message type'''
@ -451,7 +408,7 @@ class ProtocolMixin(object):
init_msg = proto.Initialize() init_msg = proto.Initialize()
if self.state is not None: if self.state is not None:
init_msg.state = self.state init_msg.state = self.state
self.features = expect(proto.Features)(self.call)(init_msg) self.features = tools.expect(proto.Features)(self.call)(init_msg)
if str(self.features.vendor) not in self.VENDORS: if str(self.features.vendor) not in self.VENDORS:
raise RuntimeError("Unsupported device") raise RuntimeError("Unsupported device")
@ -465,7 +422,7 @@ class ProtocolMixin(object):
@staticmethod @staticmethod
def expand_path(n): def expand_path(n):
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning) warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2)
return tools.parse_path(n) return tools.parse_path(n)
@expect(proto.PublicKey) @expect(proto.PublicKey)

View File

@ -18,7 +18,7 @@ from binascii import hexlify
import pytest import pytest
from .common import TrezorTest from .common import TrezorTest
from trezorlib.client import CallException from trezorlib.tools import CallException
class TestMsgGetpublickeyCurve(TrezorTest): class TestMsgGetpublickeyCurve(TrezorTest):

View File

@ -21,9 +21,8 @@ from .common import TrezorTest
from .conftest import TREZOR_VERSION from .conftest import TREZOR_VERSION
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tx_api import TxApiInsight from trezorlib.tx_api import TxApiInsight
from trezorlib.tools import parse_path, CallException
TxApiTestnet = TxApiInsight("insight_testnet") TxApiTestnet = TxApiInsight("insight_testnet")

View File

@ -21,9 +21,7 @@ from .common import TrezorTest
from ..support.ckd_public import deserialize from ..support.ckd_public import deserialize
from trezorlib import coins from trezorlib import coins
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException from trezorlib.tools import parse_path, CallException
from trezorlib.tools import parse_path
TxApiBcash = coins.tx_api['Bcash'] TxApiBcash = coins.tx_api['Bcash']

View File

@ -21,8 +21,7 @@ from .common import TrezorTest
from ..support.ckd_public import deserialize from ..support.ckd_public import deserialize
from trezorlib import coins from trezorlib import coins
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException from trezorlib.tools import parse_path, CallException
from trezorlib.tools import parse_path
TxApiBitcoinGold = coins.tx_api["Bgold"] TxApiBitcoinGold = coins.tx_api["Bgold"]

View File

@ -22,9 +22,8 @@ from .conftest import TREZOR_VERSION
from .common import TrezorTest from .common import TrezorTest
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException
from trezorlib.tools import parse_path
from trezorlib.tx_api import TxApiInsight from trezorlib.tx_api import TxApiInsight
from trezorlib.tools import parse_path, CallException
TxApiTestnet = TxApiInsight("insight_testnet") TxApiTestnet = TxApiInsight("insight_testnet")

View File

@ -20,7 +20,7 @@ import pytest
from .common import TrezorTest from .common import TrezorTest
from ..support import ckd_public as bip32 from ..support import ckd_public as bip32
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException from trezorlib.tools import CallException
TXHASH_c6091a = unhexlify('c6091adf4c0c23982a35899a6e58ae11e703eacd7954f588ed4b9cdefc4dba52') TXHASH_c6091a = unhexlify('c6091adf4c0c23982a35899a6e58ae11e703eacd7954f588ed4b9cdefc4dba52')

View File

@ -20,8 +20,7 @@ import pytest
from .common import TrezorTest from .common import TrezorTest
from .conftest import TREZOR_VERSION from .conftest import TREZOR_VERSION
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import CallException from trezorlib.tools import CallException
TXHASH_d5f65e = unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882') TXHASH_d5f65e = unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')

View File

@ -19,7 +19,9 @@ import pytest
from .common import TrezorTest from .common import TrezorTest
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib.client import PinException, CallException from trezorlib.client import PinException
from trezorlib.tools import CallException
# FIXME TODO Add passphrase tests # FIXME TODO Add passphrase tests

View File

@ -16,6 +16,7 @@
import hashlib import hashlib
import struct import struct
import unicodedata
from typing import NewType, List from typing import NewType, List
from .coins import slip44 from .coins import slip44
@ -159,3 +160,65 @@ def parse_path(nstr: str) -> Address:
return list(str_to_harden(x) for x in n) return list(str_to_harden(x) for x in n)
except Exception: except Exception:
raise ValueError('Invalid BIP32 path', nstr) raise ValueError('Invalid BIP32 path', nstr)
def normalize_nfc(txt):
'''
Normalize message to NFC and return bytes suitable for protobuf.
This seems to be bitcoin-qt standard of doing things.
'''
if isinstance(txt, bytes):
txt = txt.decode('utf-8')
return unicodedata.normalize('NFC', txt).encode('utf-8')
class CallException(Exception):
pass
class field:
# Decorator extracts single value from
# protobuf object. If the field is not
# present, raises an exception.
def __init__(self, field):
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
def __init__(self, *expected):
self.expected = expected
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
return ret
return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
client = args[0]
client.transport.session_begin()
try:
return f(*args, **kwargs)
finally:
client.transport.session_end()
return wrapped_f