mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 22:40:59 +00:00
feat(python): implement API compatibility with trezorlib 0.12
This commit is contained in:
parent
b2948ee2dc
commit
3d6d1a56ac
@ -15,7 +15,7 @@ The root is an object with the following attributes:
|
||||
missing, `"Bitcoin"` is used.
|
||||
* __`inputs`__: array of `TxInputType` objects. Must be present.
|
||||
* __`outputs`__: array of `TxOutputType` objects. Must be present.
|
||||
* __`details`__: object of type `SignTx`, specifying transaction metadata. Can be
|
||||
* __`details`__: object whose keys correspond to metadata on the `SignTx` type. Can be
|
||||
omitted.
|
||||
* __`prev_txes`__: object whose keys are hex-encoded transaction hashes, and values are
|
||||
objects of type `TransactionType`. When signing a transaction with non-SegWit inputs,
|
||||
@ -112,10 +112,8 @@ set.
|
||||
|
||||
### Transaction metadata
|
||||
|
||||
The following is a shortened definition of the `SignTx` protobuf message. Note that it
|
||||
is possible to set fields `outputs_count`, `inputs_count` and `coin_name`, but their
|
||||
values will be ignored. Instead, the number of elements in `outputs`, `inputs`, and the
|
||||
value of `coin_name` from root object will be used.
|
||||
The following is a shortened definition of the `SignTx` protobuf message, containing
|
||||
all possible fields that are accepted in the `details` object.
|
||||
|
||||
All fields are optional unless required by your currency.
|
||||
|
||||
@ -124,7 +122,6 @@ message SignTx {
|
||||
optional uint32 version = 4; // transaction version
|
||||
optional uint32 lock_time = 5; // transaction lock_time
|
||||
optional uint32 expiry = 6; // only for Decred and Zcash
|
||||
optional bool overwintered = 7; // only for Zcash
|
||||
optional uint32 version_group_id = 8; // only for Zcash, nVersionGroupId when overwintered is set
|
||||
optional uint32 timestamp = 9; // only for Peercoin, transaction timestamp
|
||||
optional uint32 branch_id = 10; // only for Zcash, BRANCH_ID when overwintered is set
|
||||
|
@ -14,6 +14,7 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import warnings
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
|
||||
|
||||
@ -182,7 +183,8 @@ def sign_tx(
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
prev_txes: Dict[bytes, messages.TransactionType],
|
||||
details: messages.SignTx = None,
|
||||
prev_txes: Dict[bytes, messages.TransactionType] = None,
|
||||
preauthorized: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Sequence[bytes], bytes]:
|
||||
@ -197,14 +199,26 @@ def sign_tx(
|
||||
(`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments
|
||||
and cannot be overriden by kwargs.
|
||||
"""
|
||||
signtx = messages.SignTx(
|
||||
coin_name=coin_name,
|
||||
inputs_count=len(inputs),
|
||||
outputs_count=len(outputs),
|
||||
)
|
||||
for name, value in kwargs.items():
|
||||
if hasattr(signtx, name):
|
||||
setattr(signtx, name, value)
|
||||
if details is not None:
|
||||
warnings.warn(
|
||||
"'details' argument is deprecated, use kwargs instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
signtx = details
|
||||
signtx.coin_name = coin_name
|
||||
signtx.inputs_count = len(inputs)
|
||||
signtx.outputs_count = len(outputs)
|
||||
|
||||
else:
|
||||
signtx = messages.SignTx(
|
||||
coin_name=coin_name,
|
||||
inputs_count=len(inputs),
|
||||
outputs_count=len(outputs),
|
||||
)
|
||||
for name, value in kwargs.items():
|
||||
if hasattr(signtx, name):
|
||||
setattr(signtx, name, value)
|
||||
|
||||
if preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
|
@ -199,7 +199,7 @@ def sign_tx(client, json_file):
|
||||
"""
|
||||
data = json.load(json_file)
|
||||
coin = data.get("coin_name", DEFAULT_COIN)
|
||||
details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {}))
|
||||
details = data.get("details", {})
|
||||
inputs = [
|
||||
protobuf.dict_to_proto(messages.TxInputType, i) for i in data.get("inputs", ())
|
||||
]
|
||||
@ -212,7 +212,14 @@ def sign_tx(client, json_file):
|
||||
for txid, tx in data.get("prev_txes", {}).items()
|
||||
}
|
||||
|
||||
_, serialized_tx = btc.sign_tx(client, coin, inputs, outputs, details, prev_txes)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
coin,
|
||||
inputs,
|
||||
outputs,
|
||||
prev_txes=prev_txes,
|
||||
**details,
|
||||
)
|
||||
|
||||
click.echo()
|
||||
click.echo("Signed Transaction:")
|
||||
|
@ -276,11 +276,11 @@ class MessageFilter:
|
||||
@classmethod
|
||||
def from_message(cls, message):
|
||||
fields = {}
|
||||
for field in message.keys():
|
||||
value = getattr(message, field)
|
||||
if value in (None, []):
|
||||
for fname, _, _ in message.get_fields().values():
|
||||
value = getattr(message, fname)
|
||||
if value in (None, [], protobuf.FLAG_REQUIRED):
|
||||
continue
|
||||
fields[field] = value
|
||||
fields[fname] = value
|
||||
return cls(type(message), **fields)
|
||||
|
||||
def match(self, message):
|
||||
|
@ -24,6 +24,7 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from itertools import zip_longest
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -37,6 +38,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
import warnings
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
@ -198,11 +200,29 @@ class UnicodeType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
|
||||
class MessageType:
|
||||
class _MessageTypeMeta(type):
|
||||
def __init__(cls, name, bases, d) -> None:
|
||||
super().__init__(name, bases, d)
|
||||
if name != "MessageType":
|
||||
cls.__init__ = MessageType.__init__
|
||||
|
||||
|
||||
class MessageType(metaclass=_MessageTypeMeta):
|
||||
WIRE_TYPE = 2
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> Dict[int, FieldInfo]:
|
||||
"""Return a field descriptor.
|
||||
|
||||
The descriptor is a mapping:
|
||||
field_id -> (field_name, field_type, default_value)
|
||||
|
||||
`default_value` can also be one of the special values:
|
||||
* `FLAG_REQUIRED` indicates that the field value has no default and _must_ be
|
||||
provided by caller/sender.
|
||||
* `FLAG_REPEATED` indicates that the field is a list of `field_type` values. In
|
||||
that case the default value is an empty list.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@ -212,10 +232,42 @@ class MessageType:
|
||||
return ftype
|
||||
return None
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
for kw in kwargs:
|
||||
setattr(self, kw, kwargs[kw])
|
||||
self._fill_missing()
|
||||
def __init__(self, *args, **kwargs: Any) -> None:
|
||||
fields = self.get_fields()
|
||||
if args:
|
||||
warnings.warn(
|
||||
"Positional arguments for MessageType are deprecated",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# process fields one by one
|
||||
NOT_PROVIDED = object()
|
||||
for field, val in zip_longest(fields.values(), args, fillvalue=NOT_PROVIDED):
|
||||
if field is NOT_PROVIDED:
|
||||
raise TypeError("too many positional arguments")
|
||||
fname, _, fdefault = field
|
||||
if fname in kwargs and val is not NOT_PROVIDED:
|
||||
# both *args and **kwargs specify the same thing
|
||||
raise TypeError(f"got multiple values for argument '{fname}'")
|
||||
elif fname in kwargs:
|
||||
# set in kwargs but not in args
|
||||
setattr(self, fname, kwargs[fname])
|
||||
elif val is not NOT_PROVIDED:
|
||||
# set in args but not in kwargs
|
||||
setattr(self, fname, val)
|
||||
else:
|
||||
# not set at all, pick a default
|
||||
if fdefault is FLAG_REPEATED:
|
||||
fdefault = []
|
||||
elif fdefault is FLAG_EXPERIMENTAL:
|
||||
fdefault = None
|
||||
elif fdefault is FLAG_REQUIRED:
|
||||
warnings.warn(
|
||||
f"Value of required field '{fname}' must be provided in constructor",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
setattr(self, fname, fdefault)
|
||||
|
||||
def __eq__(self, rhs: Any) -> bool:
|
||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||
@ -237,17 +289,6 @@ class MessageType:
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
def _fill_missing(self) -> None:
|
||||
# fill missing fields
|
||||
for fname, _, fdefault in self.get_fields().values():
|
||||
if not hasattr(self, fname):
|
||||
if fdefault is FLAG_REPEATED:
|
||||
setattr(self, fname, [])
|
||||
elif fdefault is FLAG_REQUIRED:
|
||||
raise ValueError("value for required field is missing")
|
||||
else:
|
||||
setattr(self, fname, fdefault)
|
||||
|
||||
def ByteSize(self) -> int:
|
||||
data = BytesIO()
|
||||
dump_message(data, self)
|
||||
@ -403,6 +444,8 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
if fvalue is FLAG_REQUIRED:
|
||||
raise ValueError # required value was not provided
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
|
@ -239,9 +239,11 @@ def test_required():
|
||||
|
||||
assert msg_ok == msg
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# cannot construct instance without the required fields
|
||||
with pytest.deprecated_call():
|
||||
msg = RequiredFields(uvarint=3)
|
||||
with pytest.raises(ValueError):
|
||||
# cannot encode instance without the required fields
|
||||
dump_message(msg)
|
||||
|
||||
msg = RequiredFields(uvarint=3, nested=None)
|
||||
# we can always encode an invalid message
|
||||
|
@ -53,6 +53,14 @@ class NestedMessage(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
class RequiredFields(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("scalar", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||
}
|
||||
|
||||
|
||||
def test_get_field_type():
|
||||
# smoke test
|
||||
assert SimpleMessage.get_field_type("bool") is protobuf.BoolType
|
||||
@ -234,3 +242,24 @@ def test_unknown_enum_to_dict():
|
||||
simple = SimpleMessage(enum=6000)
|
||||
converted = protobuf.to_dict(simple)
|
||||
assert converted["enum"] == 6000
|
||||
|
||||
|
||||
def test_constructor_deprecations():
|
||||
# ok:
|
||||
RequiredFields(scalar=0)
|
||||
|
||||
# positional argument
|
||||
with pytest.deprecated_call():
|
||||
RequiredFields(0)
|
||||
|
||||
# missing required value
|
||||
with pytest.deprecated_call():
|
||||
RequiredFields()
|
||||
|
||||
# more args than fields
|
||||
with pytest.deprecated_call(), pytest.raises(TypeError):
|
||||
RequiredFields(0, 0)
|
||||
|
||||
# colliding arg and kwarg
|
||||
with pytest.deprecated_call(), pytest.raises(TypeError):
|
||||
RequiredFields(0, scalar=0)
|
||||
|
Loading…
Reference in New Issue
Block a user