1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 23:08:14 +00:00

feat(python): make the protobuf mappings overridable

This commit is contained in:
matejcik 2021-02-05 11:57:44 +01:00 committed by matejcik
parent dbf57d745a
commit a2a8cc88d9
6 changed files with 82 additions and 60 deletions

View File

@ -0,0 +1 @@
`trezorlib.mappings` was refactored for easier customization

View File

@ -69,7 +69,7 @@ def send_bytes(
click.echo(f"Response data: {response_data.hex()}")
try:
msg = mapping.decode(response_type, response_data)
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:")
click.echo(protobuf.format_message(msg))
except Exception as e:

View File

@ -110,6 +110,7 @@ class TrezorClient:
might be removed at any time.
"""
LOG.info(f"creating client instance for device: {transport.get_path()}")
self.mapping = mapping.DEFAULT_MAPPING
self.transport = transport
self.ui = ui
self.session_counter = 0
@ -142,7 +143,7 @@ class TrezorClient:
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = mapping.encode(msg)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
@ -156,7 +157,7 @@ class TrezorClient:
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = mapping.decode(msg_type, msg_bytes)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},

View File

@ -67,6 +67,7 @@ class DebugLink:
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.transport = transport
self.allow_interactions = auto_interact
self.mapping = mapping.DEFAULT_MAPPING
def open(self) -> None:
self.transport.begin_session()
@ -79,7 +80,7 @@ class DebugLink:
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = mapping.encode(msg)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
@ -93,7 +94,7 @@ class DebugLink:
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = mapping.decode(ret_type, ret_bytes)
msg = self.mapping.decode(ret_type, ret_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},

View File

@ -15,66 +15,85 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import io
from typing import Dict, Tuple, Type
from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar
from . import messages, protobuf
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
T = TypeVar("T")
def build_map() -> None:
for entry in messages.MessageType:
msg_class = getattr(messages, entry.name, None)
if msg_class is None:
raise ValueError(
f"Implementation of protobuf message '{entry.name}' is missing"
)
class ProtobufMapping:
"""Mapping of protobuf classes to Python classes"""
if msg_class.MESSAGE_WIRE_TYPE != entry.value:
raise ValueError(
f"Inconsistent wire type and MessageType record for '{entry.name}'"
)
def __init__(self) -> None:
self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {}
register_message(msg_class)
def register(
self,
msg_class: Type[protobuf.MessageType],
msg_wire_type: Optional[int] = None,
) -> None:
"""Register a Python class as a protobuf type.
If `msg_wire_type` is specified, it is used instead of the internal value in
`msg_class`.
Any existing registrations are overwritten.
"""
if msg_wire_type is not None:
self.class_to_type_override[msg_class] = msg_wire_type
elif msg_class.MESSAGE_WIRE_TYPE is None:
raise ValueError("Cannot register class without wire type")
else:
msg_wire_type = msg_class.MESSAGE_WIRE_TYPE
self.type_to_class[msg_wire_type] = msg_class
def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]:
"""Serialize a Python protobuf class.
Returns the message wire type and a byte representation of the protobuf message.
"""
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None:
raise ValueError("Cannot encode class without wire type")
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type]
buf = io.BytesIO(msg_bytes)
return protobuf.load_message(buf, cls)
@classmethod
def from_module(cls: Type[T], module: ModuleType) -> T:
"""Generate a mapping from a module.
The module must have a `MessageType` enum that specifies individual wire types.
"""
mapping = cls()
message_types = getattr(module, "MessageType")
for entry in message_types:
msg_class = getattr(module, entry.name, None)
if msg_class is None:
raise ValueError(
f"Implementation of protobuf message '{entry.name}' is missing"
)
if msg_class.MESSAGE_WIRE_TYPE != entry.value:
raise ValueError(
f"Inconsistent wire type and MessageType record for '{entry.name}'"
)
mapping.register(msg_class)
return mapping
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
if msg_class.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be registered")
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
raise Exception(
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
)
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
def get_type(msg: protobuf.MessageType) -> int:
return map_class_to_type[msg.__class__]
def get_class(t: int) -> Type[protobuf.MessageType]:
return map_type_to_class[t]
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
if msg.MESSAGE_WIRE_TYPE is None:
raise ValueError("Only messages with a wire type can be encoded")
message_type = msg.MESSAGE_WIRE_TYPE
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return message_type, buf.getvalue()
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
cls = get_class(message_type)
buf = io.BytesIO(message_bytes)
return protobuf.load_message(buf, cls)
build_map()
DEFAULT_MAPPING = ProtobufMapping.from_module(messages)

View File

@ -36,7 +36,7 @@ class ApplySettingsCompat(protobuf.MessageType):
}
mapping.map_class_to_type[ApplySettingsCompat] = ApplySettingsCompat.MESSAGE_WIRE_TYPE
mapping.DEFAULT_MAPPING.register(ApplySettingsCompat)
@pytest.fixture