mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
feat(python): make the protobuf mappings overridable
This commit is contained in:
parent
dbf57d745a
commit
a2a8cc88d9
1
python/.changelog.d/1449.changed
Normal file
1
python/.changelog.d/1449.changed
Normal file
@ -0,0 +1 @@
|
||||
`trezorlib.mappings` was refactored for easier customization
|
@ -69,7 +69,7 @@ def send_bytes(
|
||||
click.echo(f"Response data: {response_data.hex()}")
|
||||
|
||||
try:
|
||||
msg = mapping.decode(response_type, response_data)
|
||||
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
click.echo("Parsed message:")
|
||||
click.echo(protobuf.format_message(msg))
|
||||
except Exception as e:
|
||||
|
@ -110,6 +110,7 @@ class TrezorClient:
|
||||
might be removed at any time.
|
||||
"""
|
||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
self.session_counter = 0
|
||||
@ -142,7 +143,7 @@ class TrezorClient:
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = mapping.encode(msg)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
@ -156,7 +157,7 @@ class TrezorClient:
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = mapping.decode(msg_type, msg_bytes)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
|
@ -67,6 +67,7 @@ class DebugLink:
|
||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||
self.transport = transport
|
||||
self.allow_interactions = auto_interact
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
|
||||
def open(self) -> None:
|
||||
self.transport.begin_session()
|
||||
@ -79,7 +80,7 @@ class DebugLink:
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = mapping.encode(msg)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
@ -93,7 +94,7 @@ class DebugLink:
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = mapping.decode(ret_type, ret_bytes)
|
||||
msg = self.mapping.decode(ret_type, ret_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
|
@ -15,17 +15,72 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import io
|
||||
from typing import Dict, Tuple, Type
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from . import messages, protobuf
|
||||
|
||||
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def build_map() -> None:
|
||||
for entry in messages.MessageType:
|
||||
msg_class = getattr(messages, entry.name, None)
|
||||
class ProtobufMapping:
|
||||
"""Mapping of protobuf classes to Python classes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||
self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
msg_class: Type[protobuf.MessageType],
|
||||
msg_wire_type: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Register a Python class as a protobuf type.
|
||||
|
||||
If `msg_wire_type` is specified, it is used instead of the internal value in
|
||||
`msg_class`.
|
||||
|
||||
Any existing registrations are overwritten.
|
||||
"""
|
||||
if msg_wire_type is not None:
|
||||
self.class_to_type_override[msg_class] = msg_wire_type
|
||||
elif msg_class.MESSAGE_WIRE_TYPE is None:
|
||||
raise ValueError("Cannot register class without wire type")
|
||||
else:
|
||||
msg_wire_type = msg_class.MESSAGE_WIRE_TYPE
|
||||
|
||||
self.type_to_class[msg_wire_type] = msg_class
|
||||
|
||||
def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||
"""Serialize a Python protobuf class.
|
||||
|
||||
Returns the message wire type and a byte representation of the protobuf message.
|
||||
"""
|
||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||
if wire_type is None:
|
||||
raise ValueError("Cannot encode class without wire type")
|
||||
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return wire_type, buf.getvalue()
|
||||
|
||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||
"""Deserialize a protobuf message into a Python class."""
|
||||
cls = self.type_to_class[msg_wire_type]
|
||||
buf = io.BytesIO(msg_bytes)
|
||||
return protobuf.load_message(buf, cls)
|
||||
|
||||
@classmethod
|
||||
def from_module(cls: Type[T], module: ModuleType) -> T:
|
||||
"""Generate a mapping from a module.
|
||||
|
||||
The module must have a `MessageType` enum that specifies individual wire types.
|
||||
"""
|
||||
mapping = cls()
|
||||
|
||||
message_types = getattr(module, "MessageType")
|
||||
for entry in message_types:
|
||||
msg_class = getattr(module, entry.name, None)
|
||||
if msg_class is None:
|
||||
raise ValueError(
|
||||
f"Implementation of protobuf message '{entry.name}' is missing"
|
||||
@ -36,45 +91,9 @@ def build_map() -> None:
|
||||
f"Inconsistent wire type and MessageType record for '{entry.name}'"
|
||||
)
|
||||
|
||||
register_message(msg_class)
|
||||
mapping.register(msg_class)
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
|
||||
if msg_class.MESSAGE_WIRE_TYPE is None:
|
||||
raise ValueError("Only messages with a wire type can be registered")
|
||||
|
||||
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
|
||||
raise Exception(
|
||||
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
|
||||
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
||||
)
|
||||
|
||||
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
|
||||
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
|
||||
|
||||
|
||||
def get_type(msg: protobuf.MessageType) -> int:
|
||||
return map_class_to_type[msg.__class__]
|
||||
|
||||
|
||||
def get_class(t: int) -> Type[protobuf.MessageType]:
|
||||
return map_type_to_class[t]
|
||||
|
||||
|
||||
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||
if msg.MESSAGE_WIRE_TYPE is None:
|
||||
raise ValueError("Only messages with a wire type can be encoded")
|
||||
|
||||
message_type = msg.MESSAGE_WIRE_TYPE
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return message_type, buf.getvalue()
|
||||
|
||||
|
||||
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
|
||||
cls = get_class(message_type)
|
||||
buf = io.BytesIO(message_bytes)
|
||||
return protobuf.load_message(buf, cls)
|
||||
|
||||
|
||||
build_map()
|
||||
DEFAULT_MAPPING = ProtobufMapping.from_module(messages)
|
||||
|
@ -36,7 +36,7 @@ class ApplySettingsCompat(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
mapping.map_class_to_type[ApplySettingsCompat] = ApplySettingsCompat.MESSAGE_WIRE_TYPE
|
||||
mapping.DEFAULT_MAPPING.register(ApplySettingsCompat)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user