1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-30 03:18:20 +00:00

trezorlib: drop @field decorator

its function is replaced by @expect(field="name") -- it doesn't make sense
to use @field without @expect anyway
This commit is contained in:
matejcik 2018-06-25 17:51:09 +02:00
parent 00617817c3
commit 7083eb7a5c
9 changed files with 37 additions and 71 deletions

View File

@ -1,7 +1,5 @@
from . import messages as proto from . import messages as proto
from .tools import expect, field, CallException, normalize_nfc, session from .tools import expect, CallException, normalize_nfc, session
### Client functions ###
@expect(proto.PublicKey) @expect(proto.PublicKey)
@ -9,8 +7,7 @@ def get_public_node(client, n, ecdsa_curve_name=None, show_display=False, coin_n
return client.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name)) return client.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name))
@field('address') @expect(proto.Address, field="address")
@expect(proto.Address)
def get_address(client, coin_name, n, show_display=False, multisig=None, script_type=proto.InputScriptType.SPENDADDRESS): def get_address(client, coin_name, n, show_display=False, multisig=None, script_type=proto.InputScriptType.SPENDADDRESS):
if multisig: if multisig:
return client.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type)) return client.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type))

View File

@ -415,8 +415,7 @@ class ProtocolMixin(object):
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2) warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2)
return tools.parse_path(n) return tools.parse_path(n)
@tools.field('message') @tools.expect(proto.Success, field="message")
@tools.expect(proto.Success)
def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False): def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
msg = proto.Ping(message=msg, msg = proto.Ping(message=msg,
button_protection=button_protection, button_protection=button_protection,
@ -450,8 +449,7 @@ class ProtocolMixin(object):
return txes return txes
@tools.field('message') @tools.expect(proto.Success, field="message")
@tools.expect(proto.Success)
def clear_session(self): def clear_session(self):
return self.call(proto.ClearSession()) return self.call(proto.ClearSession())

View File

@ -21,7 +21,7 @@ from mnemonic import Mnemonic
from . import messages as proto from . import messages as proto
from . import tools from . import tools
from .tools import field, expect, session from .tools import expect, session
from .transport import enumerate_devices, get_transport from .transport import enumerate_devices, get_transport
@ -44,8 +44,7 @@ class TrezorDevice:
return get_transport(path, prefix_search=False) return get_transport(path, prefix_search=False)
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def apply_settings(client, label=None, language=None, use_passphrase=None, homescreen=None, passphrase_source=None, auto_lock_delay_ms=None): def apply_settings(client, label=None, language=None, use_passphrase=None, homescreen=None, passphrase_source=None, auto_lock_delay_ms=None):
settings = proto.ApplySettings() settings = proto.ApplySettings()
if label is not None: if label is not None:
@ -66,39 +65,34 @@ def apply_settings(client, label=None, language=None, use_passphrase=None, homes
return out return out
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def apply_flags(client, flags): def apply_flags(client, flags):
out = client.call(proto.ApplyFlags(flags=flags)) out = client.call(proto.ApplyFlags(flags=flags))
client.init_device() # Reload Features client.init_device() # Reload Features
return out return out
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def change_pin(client, remove=False): def change_pin(client, remove=False):
ret = client.call(proto.ChangePin(remove=remove)) ret = client.call(proto.ChangePin(remove=remove))
client.init_device() # Re-read features client.init_device() # Re-read features
return ret return ret
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def set_u2f_counter(client, u2f_counter): def set_u2f_counter(client, u2f_counter):
ret = client.call(proto.SetU2FCounter(u2f_counter=u2f_counter)) ret = client.call(proto.SetU2FCounter(u2f_counter=u2f_counter))
return ret return ret
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def wipe_device(client): def wipe_device(client):
ret = client.call(proto.WipeDevice()) ret = client.call(proto.WipeDevice())
client.init_device() client.init_device()
return ret return ret
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def recovery_device(client, word_count, passphrase_protection, pin_protection, label, language, type=proto.RecoveryDeviceType.ScrambledWords, expand=False, dry_run=False): def recovery_device(client, word_count, passphrase_protection, pin_protection, label, language, type=proto.RecoveryDeviceType.ScrambledWords, expand=False, dry_run=False):
if client.features.initialized and not dry_run: if client.features.initialized and not dry_run:
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.") raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
@ -127,8 +121,7 @@ def recovery_device(client, word_count, passphrase_protection, pin_protection, l
return res return res
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
@session @session
def reset_device(client, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False): def reset_device(client, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False):
if client.features.initialized: if client.features.initialized:
@ -155,15 +148,13 @@ def reset_device(client, display_random, strength, passphrase_protection, pin_pr
return ret return ret
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def backup_device(client): def backup_device(client):
ret = client.call(proto.BackupDevice()) ret = client.call(proto.BackupDevice())
return ret return ret
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False): def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False):
# Convert mnemonic to UTF8 NKFD # Convert mnemonic to UTF8 NKFD
mnemonic = Mnemonic.normalize_string(mnemonic) mnemonic = Mnemonic.normalize_string(mnemonic)
@ -191,8 +182,7 @@ def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label,
return resp return resp
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language): def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language):
if client.features.initialized: if client.features.initialized:
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.") raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
@ -236,8 +226,7 @@ def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, languag
return resp return resp
@field('message') @expect(proto.Success, field="message")
@expect(proto.Success)
def self_test(client): def self_test(client):
if client.features.bootloader_mode is False: if client.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")

View File

@ -1,5 +1,5 @@
from . import messages as proto from . import messages as proto
from .tools import field, expect, CallException, normalize_nfc, session from .tools import expect, CallException, normalize_nfc, session
def int_to_big_endian(value): def int_to_big_endian(value):
@ -9,8 +9,7 @@ def int_to_big_endian(value):
### Client functions ### ### Client functions ###
@field('address') @expect(proto.EthereumAddress, field="address")
@expect(proto.EthereumAddress)
def get_address(client, n, show_display=False, multisig=None): def get_address(client, n, show_display=False, multisig=None):
return client.call(proto.EthereumGetAddress(address_n=n, show_display=show_display)) return client.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))

View File

@ -1,11 +1,10 @@
import binascii import binascii
from . import messages as proto from . import messages as proto
from .tools import field, expect, CallException, normalize_nfc from .tools import expect, CallException, normalize_nfc
@field('address') @expect(proto.LiskAddress, field="address")
@expect(proto.LiskAddress)
def get_address(client, n, show_display=False): def get_address(client, n, show_display=False):
return client.call(proto.LiskGetAddress(address_n=n, show_display=show_display)) return client.call(proto.LiskGetAddress(address_n=n, show_display=show_display))

View File

@ -1,9 +1,8 @@
from . import messages as proto from . import messages as proto
from .tools import field, expect from .tools import expect
@field('entropy') @expect(proto.Entropy, field="entropy")
@expect(proto.Entropy)
def get_entropy(client, size): def get_entropy(client, size):
return client.call(proto.GetEntropy(size=size)) return client.call(proto.GetEntropy(size=size))
@ -18,8 +17,7 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non
return client.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name)) return client.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name))
@field('value') @expect(proto.CipheredKeyValue, field="value")
@expect(proto.CipheredKeyValue)
def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''): def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
return client.call(proto.CipherKeyValue(address_n=n, return client.call(proto.CipherKeyValue(address_n=n,
key=key, key=key,
@ -30,8 +28,7 @@ def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=
iv=iv)) iv=iv))
@field('value') @expect(proto.CipheredKeyValue, field="value")
@expect(proto.CipheredKeyValue)
def decrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''): def decrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
return client.call(proto.CipherKeyValue(address_n=n, return client.call(proto.CipherKeyValue(address_n=n,
key=key, key=key,

View File

@ -17,7 +17,7 @@
import binascii import binascii
import json import json
from . import messages as proto from . import messages as proto
from .tools import expect, field, CallException from .tools import expect, CallException
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -170,8 +170,7 @@ def create_sign_tx(transaction):
### Client functions ### ### Client functions ###
@field("address") @expect(proto.NEMAddress, field="address")
@expect(proto.NEMAddress)
def get_address(client, n, network, show_display=False): def get_address(client, n, network, show_display=False):
return client.call(proto.NEMGetAddress(address_n=n, network=network, show_display=show_display)) return client.call(proto.NEMGetAddress(address_n=n, network=network, show_display=show_display))

View File

@ -19,7 +19,7 @@ import struct
import xdrlib import xdrlib
from . import messages from . import messages
from .tools import field, expect, CallException from .tools import expect, CallException
# Memo types # Memo types
MEMO_TYPE_NONE = 0 MEMO_TYPE_NONE = 0
@ -343,14 +343,12 @@ def _crc16_checksum(bytes):
### Client functions ### ### Client functions ###
@field('public_key') @expect(messages.StellarPublicKey, field="public_key")
@expect(messages.StellarPublicKey)
def get_public_key(client, address_n, show_display=False): def get_public_key(client, address_n, show_display=False):
return client.call(messages.StellarGetPublicKey(address_n=address_n, show_display=show_display)) return client.call(messages.StellarGetPublicKey(address_n=address_n, show_display=show_display))
@field('address') @expect(messages.StellarAddress, field="address")
@expect(messages.StellarAddress)
def get_address(client, address_n, show_display=False): def get_address(client, address_n, show_display=False):
return client.call(messages.StellarGetAddress(address_n=address_n, show_display=show_display)) return client.call(messages.StellarGetAddress(address_n=address_n, show_display=show_display))

View File

@ -157,7 +157,7 @@ def parse_path(nstr: str) -> Address:
return int(x) return int(x)
try: try:
return list(str_to_harden(x) for x in n) return [str_to_harden(x) for x in n]
except Exception: except Exception:
raise ValueError('Invalid BIP32 path', nstr) raise ValueError('Invalid BIP32 path', nstr)
@ -176,27 +176,13 @@ class CallException(Exception):
pass pass
class field:
# Decorator extracts single value from
# protobuf object. If the field is not
# present, raises an exception.
def __init__(self, field):
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect: class expect:
# Decorator checks if the method # Decorator checks if the method
# returned one of expected protobuf messages # returned one of expected protobuf messages
# or raises an exception # or raises an exception
def __init__(self, *expected): def __init__(self, expected, field=None):
self.expected = expected self.expected = expected
self.field = field
def __call__(self, f): def __call__(self, f):
@functools.wraps(f) @functools.wraps(f)
@ -204,7 +190,11 @@ class expect:
ret = f(*args, **kwargs) ret = f(*args, **kwargs)
if not isinstance(ret, self.expected): if not isinstance(ret, self.expected):
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected)) raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
return ret if self.field is not None:
return getattr(ret, self.field)
else:
return ret
return wrapped_f return wrapped_f