diff --git a/python/.changelog.d/1449.changed b/python/.changelog.d/1449.changed new file mode 100644 index 000000000..5d6ea71b5 --- /dev/null +++ b/python/.changelog.d/1449.changed @@ -0,0 +1 @@ +`trezorlib.mappings` was refactored for easier customization diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index bf726fc9f..4a2fdefd1 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -69,7 +69,7 @@ def send_bytes( click.echo(f"Response data: {response_data.hex()}") try: - msg = mapping.decode(response_type, response_data) + msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) click.echo("Parsed message:") click.echo(protobuf.format_message(msg)) except Exception as e: diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index f9945c0ea..9ab596581 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -110,6 +110,7 @@ class TrezorClient: might be removed at any time. """ LOG.info(f"creating client instance for device: {transport.get_path()}") + self.mapping = mapping.DEFAULT_MAPPING self.transport = transport self.ui = ui self.session_counter = 0 @@ -142,7 +143,7 @@ class TrezorClient: f"sending message: {msg.__class__.__name__}", extra={"protobuf": msg}, ) - msg_type, msg_bytes = mapping.encode(msg) + msg_type, msg_bytes = self.mapping.encode(msg) LOG.log( DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", @@ -156,7 +157,7 @@ class TrezorClient: DUMP_BYTES, f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - msg = mapping.decode(msg_type, msg_bytes) + msg = self.mapping.decode(msg_type, msg_bytes) LOG.debug( f"received message: {msg.__class__.__name__}", extra={"protobuf": msg}, diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 2b28c58c4..50a9bc16c 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -67,6 +67,7 @@ class DebugLink: def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact + self.mapping = mapping.DEFAULT_MAPPING def open(self) -> None: self.transport.begin_session() @@ -79,7 +80,7 @@ class DebugLink: f"sending message: {msg.__class__.__name__}", extra={"protobuf": msg}, ) - msg_type, msg_bytes = mapping.encode(msg) + msg_type, msg_bytes = self.mapping.encode(msg) LOG.log( DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", @@ -93,7 +94,7 @@ class DebugLink: DUMP_BYTES, f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - msg = mapping.decode(ret_type, ret_bytes) + msg = self.mapping.decode(ret_type, ret_bytes) LOG.debug( f"received message: {msg.__class__.__name__}", extra={"protobuf": msg}, diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 37132ccdb..6000ee632 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -15,66 +15,85 @@ # If not, see . import io -from typing import Dict, Tuple, Type +from types import ModuleType +from typing import Dict, Optional, Tuple, Type, TypeVar from . import messages, protobuf -map_type_to_class: Dict[int, Type[protobuf.MessageType]] = {} -map_class_to_type: Dict[Type[protobuf.MessageType], int] = {} +T = TypeVar("T") -def build_map() -> None: - for entry in messages.MessageType: - msg_class = getattr(messages, entry.name, None) - if msg_class is None: - raise ValueError( - f"Implementation of protobuf message '{entry.name}' is missing" - ) +class ProtobufMapping: + """Mapping of protobuf classes to Python classes""" - if msg_class.MESSAGE_WIRE_TYPE != entry.value: - raise ValueError( - f"Inconsistent wire type and MessageType record for '{entry.name}'" - ) + def __init__(self) -> None: + self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {} + self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {} - register_message(msg_class) + def register( + self, + msg_class: Type[protobuf.MessageType], + msg_wire_type: Optional[int] = None, + ) -> None: + """Register a Python class as a protobuf type. + If `msg_wire_type` is specified, it is used instead of the internal value in + `msg_class`. -def register_message(msg_class: Type[protobuf.MessageType]) -> None: - if msg_class.MESSAGE_WIRE_TYPE is None: - raise ValueError("Only messages with a wire type can be registered") + Any existing registrations are overwritten. + """ + if msg_wire_type is not None: + self.class_to_type_override[msg_class] = msg_wire_type + elif msg_class.MESSAGE_WIRE_TYPE is None: + raise ValueError("Cannot register class without wire type") + else: + msg_wire_type = msg_class.MESSAGE_WIRE_TYPE - if msg_class.MESSAGE_WIRE_TYPE in map_type_to_class: - raise Exception( - f"Message for wire type {msg_class.MESSAGE_WIRE_TYPE} is already " - f"registered by {get_class(msg_class.MESSAGE_WIRE_TYPE)}" - ) + self.type_to_class[msg_wire_type] = msg_class - map_class_to_type[msg_class] = msg_class.MESSAGE_WIRE_TYPE - map_type_to_class[msg_class.MESSAGE_WIRE_TYPE] = msg_class + def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]: + """Serialize a Python protobuf class. + Returns the message wire type and a byte representation of the protobuf message. + """ + wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) + if wire_type is None: + raise ValueError("Cannot encode class without wire type") -def get_type(msg: protobuf.MessageType) -> int: - return map_class_to_type[msg.__class__] + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return wire_type, buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: + """Deserialize a protobuf message into a Python class.""" + cls = self.type_to_class[msg_wire_type] + buf = io.BytesIO(msg_bytes) + return protobuf.load_message(buf, cls) -def get_class(t: int) -> Type[protobuf.MessageType]: - return map_type_to_class[t] + @classmethod + def from_module(cls: Type[T], module: ModuleType) -> T: + """Generate a mapping from a module. + The module must have a `MessageType` enum that specifies individual wire types. + """ + mapping = cls() -def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]: - if msg.MESSAGE_WIRE_TYPE is None: - raise ValueError("Only messages with a wire type can be encoded") + message_types = getattr(module, "MessageType") + for entry in message_types: + msg_class = getattr(module, entry.name, None) + if msg_class is None: + raise ValueError( + f"Implementation of protobuf message '{entry.name}' is missing" + ) - message_type = msg.MESSAGE_WIRE_TYPE - buf = io.BytesIO() - protobuf.dump_message(buf, msg) - return message_type, buf.getvalue() + if msg_class.MESSAGE_WIRE_TYPE != entry.value: + raise ValueError( + f"Inconsistent wire type and MessageType record for '{entry.name}'" + ) + mapping.register(msg_class) -def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType: - cls = get_class(message_type) - buf = io.BytesIO(message_bytes) - return protobuf.load_message(buf, cls) + return mapping -build_map() +DEFAULT_MAPPING = ProtobufMapping.from_module(messages) diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 3c62f2ed4..2974bdf55 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -36,7 +36,7 @@ class ApplySettingsCompat(protobuf.MessageType): } -mapping.map_class_to_type[ApplySettingsCompat] = ApplySettingsCompat.MESSAGE_WIRE_TYPE +mapping.DEFAULT_MAPPING.register(ApplySettingsCompat) @pytest.fixture