mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
feat(python): make the protobuf mappings overridable
This commit is contained in:
parent
dbf57d745a
commit
a2a8cc88d9
1
python/.changelog.d/1449.changed
Normal file
1
python/.changelog.d/1449.changed
Normal file
@ -0,0 +1 @@
|
|||||||
|
`trezorlib.mappings` was refactored for easier customization
|
@ -69,7 +69,7 @@ def send_bytes(
|
|||||||
click.echo(f"Response data: {response_data.hex()}")
|
click.echo(f"Response data: {response_data.hex()}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = mapping.decode(response_type, response_data)
|
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||||
click.echo("Parsed message:")
|
click.echo("Parsed message:")
|
||||||
click.echo(protobuf.format_message(msg))
|
click.echo(protobuf.format_message(msg))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -110,6 +110,7 @@ class TrezorClient:
|
|||||||
might be removed at any time.
|
might be removed at any time.
|
||||||
"""
|
"""
|
||||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||||
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.ui = ui
|
self.ui = ui
|
||||||
self.session_counter = 0
|
self.session_counter = 0
|
||||||
@ -142,7 +143,7 @@ class TrezorClient:
|
|||||||
f"sending message: {msg.__class__.__name__}",
|
f"sending message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
)
|
)
|
||||||
msg_type, msg_bytes = mapping.encode(msg)
|
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||||
LOG.log(
|
LOG.log(
|
||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
@ -156,7 +157,7 @@ class TrezorClient:
|
|||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
)
|
)
|
||||||
msg = mapping.decode(msg_type, msg_bytes)
|
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"received message: {msg.__class__.__name__}",
|
f"received message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
|
@ -67,6 +67,7 @@ class DebugLink:
|
|||||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.allow_interactions = auto_interact
|
self.allow_interactions = auto_interact
|
||||||
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.transport.begin_session()
|
self.transport.begin_session()
|
||||||
@ -79,7 +80,7 @@ class DebugLink:
|
|||||||
f"sending message: {msg.__class__.__name__}",
|
f"sending message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
)
|
)
|
||||||
msg_type, msg_bytes = mapping.encode(msg)
|
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||||
LOG.log(
|
LOG.log(
|
||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
@ -93,7 +94,7 @@ class DebugLink:
|
|||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
)
|
)
|
||||||
msg = mapping.decode(ret_type, ret_bytes)
|
msg = self.mapping.decode(ret_type, ret_bytes)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"received message: {msg.__class__.__name__}",
|
f"received message: {msg.__class__.__name__}",
|
||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
|
@ -15,17 +15,72 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Dict, Tuple, Type
|
from types import ModuleType
|
||||||
|
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
from . import messages, protobuf
|
from . import messages, protobuf
|
||||||
|
|
||||||
map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
T = TypeVar("T")
|
||||||
map_class_to_type: Dict[Type[protobuf.MessageType], int] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def build_map() -> None:
|
class ProtobufMapping:
|
||||||
for entry in messages.MessageType:
|
"""Mapping of protobuf classes to Python classes"""
|
||||||
msg_class = getattr(messages, entry.name, None)
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||||
|
self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
msg_class: Type[protobuf.MessageType],
|
||||||
|
msg_wire_type: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register a Python class as a protobuf type.
|
||||||
|
|
||||||
|
If `msg_wire_type` is specified, it is used instead of the internal value in
|
||||||
|
`msg_class`.
|
||||||
|
|
||||||
|
Any existing registrations are overwritten.
|
||||||
|
"""
|
||||||
|
if msg_wire_type is not None:
|
||||||
|
self.class_to_type_override[msg_class] = msg_wire_type
|
||||||
|
elif msg_class.MESSAGE_WIRE_TYPE is None:
|
||||||
|
raise ValueError("Cannot register class without wire type")
|
||||||
|
else:
|
||||||
|
msg_wire_type = msg_class.MESSAGE_WIRE_TYPE
|
||||||
|
|
||||||
|
self.type_to_class[msg_wire_type] = msg_class
|
||||||
|
|
||||||
|
def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||||
|
"""Serialize a Python protobuf class.
|
||||||
|
|
||||||
|
Returns the message wire type and a byte representation of the protobuf message.
|
||||||
|
"""
|
||||||
|
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||||
|
if wire_type is None:
|
||||||
|
raise ValueError("Cannot encode class without wire type")
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
return wire_type, buf.getvalue()
|
||||||
|
|
||||||
|
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||||
|
"""Deserialize a protobuf message into a Python class."""
|
||||||
|
cls = self.type_to_class[msg_wire_type]
|
||||||
|
buf = io.BytesIO(msg_bytes)
|
||||||
|
return protobuf.load_message(buf, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_module(cls: Type[T], module: ModuleType) -> T:
|
||||||
|
"""Generate a mapping from a module.
|
||||||
|
|
||||||
|
The module must have a `MessageType` enum that specifies individual wire types.
|
||||||
|
"""
|
||||||
|
mapping = cls()
|
||||||
|
|
||||||
|
message_types = getattr(module, "MessageType")
|
||||||
|
for entry in message_types:
|
||||||
|
msg_class = getattr(module, entry.name, None)
|
||||||
if msg_class is None:
|
if msg_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Implementation of protobuf message '{entry.name}' is missing"
|
f"Implementation of protobuf message '{entry.name}' is missing"
|
||||||
@ -36,45 +91,9 @@ def build_map() -> None:
|
|||||||
f"Inconsistent wire type and MessageType record for '{entry.name}'"
|
f"Inconsistent wire type and MessageType record for '{entry.name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
register_message(msg_class)
|
mapping.register(msg_class)
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
def register_message(msg_class: Type[protobuf.MessageType]) -> None:
|
DEFAULT_MAPPING = ProtobufMapping.from_module(messages)
|
||||||
if msg_class.MESSAGE_WIRE_TYPE is None:
|
|
||||||
raise ValueError("Only messages with a wire type can be registered")
|
|
||||||
|
|
||||||
if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class:
|
|
||||||
raise Exception(
|
|
||||||
f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already "
|
|
||||||
f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE
|
|
||||||
map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class
|
|
||||||
|
|
||||||
|
|
||||||
def get_type(msg: protobuf.MessageType) -> int:
|
|
||||||
return map_class_to_type[msg.__class__]
|
|
||||||
|
|
||||||
|
|
||||||
def get_class(t: int) -> Type[protobuf.MessageType]:
|
|
||||||
return map_type_to_class[t]
|
|
||||||
|
|
||||||
|
|
||||||
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
|
||||||
if msg.MESSAGE_WIRE_TYPE is None:
|
|
||||||
raise ValueError("Only messages with a wire type can be encoded")
|
|
||||||
|
|
||||||
message_type = msg.MESSAGE_WIRE_TYPE
|
|
||||||
buf = io.BytesIO()
|
|
||||||
protobuf.dump_message(buf, msg)
|
|
||||||
return message_type, buf.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
|
|
||||||
cls = get_class(message_type)
|
|
||||||
buf = io.BytesIO(message_bytes)
|
|
||||||
return protobuf.load_message(buf, cls)
|
|
||||||
|
|
||||||
|
|
||||||
build_map()
|
|
||||||
|
@ -36,7 +36,7 @@ class ApplySettingsCompat(protobuf.MessageType):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
mapping.map_class_to_type[ApplySettingsCompat] = ApplySettingsCompat.MESSAGE_WIRE_TYPE
|
mapping.DEFAULT_MAPPING.register(ApplySettingsCompat)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
Reference in New Issue
Block a user