feat(core): prefill field_cache in bitcoin app

pull/1610/head
matejcik 3 years ago committed by matejcik
parent f3db4f2dd3
commit 1822aebdb4

@ -1,5 +1,14 @@
import gc
from trezor import utils, wire
from trezor.messages.RequestType import TXFINISHED
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAckInput import TxAckInput
from trezor.messages.TxAckOutput import TxAckOutput
from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData
from trezor.messages.TxAckPrevInput import TxAckPrevInput
from trezor.messages.TxAckPrevMeta import TxAckPrevMeta
from trezor.messages.TxAckPrevOutput import TxAckPrevOutput
from trezor.messages.TxRequest import TxRequest
from ..common import BITCOIN_NAMES
@ -14,14 +23,6 @@ if False:
from protobuf import FieldCache
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAckInput import TxAckInput
from trezor.messages.TxAckOutput import TxAckOutput
from trezor.messages.TxAckPrevMeta import TxAckPrevMeta
from trezor.messages.TxAckPrevInput import TxAckPrevInput
from trezor.messages.TxAckPrevOutput import TxAckPrevOutput
from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
@ -75,7 +76,18 @@ async def sign_tx(
signer = signer_class(msg, keychain, coin, approver).signer()
res: TxAckType | bool | None = None
gc.collect()
field_cache: FieldCache = {}
TxRequest.cache_subordinate_types(field_cache)
SignTx.cache_subordinate_types(field_cache)
TxAckInput.cache_subordinate_types(field_cache)
TxAckOutput.cache_subordinate_types(field_cache)
TxAckPrevExtraData.cache_subordinate_types(field_cache)
TxAckPrevInput.cache_subordinate_types(field_cache)
TxAckPrevMeta.cache_subordinate_types(field_cache)
TxAckPrevOutput.cache_subordinate_types(field_cache)
while True:
req = signer.send(res)
if isinstance(req, tuple):

@ -148,6 +148,24 @@ class UnicodeType:
WIRE_TYPE = 2
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type["MessageType"],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type["MessageType"], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
class MessageType:
WIRE_TYPE = 2
UNSTABLE = False
@ -157,9 +175,21 @@ class MessageType:
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls) -> "FieldDict":
def get_fields(cls) -> FieldDict:
return {}
@classmethod
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
if cls in field_cache:
fields = field_cache[cls]
else:
fields = cls.get_fields()
field_cache[cls] = fields
for _, field_type, _ in fields.values():
if isinstance(field_type, MessageType):
field_type.cache_subordinate_types(field_cache)
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
@ -185,23 +215,6 @@ FLAG_REPEATED = object()
FLAG_REQUIRED = object()
FLAG_EXPERIMENTAL = object()
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type[MessageType],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type[MessageType], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
def load_message(
reader: Reader,

Loading…
Cancel
Save