mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-09 16:18:10 +00:00
feat(python): implement API compatibility with trezorlib 0.12
This commit is contained in:
parent
b2948ee2dc
commit
3d6d1a56ac
@ -15,7 +15,7 @@ The root is an object with the following attributes:
|
|||||||
missing, `"Bitcoin"` is used.
|
missing, `"Bitcoin"` is used.
|
||||||
* __`inputs`__: array of `TxInputType` objects. Must be present.
|
* __`inputs`__: array of `TxInputType` objects. Must be present.
|
||||||
* __`outputs`__: array of `TxOutputType` objects. Must be present.
|
* __`outputs`__: array of `TxOutputType` objects. Must be present.
|
||||||
* __`details`__: object of type `SignTx`, specifying transaction metadata. Can be
|
* __`details`__: object whose keys correspond to metadata on the `SignTx` type. Can be
|
||||||
omitted.
|
omitted.
|
||||||
* __`prev_txes`__: object whose keys are hex-encoded transaction hashes, and values are
|
* __`prev_txes`__: object whose keys are hex-encoded transaction hashes, and values are
|
||||||
objects of type `TransactionType`. When signing a transaction with non-SegWit inputs,
|
objects of type `TransactionType`. When signing a transaction with non-SegWit inputs,
|
||||||
@ -112,10 +112,8 @@ set.
|
|||||||
|
|
||||||
### Transaction metadata
|
### Transaction metadata
|
||||||
|
|
||||||
The following is a shortened definition of the `SignTx` protobuf message. Note that it
|
The following is a shortened definition of the `SignTx` protobuf message, containing
|
||||||
is possible to set fields `outputs_count`, `inputs_count` and `coin_name`, but their
|
all possible fields that are accepted in the `details` object.
|
||||||
values will be ignored. Instead, the number of elements in `outputs`, `inputs`, and the
|
|
||||||
value of `coin_name` from root object will be used.
|
|
||||||
|
|
||||||
All fields are optional unless required by your currency.
|
All fields are optional unless required by your currency.
|
||||||
|
|
||||||
@ -124,7 +122,6 @@ message SignTx {
|
|||||||
optional uint32 version = 4; // transaction version
|
optional uint32 version = 4; // transaction version
|
||||||
optional uint32 lock_time = 5; // transaction lock_time
|
optional uint32 lock_time = 5; // transaction lock_time
|
||||||
optional uint32 expiry = 6; // only for Decred and Zcash
|
optional uint32 expiry = 6; // only for Decred and Zcash
|
||||||
optional bool overwintered = 7; // only for Zcash
|
|
||||||
optional uint32 version_group_id = 8; // only for Zcash, nVersionGroupId when overwintered is set
|
optional uint32 version_group_id = 8; // only for Zcash, nVersionGroupId when overwintered is set
|
||||||
optional uint32 timestamp = 9; // only for Peercoin, transaction timestamp
|
optional uint32 timestamp = 9; // only for Peercoin, transaction timestamp
|
||||||
optional uint32 branch_id = 10; // only for Zcash, BRANCH_ID when overwintered is set
|
optional uint32 branch_id = 10; // only for Zcash, BRANCH_ID when overwintered is set
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
|
||||||
|
|
||||||
@ -182,7 +183,8 @@ def sign_tx(
|
|||||||
coin_name: str,
|
coin_name: str,
|
||||||
inputs: Sequence[messages.TxInputType],
|
inputs: Sequence[messages.TxInputType],
|
||||||
outputs: Sequence[messages.TxOutputType],
|
outputs: Sequence[messages.TxOutputType],
|
||||||
prev_txes: Dict[bytes, messages.TransactionType],
|
details: messages.SignTx = None,
|
||||||
|
prev_txes: Dict[bytes, messages.TransactionType] = None,
|
||||||
preauthorized: bool = False,
|
preauthorized: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[Sequence[bytes], bytes]:
|
) -> Tuple[Sequence[bytes], bytes]:
|
||||||
@ -197,14 +199,26 @@ def sign_tx(
|
|||||||
(`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments
|
(`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments
|
||||||
and cannot be overriden by kwargs.
|
and cannot be overriden by kwargs.
|
||||||
"""
|
"""
|
||||||
signtx = messages.SignTx(
|
if details is not None:
|
||||||
coin_name=coin_name,
|
warnings.warn(
|
||||||
inputs_count=len(inputs),
|
"'details' argument is deprecated, use kwargs instead",
|
||||||
outputs_count=len(outputs),
|
DeprecationWarning,
|
||||||
)
|
stacklevel=2,
|
||||||
for name, value in kwargs.items():
|
)
|
||||||
if hasattr(signtx, name):
|
signtx = details
|
||||||
setattr(signtx, name, value)
|
signtx.coin_name = coin_name
|
||||||
|
signtx.inputs_count = len(inputs)
|
||||||
|
signtx.outputs_count = len(outputs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
signtx = messages.SignTx(
|
||||||
|
coin_name=coin_name,
|
||||||
|
inputs_count=len(inputs),
|
||||||
|
outputs_count=len(outputs),
|
||||||
|
)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
if hasattr(signtx, name):
|
||||||
|
setattr(signtx, name, value)
|
||||||
|
|
||||||
if preauthorized:
|
if preauthorized:
|
||||||
res = client.call(messages.DoPreauthorized())
|
res = client.call(messages.DoPreauthorized())
|
||||||
|
@ -199,7 +199,7 @@ def sign_tx(client, json_file):
|
|||||||
"""
|
"""
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
coin = data.get("coin_name", DEFAULT_COIN)
|
coin = data.get("coin_name", DEFAULT_COIN)
|
||||||
details = protobuf.dict_to_proto(messages.SignTx, data.get("details", {}))
|
details = data.get("details", {})
|
||||||
inputs = [
|
inputs = [
|
||||||
protobuf.dict_to_proto(messages.TxInputType, i) for i in data.get("inputs", ())
|
protobuf.dict_to_proto(messages.TxInputType, i) for i in data.get("inputs", ())
|
||||||
]
|
]
|
||||||
@ -212,7 +212,14 @@ def sign_tx(client, json_file):
|
|||||||
for txid, tx in data.get("prev_txes", {}).items()
|
for txid, tx in data.get("prev_txes", {}).items()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, serialized_tx = btc.sign_tx(client, coin, inputs, outputs, details, prev_txes)
|
_, serialized_tx = btc.sign_tx(
|
||||||
|
client,
|
||||||
|
coin,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
prev_txes=prev_txes,
|
||||||
|
**details,
|
||||||
|
)
|
||||||
|
|
||||||
click.echo()
|
click.echo()
|
||||||
click.echo("Signed Transaction:")
|
click.echo("Signed Transaction:")
|
||||||
|
@ -276,11 +276,11 @@ class MessageFilter:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message):
|
||||||
fields = {}
|
fields = {}
|
||||||
for field in message.keys():
|
for fname, _, _ in message.get_fields().values():
|
||||||
value = getattr(message, field)
|
value = getattr(message, fname)
|
||||||
if value in (None, []):
|
if value in (None, [], protobuf.FLAG_REQUIRED):
|
||||||
continue
|
continue
|
||||||
fields[field] = value
|
fields[fname] = value
|
||||||
return cls(type(message), **fields)
|
return cls(type(message), **fields)
|
||||||
|
|
||||||
def match(self, message):
|
def match(self, message):
|
||||||
|
@ -24,6 +24,7 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from itertools import zip_longest
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -37,6 +38,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
import warnings
|
||||||
|
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
@ -198,11 +200,29 @@ class UnicodeType:
|
|||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
|
||||||
class MessageType:
|
class _MessageTypeMeta(type):
|
||||||
|
def __init__(cls, name, bases, d) -> None:
|
||||||
|
super().__init__(name, bases, d)
|
||||||
|
if name != "MessageType":
|
||||||
|
cls.__init__ = MessageType.__init__
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(metaclass=_MessageTypeMeta):
|
||||||
WIRE_TYPE = 2
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Dict[int, FieldInfo]:
|
def get_fields(cls) -> Dict[int, FieldInfo]:
|
||||||
|
"""Return a field descriptor.
|
||||||
|
|
||||||
|
The descriptor is a mapping:
|
||||||
|
field_id -> (field_name, field_type, default_value)
|
||||||
|
|
||||||
|
`default_value` can also be one of the special values:
|
||||||
|
* `FLAG_REQUIRED` indicates that the field value has no default and _must_ be
|
||||||
|
provided by caller/sender.
|
||||||
|
* `FLAG_REPEATED` indicates that the field is a list of `field_type` values. In
|
||||||
|
that case the default value is an empty list.
|
||||||
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -212,10 +232,42 @@ class MessageType:
|
|||||||
return ftype
|
return ftype
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, *args, **kwargs: Any) -> None:
|
||||||
for kw in kwargs:
|
fields = self.get_fields()
|
||||||
setattr(self, kw, kwargs[kw])
|
if args:
|
||||||
self._fill_missing()
|
warnings.warn(
|
||||||
|
"Positional arguments for MessageType are deprecated",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
# process fields one by one
|
||||||
|
NOT_PROVIDED = object()
|
||||||
|
for field, val in zip_longest(fields.values(), args, fillvalue=NOT_PROVIDED):
|
||||||
|
if field is NOT_PROVIDED:
|
||||||
|
raise TypeError("too many positional arguments")
|
||||||
|
fname, _, fdefault = field
|
||||||
|
if fname in kwargs and val is not NOT_PROVIDED:
|
||||||
|
# both *args and **kwargs specify the same thing
|
||||||
|
raise TypeError(f"got multiple values for argument '{fname}'")
|
||||||
|
elif fname in kwargs:
|
||||||
|
# set in kwargs but not in args
|
||||||
|
setattr(self, fname, kwargs[fname])
|
||||||
|
elif val is not NOT_PROVIDED:
|
||||||
|
# set in args but not in kwargs
|
||||||
|
setattr(self, fname, val)
|
||||||
|
else:
|
||||||
|
# not set at all, pick a default
|
||||||
|
if fdefault is FLAG_REPEATED:
|
||||||
|
fdefault = []
|
||||||
|
elif fdefault is FLAG_EXPERIMENTAL:
|
||||||
|
fdefault = None
|
||||||
|
elif fdefault is FLAG_REQUIRED:
|
||||||
|
warnings.warn(
|
||||||
|
f"Value of required field '{fname}' must be provided in constructor",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
setattr(self, fname, fdefault)
|
||||||
|
|
||||||
def __eq__(self, rhs: Any) -> bool:
|
def __eq__(self, rhs: Any) -> bool:
|
||||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||||
@ -237,17 +289,6 @@ class MessageType:
|
|||||||
def __getitem__(self, key: str) -> Any:
|
def __getitem__(self, key: str) -> Any:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
def _fill_missing(self) -> None:
|
|
||||||
# fill missing fields
|
|
||||||
for fname, _, fdefault in self.get_fields().values():
|
|
||||||
if not hasattr(self, fname):
|
|
||||||
if fdefault is FLAG_REPEATED:
|
|
||||||
setattr(self, fname, [])
|
|
||||||
elif fdefault is FLAG_REQUIRED:
|
|
||||||
raise ValueError("value for required field is missing")
|
|
||||||
else:
|
|
||||||
setattr(self, fname, fdefault)
|
|
||||||
|
|
||||||
def ByteSize(self) -> int:
|
def ByteSize(self) -> int:
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
dump_message(data, self)
|
dump_message(data, self)
|
||||||
@ -403,6 +444,8 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
fvalue = getattr(msg, fname, None)
|
fvalue = getattr(msg, fname, None)
|
||||||
if fvalue is None:
|
if fvalue is None:
|
||||||
continue
|
continue
|
||||||
|
if fvalue is FLAG_REQUIRED:
|
||||||
|
raise ValueError # required value was not provided
|
||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||||
|
|
||||||
|
@ -239,9 +239,11 @@ def test_required():
|
|||||||
|
|
||||||
assert msg_ok == msg
|
assert msg_ok == msg
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.deprecated_call():
|
||||||
# cannot construct instance without the required fields
|
|
||||||
msg = RequiredFields(uvarint=3)
|
msg = RequiredFields(uvarint=3)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# cannot encode instance without the required fields
|
||||||
|
dump_message(msg)
|
||||||
|
|
||||||
msg = RequiredFields(uvarint=3, nested=None)
|
msg = RequiredFields(uvarint=3, nested=None)
|
||||||
# we can always encode an invalid message
|
# we can always encode an invalid message
|
||||||
|
@ -53,6 +53,14 @@ class NestedMessage(protobuf.MessageType):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredFields(protobuf.MessageType):
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("scalar", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_get_field_type():
|
def test_get_field_type():
|
||||||
# smoke test
|
# smoke test
|
||||||
assert SimpleMessage.get_field_type("bool") is protobuf.BoolType
|
assert SimpleMessage.get_field_type("bool") is protobuf.BoolType
|
||||||
@ -234,3 +242,24 @@ def test_unknown_enum_to_dict():
|
|||||||
simple = SimpleMessage(enum=6000)
|
simple = SimpleMessage(enum=6000)
|
||||||
converted = protobuf.to_dict(simple)
|
converted = protobuf.to_dict(simple)
|
||||||
assert converted["enum"] == 6000
|
assert converted["enum"] == 6000
|
||||||
|
|
||||||
|
|
||||||
|
def test_constructor_deprecations():
|
||||||
|
# ok:
|
||||||
|
RequiredFields(scalar=0)
|
||||||
|
|
||||||
|
# positional argument
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
RequiredFields(0)
|
||||||
|
|
||||||
|
# missing required value
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
RequiredFields()
|
||||||
|
|
||||||
|
# more args than fields
|
||||||
|
with pytest.deprecated_call(), pytest.raises(TypeError):
|
||||||
|
RequiredFields(0, 0)
|
||||||
|
|
||||||
|
# colliding arg and kwarg
|
||||||
|
with pytest.deprecated_call(), pytest.raises(TypeError):
|
||||||
|
RequiredFields(0, scalar=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user