diff --git a/python/docs/transaction-format.md b/python/docs/transaction-format.md index 61a4e8029..1b033bb48 100644 --- a/python/docs/transaction-format.md +++ b/python/docs/transaction-format.md @@ -15,7 +15,7 @@ The root is an object with the following attributes: missing, `"Bitcoin"` is used. * __`inputs`__: array of `TxInputType` objects. Must be present. * __`outputs`__: array of `TxOutputType` objects. Must be present. -* __`details`__: object of type `SignTx`, specifying transaction metadata. Can be +* __`details`__: object whose keys correspond to metadata on the `SignTx` type. Can be omitted. * __`prev_txes`__: object whose keys are hex-encoded transaction hashes, and values are objects of type `TransactionType`. When signing a transaction with non-SegWit inputs, @@ -112,10 +112,8 @@ set. ### Transaction metadata -The following is a shortened definition of the `SignTx` protobuf message. Note that it -is possible to set fields `outputs_count`, `inputs_count` and `coin_name`, but their -values will be ignored. Instead, the number of elements in `outputs`, `inputs`, and the -value of `coin_name` from root object will be used. +The following is a shortened definition of the `SignTx` protobuf message, containing +all possible fields that are accepted in the `details` object. All fields are optional unless required by your currency. @@ -124,7 +122,6 @@ message SignTx { optional uint32 version = 4; // transaction version optional uint32 lock_time = 5; // transaction lock_time optional uint32 expiry = 6; // only for Decred and Zcash - optional bool overwintered = 7; // only for Zcash optional uint32 version_group_id = 8; // only for Zcash, nVersionGroupId when overwintered is set optional uint32 timestamp = 9; // only for Peercoin, transaction timestamp optional uint32 branch_id = 10; // only for Zcash, BRANCH_ID when overwintered is set diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 0570ad643..9454fcfd2 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -14,6 +14,7 @@ # You should have received a copy of the License along with this library. # If not, see . +import warnings from decimal import Decimal from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple @@ -182,7 +183,8 @@ def sign_tx( coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], - prev_txes: Dict[bytes, messages.TransactionType], + details: messages.SignTx = None, + prev_txes: Dict[bytes, messages.TransactionType] = None, preauthorized: bool = False, **kwargs: Any, ) -> Tuple[Sequence[bytes], bytes]: @@ -197,14 +199,26 @@ def sign_tx( (`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments and cannot be overriden by kwargs. """ - signtx = messages.SignTx( - coin_name=coin_name, - inputs_count=len(inputs), - outputs_count=len(outputs), - ) - for name, value in kwargs.items(): - if hasattr(signtx, name): - setattr(signtx, name, value) + if details is not None: + warnings.warn( + "'details' argument is deprecated, use kwargs instead", + DeprecationWarning, + stacklevel=2, + ) + signtx = details + signtx.coin_name = coin_name + signtx.inputs_count = len(inputs) + signtx.outputs_count = len(outputs) + + else: + signtx = messages.SignTx( + coin_name=coin_name, + inputs_count=len(inputs), + outputs_count=len(outputs), + ) + for name, value in kwargs.items(): + if hasattr(signtx, name): + setattr(signtx, name, value) if preauthorized: res = client.call(messages.DoPreauthorized()) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index 131d305db..7ecfa6eed 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -199,7 +199,7 @@ def sign_tx(client, json_file): """ data = json.load(json_file) coin = data.get("coin_name", DEFAULT_COIN) - details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {})) + details = data.get("details", {}) inputs = [ protobuf.dict_to_proto(messages.TxInputType, i) for i in data.get("inputs", ()) ] @@ -212,7 +212,14 @@ def sign_tx(client, json_file): for txid, tx in data.get("prev_txes", {}).items() } - _, serialized_tx = btc.sign_tx(client, coin, inputs, outputs, details, prev_txes) + _, serialized_tx = btc.sign_tx( + client, + coin, + inputs, + outputs, + prev_txes=prev_txes, + **details, + ) click.echo() click.echo("Signed Transaction:") diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 2edae501b..0b113a36b 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -276,11 +276,11 @@ class MessageFilter: @classmethod def from_message(cls, message): fields = {} - for field in message.keys(): - value = getattr(message, field) - if value in (None, []): + for fname, _, _ in message.get_fields().values(): + value = getattr(message, fname) + if value in (None, [], protobuf.FLAG_REQUIRED): continue - fields[field] = value + fields[fname] = value return cls(type(message), **fields) def match(self, message): diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index 22b050af7..048620ee9 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -24,6 +24,7 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ import logging from io import BytesIO +from itertools import zip_longest from typing import ( Any, Callable, @@ -37,6 +38,7 @@ from typing import ( TypeVar, Union, ) +import warnings from typing_extensions import Protocol @@ -198,11 +200,29 @@ class UnicodeType: WIRE_TYPE = 2 -class MessageType: +class _MessageTypeMeta(type): + def __init__(cls, name, bases, d) -> None: + super().__init__(name, bases, d) + if name != "MessageType": + cls.__init__ = MessageType.__init__ + + +class MessageType(metaclass=_MessageTypeMeta): WIRE_TYPE = 2 @classmethod def get_fields(cls) -> Dict[int, FieldInfo]: + """Return a field descriptor. + + The descriptor is a mapping: + field_id -> (field_name, field_type, default_value) + + `default_value` can also be one of the special values: + * `FLAG_REQUIRED` indicates that the field value has no default and _must_ be + provided by caller/sender. + * `FLAG_REPEATED` indicates that the field is a list of `field_type` values. In + that case the default value is an empty list. + """ return {} @classmethod @@ -212,10 +232,42 @@ class MessageType: return ftype return None - def __init__(self, **kwargs: Any) -> None: - for kw in kwargs: - setattr(self, kw, kwargs[kw]) - self._fill_missing() + def __init__(self, *args, **kwargs: Any) -> None: + fields = self.get_fields() + if args: + warnings.warn( + "Positional arguments for MessageType are deprecated", + DeprecationWarning, + stacklevel=2, + ) + # process fields one by one + NOT_PROVIDED = object() + for field, val in zip_longest(fields.values(), args, fillvalue=NOT_PROVIDED): + if field is NOT_PROVIDED: + raise TypeError("too many positional arguments") + fname, _, fdefault = field + if fname in kwargs and val is not NOT_PROVIDED: + # both *args and **kwargs specify the same thing + raise TypeError(f"got multiple values for argument '{fname}'") + elif fname in kwargs: + # set in kwargs but not in args + setattr(self, fname, kwargs[fname]) + elif val is not NOT_PROVIDED: + # set in args but not in kwargs + setattr(self, fname, val) + else: + # not set at all, pick a default + if fdefault is FLAG_REPEATED: + fdefault = [] + elif fdefault is FLAG_EXPERIMENTAL: + fdefault = None + elif fdefault is FLAG_REQUIRED: + warnings.warn( + f"Value of required field '{fname}' must be provided in constructor", + DeprecationWarning, + stacklevel=2, + ) + setattr(self, fname, fdefault) def __eq__(self, rhs: Any) -> bool: return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ @@ -237,17 +289,6 @@ class MessageType: def __getitem__(self, key: str) -> Any: return getattr(self, key) - def _fill_missing(self) -> None: - # fill missing fields - for fname, _, fdefault in self.get_fields().values(): - if not hasattr(self, fname): - if fdefault is FLAG_REPEATED: - setattr(self, fname, []) - elif fdefault is FLAG_REQUIRED: - raise ValueError("value for required field is missing") - else: - setattr(self, fname, fdefault) - def ByteSize(self) -> int: data = BytesIO() dump_message(data, self) @@ -403,6 +444,8 @@ def dump_message(writer: Writer, msg: MessageType) -> None: fvalue = getattr(msg, fname, None) if fvalue is None: continue + if fvalue is FLAG_REQUIRED: + raise ValueError # required value was not provided fkey = (ftag << 3) | ftype.WIRE_TYPE diff --git a/python/tests/test_protobuf_encoding.py b/python/tests/test_protobuf_encoding.py index 7a3bb80b8..ebf17b33b 100644 --- a/python/tests/test_protobuf_encoding.py +++ b/python/tests/test_protobuf_encoding.py @@ -239,9 +239,11 @@ def test_required(): assert msg_ok == msg - with pytest.raises(ValueError): - # cannot construct instance without the required fields + with pytest.deprecated_call(): msg = RequiredFields(uvarint=3) + with pytest.raises(ValueError): + # cannot encode instance without the required fields + dump_message(msg) msg = RequiredFields(uvarint=3, nested=None) # we can always encode an invalid message diff --git a/python/tests/test_protobuf_misc.py b/python/tests/test_protobuf_misc.py index 644d5becb..ea722c4f8 100644 --- a/python/tests/test_protobuf_misc.py +++ b/python/tests/test_protobuf_misc.py @@ -53,6 +53,14 @@ class NestedMessage(protobuf.MessageType): } +class RequiredFields(protobuf.MessageType): + @classmethod + def get_fields(cls): + return { + 1: ("scalar", protobuf.UVarintType, protobuf.FLAG_REQUIRED), + } + + def test_get_field_type(): # smoke test assert SimpleMessage.get_field_type("bool") is protobuf.BoolType @@ -234,3 +242,24 @@ def test_unknown_enum_to_dict(): simple = SimpleMessage(enum=6000) converted = protobuf.to_dict(simple) assert converted["enum"] == 6000 + + +def test_constructor_deprecations(): + # ok: + RequiredFields(scalar=0) + + # positional argument + with pytest.deprecated_call(): + RequiredFields(0) + + # missing required value + with pytest.deprecated_call(): + RequiredFields() + + # more args than fields + with pytest.deprecated_call(), pytest.raises(TypeError): + RequiredFields(0, 0) + + # colliding arg and kwarg + with pytest.deprecated_call(), pytest.raises(TypeError): + RequiredFields(0, scalar=0)