mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-16 17:42:02 +00:00
python: add more protobuf tests
This commit is contained in:
parent
bd9bf4e2bc
commit
8f2b22a8f5
@ -205,6 +205,13 @@ class MessageType:
|
|||||||
def get_fields(cls) -> Dict[int, FieldInfo]:
|
def get_fields(cls) -> Dict[int, FieldInfo]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_type(cls, name: str) -> Optional[FieldType]:
|
||||||
|
for fname, ftype, flags in cls.get_fields().values():
|
||||||
|
if fname == name:
|
||||||
|
return ftype
|
||||||
|
return None
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
for kw in kwargs:
|
for kw in kwargs:
|
||||||
setattr(self, kw, kwargs[kw])
|
setattr(self, kw, kwargs[kw])
|
||||||
@ -438,16 +445,10 @@ def format_message(
|
|||||||
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E)
|
||||||
return printable / len(bytes) > 0.8
|
return printable / len(bytes) > 0.8
|
||||||
|
|
||||||
def get_type(name: str) -> Any:
|
|
||||||
try:
|
|
||||||
return next(ft for fn, ft, _ in pb.get_fields().values() if fn == name)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def pformat(name: str, value: Any, indent: int) -> str:
|
def pformat(name: str, value: Any, indent: int) -> str:
|
||||||
level = sep * indent
|
level = sep * indent
|
||||||
leadin = sep * (indent + 1)
|
leadin = sep * (indent + 1)
|
||||||
ftype = get_type(name)
|
ftype = pb.get_field_type(name)
|
||||||
|
|
||||||
if isinstance(value, MessageType):
|
if isinstance(value, MessageType):
|
||||||
return format_message(value, indent, sep)
|
return format_message(value, indent, sep)
|
||||||
@ -549,13 +550,15 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
|||||||
|
|
||||||
|
|
||||||
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
||||||
def convert_value(value: Any) -> Any:
|
def convert_value(ftype: FieldType, value: Any) -> Any:
|
||||||
if hexlify_bytes and isinstance(value, bytes):
|
if hexlify_bytes and isinstance(value, bytes):
|
||||||
return value.hex()
|
return value.hex()
|
||||||
elif isinstance(value, MessageType):
|
elif isinstance(value, MessageType):
|
||||||
return to_dict(value, hexlify_bytes)
|
return to_dict(value, hexlify_bytes)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [convert_value(v) for v in value]
|
return [convert_value(ftype, v) for v in value]
|
||||||
|
elif isinstance(value, int) and isinstance(ftype, EnumType):
|
||||||
|
return ftype.to_str(value)
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -563,6 +566,6 @@ def to_dict(msg: MessageType, hexlify_bytes: bool = True) -> Dict[str, Any]:
|
|||||||
for key, value in msg.__dict__.items():
|
for key, value in msg.__dict__.items():
|
||||||
if value is None or value == []:
|
if value is None or value == []:
|
||||||
continue
|
continue
|
||||||
res[key] = convert_value(value)
|
res[key] = convert_value(msg.get_field_type(key), value)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -19,7 +19,6 @@ from io import BytesIO
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from trezorlib import protobuf
|
from trezorlib import protobuf
|
||||||
from trezorlib.messages import InputScriptType
|
|
||||||
|
|
||||||
|
|
||||||
class PrimitiveMessage(protobuf.MessageType):
|
class PrimitiveMessage(protobuf.MessageType):
|
||||||
@ -170,24 +169,6 @@ def test_validate_enum(caplog):
|
|||||||
assert record.getMessage() == "Value 3 unknown for type t"
|
assert record.getMessage() == "Value 3 unknown for type t"
|
||||||
|
|
||||||
|
|
||||||
def test_enum_to_str():
|
|
||||||
enum_values = [
|
|
||||||
(key, getattr(InputScriptType, key))
|
|
||||||
for key in dir(InputScriptType)
|
|
||||||
if not key.startswith("__")
|
|
||||||
]
|
|
||||||
enum_type = protobuf.EnumType("InputScriptType", [v for _, v in enum_values])
|
|
||||||
for name, value in enum_values:
|
|
||||||
assert enum_type.to_str(value) == name
|
|
||||||
assert enum_type.from_str(name) == value
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
enum_type.from_str("NotAValidValue")
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
enum_type.to_str(999)
|
|
||||||
|
|
||||||
|
|
||||||
def test_repeated():
|
def test_repeated():
|
||||||
msg = RepeatedFields(
|
msg = RepeatedFields(
|
||||||
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
218
python/tests/test_protobuf_misc.py
Normal file
218
python/tests/test_protobuf_misc.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# This file is part of the Trezor project.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2012-2019 SatoshiLabs and contributors
|
||||||
|
#
|
||||||
|
# This library is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
|
# as published by the Free Software Foundation.
|
||||||
|
#
|
||||||
|
# This library is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Lesser General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the License along with this library.
|
||||||
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from trezorlib import protobuf
|
||||||
|
|
||||||
|
SimpleEnum = SimpleNamespace(FOO=0, BAR=5, QUUX=13)
|
||||||
|
SimpleEnumType = protobuf.EnumType("SimpleEnum", (0, 5, 13))
|
||||||
|
|
||||||
|
with_simple_enum = patch("trezorlib.messages.SimpleEnum", SimpleEnum, create=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMessage(protobuf.MessageType):
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("uvarint", protobuf.UVarintType, 0),
|
||||||
|
2: ("svarint", protobuf.SVarintType, 0),
|
||||||
|
3: ("bool", protobuf.BoolType, 0),
|
||||||
|
4: ("bytes", protobuf.BytesType, 0),
|
||||||
|
5: ("unicode", protobuf.UnicodeType, 0),
|
||||||
|
6: ("enum", SimpleEnumType, 0),
|
||||||
|
7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
|
||||||
|
8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
|
||||||
|
9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NestedMessage(protobuf.MessageType):
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("scalar", protobuf.UVarintType, 0),
|
||||||
|
2: ("nested", SimpleMessage, 0),
|
||||||
|
3: ("repeated", SimpleMessage, protobuf.FLAG_REPEATED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_field_type():
|
||||||
|
# smoke test
|
||||||
|
assert SimpleMessage.get_field_type("bool") is protobuf.BoolType
|
||||||
|
|
||||||
|
# full field list
|
||||||
|
for fname, ftype, _ in SimpleMessage.get_fields().values():
|
||||||
|
assert SimpleMessage.get_field_type(fname) is ftype
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_enum_to_str():
|
||||||
|
# smoke test
|
||||||
|
assert SimpleEnumType.to_str(5) == "BAR"
|
||||||
|
|
||||||
|
# full value list
|
||||||
|
for name, value in SimpleEnum.__dict__.items():
|
||||||
|
assert SimpleEnumType.to_str(value) == name
|
||||||
|
assert SimpleEnumType.from_str(name) == value
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
SimpleEnumType.from_str("NotAValidValue")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
SimpleEnumType.to_str(999)
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_dict_roundtrip():
|
||||||
|
msg = SimpleMessage(
|
||||||
|
uvarint=5,
|
||||||
|
svarint=-13,
|
||||||
|
bool=False,
|
||||||
|
bytes=b"\xca\xfe\x00\xfe",
|
||||||
|
unicode="žluťoučký kůň",
|
||||||
|
enum=5,
|
||||||
|
rep_int=[1, 2, 3],
|
||||||
|
rep_str=["a", "b", "c"],
|
||||||
|
rep_enum=[0, 5, 13],
|
||||||
|
)
|
||||||
|
|
||||||
|
converted = protobuf.to_dict(msg)
|
||||||
|
recovered = protobuf.dict_to_proto(SimpleMessage, converted)
|
||||||
|
|
||||||
|
assert recovered == msg
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_to_dict():
|
||||||
|
msg = SimpleMessage(
|
||||||
|
uvarint=5,
|
||||||
|
svarint=-13,
|
||||||
|
bool=False,
|
||||||
|
bytes=b"\xca\xfe\x00\xfe",
|
||||||
|
unicode="žluťoučký kůň",
|
||||||
|
enum=5,
|
||||||
|
rep_int=[1, 2, 3],
|
||||||
|
rep_str=["a", "b", "c"],
|
||||||
|
rep_enum=[0, 5, 13],
|
||||||
|
)
|
||||||
|
|
||||||
|
converted = protobuf.to_dict(msg)
|
||||||
|
|
||||||
|
fields = [fname for fname, _, _ in msg.get_fields().values()]
|
||||||
|
assert list(sorted(converted.keys())) == list(sorted(fields))
|
||||||
|
|
||||||
|
assert converted["uvarint"] == 5
|
||||||
|
assert converted["svarint"] == -13
|
||||||
|
assert converted["bool"] is False
|
||||||
|
assert converted["bytes"] == "cafe00fe"
|
||||||
|
assert converted["unicode"] == "žluťoučký kůň"
|
||||||
|
assert converted["enum"] == "BAR"
|
||||||
|
assert converted["rep_int"] == [1, 2, 3]
|
||||||
|
assert converted["rep_str"] == ["a", "b", "c"]
|
||||||
|
assert converted["rep_enum"] == ["FOO", "BAR", "QUUX"]
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_recover_mismatch():
|
||||||
|
dictdata = {
|
||||||
|
"bool": True,
|
||||||
|
"enum": "FOO",
|
||||||
|
"another_field": "hello",
|
||||||
|
"rep_enum": ["FOO", 5, 5],
|
||||||
|
}
|
||||||
|
recovered = protobuf.dict_to_proto(SimpleMessage, dictdata)
|
||||||
|
|
||||||
|
assert recovered.bool is True
|
||||||
|
assert recovered.enum is SimpleEnum.FOO
|
||||||
|
assert not hasattr(recovered, "another_field")
|
||||||
|
assert recovered.rep_enum == [SimpleEnum.FOO, SimpleEnum.BAR, SimpleEnum.BAR]
|
||||||
|
|
||||||
|
for name, _, flags in SimpleMessage.get_fields().values():
|
||||||
|
if name not in dictdata:
|
||||||
|
if flags == protobuf.FLAG_REPEATED:
|
||||||
|
assert getattr(recovered, name) == []
|
||||||
|
else:
|
||||||
|
assert getattr(recovered, name) is None
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_hexlify():
|
||||||
|
msg = SimpleMessage(bytes=b"\xca\xfe\x00\x12\x34", unicode="žluťoučký kůň")
|
||||||
|
converted_nohex = protobuf.to_dict(msg, hexlify_bytes=False)
|
||||||
|
converted_hex = protobuf.to_dict(msg, hexlify_bytes=True)
|
||||||
|
|
||||||
|
assert converted_nohex["bytes"] == b"\xca\xfe\x00\x12\x34"
|
||||||
|
assert converted_nohex["unicode"] == "žluťoučký kůň"
|
||||||
|
assert converted_hex["bytes"] == "cafe001234"
|
||||||
|
assert converted_hex["unicode"] == "žluťoučký kůň"
|
||||||
|
|
||||||
|
recovered_nohex = protobuf.dict_to_proto(SimpleMessage, converted_nohex)
|
||||||
|
recovered_hex = protobuf.dict_to_proto(SimpleMessage, converted_hex)
|
||||||
|
|
||||||
|
assert recovered_nohex.bytes == msg.bytes
|
||||||
|
assert recovered_hex.bytes == msg.bytes
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_nested_round_trip():
|
||||||
|
msg = NestedMessage(
|
||||||
|
scalar=9,
|
||||||
|
nested=SimpleMessage(uvarint=4, enum=SimpleEnum.FOO),
|
||||||
|
repeated=[
|
||||||
|
SimpleMessage(),
|
||||||
|
SimpleMessage(rep_enum=[SimpleEnum.BAR, SimpleEnum.BAR]),
|
||||||
|
SimpleMessage(bytes=b"\xca\xfe"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
converted = protobuf.to_dict(msg)
|
||||||
|
recovered = protobuf.dict_to_proto(NestedMessage, converted)
|
||||||
|
|
||||||
|
assert msg == recovered
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_nested_to_dict():
|
||||||
|
msg = NestedMessage(
|
||||||
|
scalar=9,
|
||||||
|
nested=SimpleMessage(uvarint=4, enum=SimpleEnum.FOO),
|
||||||
|
repeated=[
|
||||||
|
SimpleMessage(),
|
||||||
|
SimpleMessage(rep_enum=[SimpleEnum.BAR, SimpleEnum.BAR]),
|
||||||
|
SimpleMessage(bytes=b"\xca\xfe"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
converted = protobuf.to_dict(msg)
|
||||||
|
assert converted["scalar"] == 9
|
||||||
|
assert isinstance(converted["nested"], dict)
|
||||||
|
assert isinstance(converted["repeated"], list)
|
||||||
|
|
||||||
|
rep = converted["repeated"]
|
||||||
|
assert rep[0] == {}
|
||||||
|
assert rep[1] == {"rep_enum": ["BAR", "BAR"]}
|
||||||
|
assert rep[2] == {"bytes": "cafe"}
|
||||||
|
|
||||||
|
|
||||||
|
@with_simple_enum
|
||||||
|
def test_nested_recover():
|
||||||
|
dictdata = {"nested": {}}
|
||||||
|
recovered = protobuf.dict_to_proto(NestedMessage, dictdata)
|
||||||
|
assert isinstance(recovered.nested, SimpleMessage)
|
Loading…
Reference in New Issue
Block a user