1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 22:40:59 +00:00

feat(python): implement API compatibility with trezorlib 0.12

This commit is contained in:
matejcik 2020-10-08 11:28:11 +02:00 committed by matejcik
parent b2948ee2dc
commit 3d6d1a56ac
7 changed files with 131 additions and 39 deletions

View File

@ -15,7 +15,7 @@ The root is an object with the following attributes:
missing, `"Bitcoin"` is used.
* __`inputs`__: array of `TxInputType` objects. Must be present.
* __`outputs`__: array of `TxOutputType` objects. Must be present.
* __`details`__: object of type `SignTx`, specifying transaction metadata. Can be
* __`details`__: object whose keys correspond to metadata on the `SignTx` type. Can be
omitted.
* __`prev_txes`__: object whose keys are hex-encoded transaction hashes, and values are
objects of type `TransactionType`. When signing a transaction with non-SegWit inputs,
@ -112,10 +112,8 @@ set.
### Transaction metadata
The following is a shortened definition of the `SignTx` protobuf message. Note that it
is possible to set fields `outputs_count`, `inputs_count` and `coin_name`, but their
values will be ignored. Instead, the number of elements in `outputs`, `inputs`, and the
value of `coin_name` from root object will be used.
The following is a shortened definition of the `SignTx` protobuf message, containing
all possible fields that are accepted in the `details` object.
All fields are optional unless required by your currency.
@ -124,7 +122,6 @@ message SignTx {
optional uint32 version = 4; // transaction version
optional uint32 lock_time = 5; // transaction lock_time
optional uint32 expiry = 6; // only for Decred and Zcash
optional bool overwintered = 7; // only for Zcash
optional uint32 version_group_id = 8; // only for Zcash, nVersionGroupId when overwintered is set
optional uint32 timestamp = 9; // only for Peercoin, transaction timestamp
optional uint32 branch_id = 10; // only for Zcash, BRANCH_ID when overwintered is set

View File

@ -14,6 +14,7 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
@ -182,7 +183,8 @@ def sign_tx(
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
prev_txes: Dict[bytes, messages.TransactionType],
details: messages.SignTx = None,
prev_txes: Dict[bytes, messages.TransactionType] = None,
preauthorized: bool = False,
**kwargs: Any,
) -> Tuple[Sequence[bytes], bytes]:
@ -197,14 +199,26 @@ def sign_tx(
(`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments
and cannot be overriden by kwargs.
"""
signtx = messages.SignTx(
coin_name=coin_name,
inputs_count=len(inputs),
outputs_count=len(outputs),
)
for name, value in kwargs.items():
if hasattr(signtx, name):
setattr(signtx, name, value)
if details is not None:
warnings.warn(
"'details' argument is deprecated, use kwargs instead",
DeprecationWarning,
stacklevel=2,
)
signtx = details
signtx.coin_name = coin_name
signtx.inputs_count = len(inputs)
signtx.outputs_count = len(outputs)
else:
signtx = messages.SignTx(
coin_name=coin_name,
inputs_count=len(inputs),
outputs_count=len(outputs),
)
for name, value in kwargs.items():
if hasattr(signtx, name):
setattr(signtx, name, value)
if preauthorized:
res = client.call(messages.DoPreauthorized())

View File

@ -199,7 +199,7 @@ def sign_tx(client, json_file):
"""
data = json.load(json_file)
coin = data.get("coin_name", DEFAULT_COIN)
details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {}))
details = data.get("details", {})
inputs = [
protobuf.dict_to_proto(messages.TxInputType, i) for i in data.get("inputs", ())
]
@ -212,7 +212,14 @@ def sign_tx(client, json_file):
for txid, tx in data.get("prev_txes", {}).items()
}
_, serialized_tx = btc.sign_tx(client, coin, inputs, outputs, details, prev_txes)
_, serialized_tx = btc.sign_tx(
client,
coin,
inputs,
outputs,
prev_txes=prev_txes,
**details,
)
click.echo()
click.echo("Signed Transaction:")

View File

@ -276,11 +276,11 @@ class MessageFilter:
@classmethod
def from_message(cls, message):
fields = {}
for field in message.keys():
value = getattr(message, field)
if value in (None, []):
for fname, _, _ in message.get_fields().values():
value = getattr(message, fname)
if value in (None, [], protobuf.FLAG_REQUIRED):
continue
fields[field] = value
fields[fname] = value
return cls(type(message), **fields)
def match(self, message):

View File

@ -24,6 +24,7 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
import logging
from io import BytesIO
from itertools import zip_longest
from typing import (
Any,
Callable,
@ -37,6 +38,7 @@ from typing import (
TypeVar,
Union,
)
import warnings
from typing_extensions import Protocol
@ -198,11 +200,29 @@ class UnicodeType:
WIRE_TYPE = 2
class MessageType:
class _MessageTypeMeta(type):
def __init__(cls, name, bases, d) -> None:
super().__init__(name, bases, d)
if name != "MessageType":
cls.__init__ = MessageType.__init__
class MessageType(metaclass=_MessageTypeMeta):
WIRE_TYPE = 2
@classmethod
def get_fields(cls) -> Dict[int, FieldInfo]:
"""Return a field descriptor.
The descriptor is a mapping:
field_id -> (field_name, field_type, default_value)
`default_value` can also be one of the special values:
* `FLAG_REQUIRED` indicates that the field value has no default and _must_ be
provided by caller/sender.
* `FLAG_REPEATED` indicates that the field is a list of `field_type` values. In
that case the default value is an empty list.
"""
return {}
@classmethod
@ -212,10 +232,42 @@ class MessageType:
return ftype
return None
def __init__(self, **kwargs: Any) -> None:
for kw in kwargs:
setattr(self, kw, kwargs[kw])
self._fill_missing()
def __init__(self, *args, **kwargs: Any) -> None:
fields = self.get_fields()
if args:
warnings.warn(
"Positional arguments for MessageType are deprecated",
DeprecationWarning,
stacklevel=2,
)
# process fields one by one
NOT_PROVIDED = object()
for field, val in zip_longest(fields.values(), args, fillvalue=NOT_PROVIDED):
if field is NOT_PROVIDED:
raise TypeError("too many positional arguments")
fname, _, fdefault = field
if fname in kwargs and val is not NOT_PROVIDED:
# both *args and **kwargs specify the same thing
raise TypeError(f"got multiple values for argument '{fname}'")
elif fname in kwargs:
# set in kwargs but not in args
setattr(self, fname, kwargs[fname])
elif val is not NOT_PROVIDED:
# set in args but not in kwargs
setattr(self, fname, val)
else:
# not set at all, pick a default
if fdefault is FLAG_REPEATED:
fdefault = []
elif fdefault is FLAG_EXPERIMENTAL:
fdefault = None
elif fdefault is FLAG_REQUIRED:
warnings.warn(
f"Value of required field '{fname}' must be provided in constructor",
DeprecationWarning,
stacklevel=2,
)
setattr(self, fname, fdefault)
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
@ -237,17 +289,6 @@ class MessageType:
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
def _fill_missing(self) -> None:
# fill missing fields
for fname, _, fdefault in self.get_fields().values():
if not hasattr(self, fname):
if fdefault is FLAG_REPEATED:
setattr(self, fname, [])
elif fdefault is FLAG_REQUIRED:
raise ValueError("value for required field is missing")
else:
setattr(self, fname, fdefault)
def ByteSize(self) -> int:
data = BytesIO()
dump_message(data, self)
@ -403,6 +444,8 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
if fvalue is FLAG_REQUIRED:
raise ValueError # required value was not provided
fkey = (ftag << 3) | ftype.WIRE_TYPE

View File

@ -239,9 +239,11 @@ def test_required():
assert msg_ok == msg
with pytest.raises(ValueError):
# cannot construct instance without the required fields
with pytest.deprecated_call():
msg = RequiredFields(uvarint=3)
with pytest.raises(ValueError):
# cannot encode instance without the required fields
dump_message(msg)
msg = RequiredFields(uvarint=3, nested=None)
# we can always encode an invalid message

View File

@ -53,6 +53,14 @@ class NestedMessage(protobuf.MessageType):
}
class RequiredFields(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {
1: ("scalar", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
}
def test_get_field_type():
# smoke test
assert SimpleMessage.get_field_type("bool") is protobuf.BoolType
@ -234,3 +242,24 @@ def test_unknown_enum_to_dict():
simple = SimpleMessage(enum=6000)
converted = protobuf.to_dict(simple)
assert converted["enum"] == 6000
def test_constructor_deprecations():
# ok:
RequiredFields(scalar=0)
# positional argument
with pytest.deprecated_call():
RequiredFields(0)
# missing required value
with pytest.deprecated_call():
RequiredFields()
# more args than fields
with pytest.deprecated_call(), pytest.raises(TypeError):
RequiredFields(0, 0)
# colliding arg and kwarg
with pytest.deprecated_call(), pytest.raises(TypeError):
RequiredFields(0, scalar=0)