mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-30 03:18:20 +00:00
trezorlib: drop @field decorator
its function is replaced by @expect(field="name") -- it doesn't make sense to use @field without @expect anyway
This commit is contained in:
parent
00617817c3
commit
7083eb7a5c
@ -1,7 +1,5 @@
|
||||
from . import messages as proto
|
||||
from .tools import expect, field, CallException, normalize_nfc, session
|
||||
|
||||
### Client functions ###
|
||||
from .tools import expect, CallException, normalize_nfc, session
|
||||
|
||||
|
||||
@expect(proto.PublicKey)
|
||||
@ -9,8 +7,7 @@ def get_public_node(client, n, ecdsa_curve_name=None, show_display=False, coin_n
|
||||
return client.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name))
|
||||
|
||||
|
||||
@field('address')
|
||||
@expect(proto.Address)
|
||||
@expect(proto.Address, field="address")
|
||||
def get_address(client, coin_name, n, show_display=False, multisig=None, script_type=proto.InputScriptType.SPENDADDRESS):
|
||||
if multisig:
|
||||
return client.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type))
|
||||
|
@ -415,8 +415,7 @@ class ProtocolMixin(object):
|
||||
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2)
|
||||
return tools.parse_path(n)
|
||||
|
||||
@tools.field('message')
|
||||
@tools.expect(proto.Success)
|
||||
@tools.expect(proto.Success, field="message")
|
||||
def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
|
||||
msg = proto.Ping(message=msg,
|
||||
button_protection=button_protection,
|
||||
@ -450,8 +449,7 @@ class ProtocolMixin(object):
|
||||
|
||||
return txes
|
||||
|
||||
@tools.field('message')
|
||||
@tools.expect(proto.Success)
|
||||
@tools.expect(proto.Success, field="message")
|
||||
def clear_session(self):
|
||||
return self.call(proto.ClearSession())
|
||||
|
||||
|
@ -21,7 +21,7 @@ from mnemonic import Mnemonic
|
||||
|
||||
from . import messages as proto
|
||||
from . import tools
|
||||
from .tools import field, expect, session
|
||||
from .tools import expect, session
|
||||
|
||||
from .transport import enumerate_devices, get_transport
|
||||
|
||||
@ -44,8 +44,7 @@ class TrezorDevice:
|
||||
return get_transport(path, prefix_search=False)
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def apply_settings(client, label=None, language=None, use_passphrase=None, homescreen=None, passphrase_source=None, auto_lock_delay_ms=None):
|
||||
settings = proto.ApplySettings()
|
||||
if label is not None:
|
||||
@ -66,39 +65,34 @@ def apply_settings(client, label=None, language=None, use_passphrase=None, homes
|
||||
return out
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def apply_flags(client, flags):
|
||||
out = client.call(proto.ApplyFlags(flags=flags))
|
||||
client.init_device() # Reload Features
|
||||
return out
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def change_pin(client, remove=False):
|
||||
ret = client.call(proto.ChangePin(remove=remove))
|
||||
client.init_device() # Re-read features
|
||||
return ret
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def set_u2f_counter(client, u2f_counter):
|
||||
ret = client.call(proto.SetU2FCounter(u2f_counter=u2f_counter))
|
||||
return ret
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def wipe_device(client):
|
||||
ret = client.call(proto.WipeDevice())
|
||||
client.init_device()
|
||||
return ret
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def recovery_device(client, word_count, passphrase_protection, pin_protection, label, language, type=proto.RecoveryDeviceType.ScrambledWords, expand=False, dry_run=False):
|
||||
if client.features.initialized and not dry_run:
|
||||
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
||||
@ -127,8 +121,7 @@ def recovery_device(client, word_count, passphrase_protection, pin_protection, l
|
||||
return res
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
@session
|
||||
def reset_device(client, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False):
|
||||
if client.features.initialized:
|
||||
@ -155,15 +148,13 @@ def reset_device(client, display_random, strength, passphrase_protection, pin_pr
|
||||
return ret
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def backup_device(client):
|
||||
ret = client.call(proto.BackupDevice())
|
||||
return ret
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False):
|
||||
# Convert mnemonic to UTF8 NKFD
|
||||
mnemonic = Mnemonic.normalize_string(mnemonic)
|
||||
@ -191,8 +182,7 @@ def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label,
|
||||
return resp
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language):
|
||||
if client.features.initialized:
|
||||
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
||||
@ -236,8 +226,7 @@ def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, languag
|
||||
return resp
|
||||
|
||||
|
||||
@field('message')
|
||||
@expect(proto.Success)
|
||||
@expect(proto.Success, field="message")
|
||||
def self_test(client):
|
||||
if client.features.bootloader_mode is False:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from . import messages as proto
|
||||
from .tools import field, expect, CallException, normalize_nfc, session
|
||||
from .tools import expect, CallException, normalize_nfc, session
|
||||
|
||||
|
||||
def int_to_big_endian(value):
|
||||
@ -9,8 +9,7 @@ def int_to_big_endian(value):
|
||||
### Client functions ###
|
||||
|
||||
|
||||
@field('address')
|
||||
@expect(proto.EthereumAddress)
|
||||
@expect(proto.EthereumAddress, field="address")
|
||||
def get_address(client, n, show_display=False, multisig=None):
|
||||
return client.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
import binascii
|
||||
|
||||
from . import messages as proto
|
||||
from .tools import field, expect, CallException, normalize_nfc
|
||||
from .tools import expect, CallException, normalize_nfc
|
||||
|
||||
|
||||
@field('address')
|
||||
@expect(proto.LiskAddress)
|
||||
@expect(proto.LiskAddress, field="address")
|
||||
def get_address(client, n, show_display=False):
|
||||
return client.call(proto.LiskGetAddress(address_n=n, show_display=show_display))
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
from . import messages as proto
|
||||
from .tools import field, expect
|
||||
from .tools import expect
|
||||
|
||||
|
||||
@field('entropy')
|
||||
@expect(proto.Entropy)
|
||||
@expect(proto.Entropy, field="entropy")
|
||||
def get_entropy(client, size):
|
||||
return client.call(proto.GetEntropy(size=size))
|
||||
|
||||
@ -18,8 +17,7 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non
|
||||
return client.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name))
|
||||
|
||||
|
||||
@field('value')
|
||||
@expect(proto.CipheredKeyValue)
|
||||
@expect(proto.CipheredKeyValue, field="value")
|
||||
def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
|
||||
return client.call(proto.CipherKeyValue(address_n=n,
|
||||
key=key,
|
||||
@ -30,8 +28,7 @@ def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=
|
||||
iv=iv))
|
||||
|
||||
|
||||
@field('value')
|
||||
@expect(proto.CipheredKeyValue)
|
||||
@expect(proto.CipheredKeyValue, field="value")
|
||||
def decrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
|
||||
return client.call(proto.CipherKeyValue(address_n=n,
|
||||
key=key,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import binascii
|
||||
import json
|
||||
from . import messages as proto
|
||||
from .tools import expect, field, CallException
|
||||
from .tools import expect, CallException
|
||||
|
||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||
@ -170,8 +170,7 @@ def create_sign_tx(transaction):
|
||||
### Client functions ###
|
||||
|
||||
|
||||
@field("address")
|
||||
@expect(proto.NEMAddress)
|
||||
@expect(proto.NEMAddress, field="address")
|
||||
def get_address(client, n, network, show_display=False):
|
||||
return client.call(proto.NEMGetAddress(address_n=n, network=network, show_display=show_display))
|
||||
|
||||
|
@ -19,7 +19,7 @@ import struct
|
||||
import xdrlib
|
||||
|
||||
from . import messages
|
||||
from .tools import field, expect, CallException
|
||||
from .tools import expect, CallException
|
||||
|
||||
# Memo types
|
||||
MEMO_TYPE_NONE = 0
|
||||
@ -343,14 +343,12 @@ def _crc16_checksum(bytes):
|
||||
### Client functions ###
|
||||
|
||||
|
||||
@field('public_key')
|
||||
@expect(messages.StellarPublicKey)
|
||||
@expect(messages.StellarPublicKey, field="public_key")
|
||||
def get_public_key(client, address_n, show_display=False):
|
||||
return client.call(messages.StellarGetPublicKey(address_n=address_n, show_display=show_display))
|
||||
|
||||
|
||||
@field('address')
|
||||
@expect(messages.StellarAddress)
|
||||
@expect(messages.StellarAddress, field="address")
|
||||
def get_address(client, address_n, show_display=False):
|
||||
return client.call(messages.StellarGetAddress(address_n=address_n, show_display=show_display))
|
||||
|
||||
|
@ -157,7 +157,7 @@ def parse_path(nstr: str) -> Address:
|
||||
return int(x)
|
||||
|
||||
try:
|
||||
return list(str_to_harden(x) for x in n)
|
||||
return [str_to_harden(x) for x in n]
|
||||
except Exception:
|
||||
raise ValueError('Invalid BIP32 path', nstr)
|
||||
|
||||
@ -176,27 +176,13 @@ class CallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class field:
|
||||
# Decorator extracts single value from
|
||||
# protobuf object. If the field is not
|
||||
# present, raises an exception.
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
def __call__(self, f):
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
ret = f(*args, **kwargs)
|
||||
return getattr(ret, self.field)
|
||||
return wrapped_f
|
||||
|
||||
|
||||
class expect:
|
||||
# Decorator checks if the method
|
||||
# returned one of expected protobuf messages
|
||||
# or raises an exception
|
||||
def __init__(self, *expected):
|
||||
def __init__(self, expected, field=None):
|
||||
self.expected = expected
|
||||
self.field = field
|
||||
|
||||
def __call__(self, f):
|
||||
@functools.wraps(f)
|
||||
@ -204,7 +190,11 @@ class expect:
|
||||
ret = f(*args, **kwargs)
|
||||
if not isinstance(ret, self.expected):
|
||||
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
|
||||
return ret
|
||||
if self.field is not None:
|
||||
return getattr(ret, self.field)
|
||||
else:
|
||||
return ret
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user