diff --git a/core/src/apps/bitcoin/sign_tx/__init__.py b/core/src/apps/bitcoin/sign_tx/__init__.py index 7400e2085e..39b880f878 100644 --- a/core/src/apps/bitcoin/sign_tx/__init__.py +++ b/core/src/apps/bitcoin/sign_tx/__init__.py @@ -1,5 +1,14 @@ +import gc + from trezor import utils, wire from trezor.messages.RequestType import TXFINISHED +from trezor.messages.SignTx import SignTx +from trezor.messages.TxAckInput import TxAckInput +from trezor.messages.TxAckOutput import TxAckOutput +from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData +from trezor.messages.TxAckPrevInput import TxAckPrevInput +from trezor.messages.TxAckPrevMeta import TxAckPrevMeta +from trezor.messages.TxAckPrevOutput import TxAckPrevOutput from trezor.messages.TxRequest import TxRequest from ..common import BITCOIN_NAMES @@ -14,14 +23,6 @@ if False: from protobuf import FieldCache - from trezor.messages.SignTx import SignTx - from trezor.messages.TxAckInput import TxAckInput - from trezor.messages.TxAckOutput import TxAckOutput - from trezor.messages.TxAckPrevMeta import TxAckPrevMeta - from trezor.messages.TxAckPrevInput import TxAckPrevInput - from trezor.messages.TxAckPrevOutput import TxAckPrevOutput - from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData - from apps.common.coininfo import CoinInfo from apps.common.keychain import Keychain @@ -75,7 +76,18 @@ async def sign_tx( signer = signer_class(msg, keychain, coin, approver).signer() res: TxAckType | bool | None = None + + gc.collect() field_cache: FieldCache = {} + TxRequest.cache_subordinate_types(field_cache) + SignTx.cache_subordinate_types(field_cache) + TxAckInput.cache_subordinate_types(field_cache) + TxAckOutput.cache_subordinate_types(field_cache) + TxAckPrevExtraData.cache_subordinate_types(field_cache) + TxAckPrevInput.cache_subordinate_types(field_cache) + TxAckPrevMeta.cache_subordinate_types(field_cache) + TxAckPrevOutput.cache_subordinate_types(field_cache) + while True: req = signer.send(res) if isinstance(req, tuple): diff --git a/core/src/protobuf.py b/core/src/protobuf.py index c8e3919a1f..5e879b3297 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -148,6 +148,24 @@ class UnicodeType: WIRE_TYPE = 2 +if False: + MessageTypeDef = Union[ + type[UVarintType], + type[SVarintType], + type[BoolType], + EnumType, + type[BytesType], + type[UnicodeType], + type["MessageType"], + ] + FieldDef = tuple[str, MessageTypeDef, Any] + FieldDict = dict[int, FieldDef] + + FieldCache = dict[type["MessageType"], FieldDict] + + LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType") + + class MessageType: WIRE_TYPE = 2 UNSTABLE = False @@ -157,9 +175,21 @@ class MessageType: MESSAGE_WIRE_TYPE = -1 @classmethod - def get_fields(cls) -> "FieldDict": + def get_fields(cls) -> FieldDict: return {} + @classmethod + def cache_subordinate_types(cls, field_cache: FieldCache) -> None: + if cls in field_cache: + fields = field_cache[cls] + else: + fields = cls.get_fields() + field_cache[cls] = fields + + for _, field_type, _ in fields.values(): + if isinstance(field_type, MessageType): + field_type.cache_subordinate_types(field_cache) + def __eq__(self, rhs: Any) -> bool: return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ @@ -185,23 +215,6 @@ FLAG_REPEATED = object() FLAG_REQUIRED = object() FLAG_EXPERIMENTAL = object() -if False: - MessageTypeDef = Union[ - type[UVarintType], - type[SVarintType], - type[BoolType], - EnumType, - type[BytesType], - type[UnicodeType], - type[MessageType], - ] - FieldDef = tuple[str, MessageTypeDef, Any] - FieldDict = dict[int, FieldDef] - - FieldCache = dict[type[MessageType], FieldDict] - - LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType) - def load_message( reader: Reader,