core: don't use __dict__ for messages, use __slots__

slots
Pavol Rusnak 4 years ago
parent fe03e947dc
commit 6c2b2942e0
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -80,7 +80,8 @@ def _validate(msg: RecoveryDevice) -> None:
if msg.dry_run:
# check that only allowed fields are set
for key, value in msg.__dict__.items():
for key in msg.__slots__:
value = getattr(msg, key, None)
if key not in DRY_RUN_ALLOWED_FIELDS and value is not None:
raise wire.ProcessError(
"Forbidden field set in dry-run: {}".format(key)

@ -158,7 +158,14 @@ class MessageType:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
if self.__class__ is not rhs.__class__:
return False
if self.__slots__ is not rhs.__slots__:
return False
for slot in self.__slots__:
if getattr(self, slot, None) != getattr(rhs, slot, None):
return False
return True
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__

@ -126,15 +126,18 @@ def obj_eq(l: object, r: object) -> bool:
return True
def obj_dict(o: object) -> dict:
if hasattr(o, "__slots__"):
return {attr: getattr(o, attr, None) for attr in o.__slots__}
else:
return o.__dict__
def obj_repr(o: object) -> str:
"""
Returns a string representation of object, supports __slots__.
"""
if hasattr(o, "__slots__"):
d = {attr: getattr(o, attr, None) for attr in o.__slots__}
else:
d = o.__dict__
return "<%s: %s>" % (o.__class__.__name__, d)
return "<%s: %s>" % (o.__class__.__name__, obj_dict(o))
def truncate_utf8(string: str, max_bytes: int) -> str:

@ -1,6 +1,6 @@
from common import *
from trezor.utils import chunks
from trezor.utils import chunks, obj_dict
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
@ -241,7 +241,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
serialized_tx=unhexlify('8a44999c07bba32df1cacdc50987944e68e3205b4429438fdde35c76024614090000000000ffffffff02'),
)),
# the out has to be cloned not to send the same object which was modified
TxAck(tx=TransactionType(outputs=[TxOutputType(**out1.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out1))])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=TxRequestSerializedType(
# returned serialized out1
@ -249,7 +249,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(outputs=[TxOutputType(**out2.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out2))])),
# segwit
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(

@ -1,6 +1,6 @@
from common import *
from trezor.utils import chunks
from trezor.utils import chunks, obj_dict
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
@ -241,7 +241,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
serialized_tx=unhexlify('d1613f483f2086d076c82fe34674385a86beb08f052d5405fe1aed397f852f4f0000000000feffffff02'),
)),
# the out has to be cloned not to send the same object which was modified
TxAck(tx=TransactionType(outputs=[TxOutputType(**out1.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out1))])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=TxRequestSerializedType(
# returned serialized out1
@ -249,7 +249,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(outputs=[TxOutputType(**out2.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out2))])),
# segwit
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(

@ -1,6 +1,6 @@
from common import *
from trezor.utils import chunks
from trezor.utils import chunks, obj_dict
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
@ -240,7 +240,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
'37c361fb8f2d9056ba8c98c5611930fcb48cacfdd0fe2e0449d83eea982f91200000000017160014d16b8c0680c61fc6ed2e407455715055e41052f5ffffffff02'),
)),
# the out has to be cloned not to send the same object which was modified
TxAck(tx=TransactionType(outputs=[TxOutputType(**out1.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out1))])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None),
serialized=TxRequestSerializedType(
@ -250,7 +250,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(outputs=[TxOutputType(**out2.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out2))])),
# segwit
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None),

@ -1,6 +1,6 @@
from common import *
from trezor.utils import chunks
from trezor.utils import chunks, obj_dict
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
@ -245,7 +245,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
'cf60ded29a2bd7ebf93453feace8551889d0321beab90c4f6e5c9d2fce8ba4090000000017160014d16b8c0680c61fc6ed2e407455715055e41052f5feffffff02'),
)),
# the out has to be cloned not to send the same object which was modified
TxAck(tx=TransactionType(outputs=[TxOutputType(**out1.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out1))])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None),
serialized=TxRequestSerializedType(
@ -255,7 +255,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(outputs=[TxOutputType(**out2.__dict__)])),
TxAck(tx=TransactionType(outputs=[TxOutputType(**obj_dict(out2))])),
# segwit
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None),

@ -1,4 +1,4 @@
from trezor.utils import ensure
from trezor.utils import ensure, obj_eq
class SkipTest(Exception):
@ -145,8 +145,7 @@ class TestCase:
self.assertObjectEqual(syscall, expected)
def assertObjectEqual(self, a, b, msg=''):
self.assertIsInstance(a, b.__class__, msg)
self.assertEqual(a.__dict__, b.__dict__, msg)
ensure(obj_eq(a, b), msg)
def skip(msg):

Loading…
Cancel
Save