1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-23 07:58:09 +00:00

feat(core): prefill field_cache in bitcoin app

This commit is contained in:
matejcik 2021-03-19 16:34:47 +01:00 committed by matejcik
parent f3db4f2dd3
commit 1822aebdb4
2 changed files with 51 additions and 26 deletions

View File

@ -1,5 +1,14 @@
import gc
from trezor import utils, wire from trezor import utils, wire
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAckInput import TxAckInput
from trezor.messages.TxAckOutput import TxAckOutput
from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData
from trezor.messages.TxAckPrevInput import TxAckPrevInput
from trezor.messages.TxAckPrevMeta import TxAckPrevMeta
from trezor.messages.TxAckPrevOutput import TxAckPrevOutput
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from ..common import BITCOIN_NAMES from ..common import BITCOIN_NAMES
@ -14,14 +23,6 @@ if False:
from protobuf import FieldCache from protobuf import FieldCache
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAckInput import TxAckInput
from trezor.messages.TxAckOutput import TxAckOutput
from trezor.messages.TxAckPrevMeta import TxAckPrevMeta
from trezor.messages.TxAckPrevInput import TxAckPrevInput
from trezor.messages.TxAckPrevOutput import TxAckPrevOutput
from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -75,7 +76,18 @@ async def sign_tx(
signer = signer_class(msg, keychain, coin, approver).signer() signer = signer_class(msg, keychain, coin, approver).signer()
res: TxAckType | bool | None = None res: TxAckType | bool | None = None
gc.collect()
field_cache: FieldCache = {} field_cache: FieldCache = {}
TxRequest.cache_subordinate_types(field_cache)
SignTx.cache_subordinate_types(field_cache)
TxAckInput.cache_subordinate_types(field_cache)
TxAckOutput.cache_subordinate_types(field_cache)
TxAckPrevExtraData.cache_subordinate_types(field_cache)
TxAckPrevInput.cache_subordinate_types(field_cache)
TxAckPrevMeta.cache_subordinate_types(field_cache)
TxAckPrevOutput.cache_subordinate_types(field_cache)
while True: while True:
req = signer.send(res) req = signer.send(res)
if isinstance(req, tuple): if isinstance(req, tuple):

View File

@ -148,6 +148,24 @@ class UnicodeType:
WIRE_TYPE = 2 WIRE_TYPE = 2
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type["MessageType"],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type["MessageType"], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
class MessageType: class MessageType:
WIRE_TYPE = 2 WIRE_TYPE = 2
UNSTABLE = False UNSTABLE = False
@ -157,9 +175,21 @@ class MessageType:
MESSAGE_WIRE_TYPE = -1 MESSAGE_WIRE_TYPE = -1
@classmethod @classmethod
def get_fields(cls) -> "FieldDict": def get_fields(cls) -> FieldDict:
return {} return {}
@classmethod
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
if cls in field_cache:
fields = field_cache[cls]
else:
fields = cls.get_fields()
field_cache[cls] = fields
for _, field_type, _ in fields.values():
if isinstance(field_type, MessageType):
field_type.cache_subordinate_types(field_cache)
def __eq__(self, rhs: Any) -> bool: def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
@ -185,23 +215,6 @@ FLAG_REPEATED = object()
FLAG_REQUIRED = object() FLAG_REQUIRED = object()
FLAG_EXPERIMENTAL = object() FLAG_EXPERIMENTAL = object()
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type[MessageType],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type[MessageType], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
def load_message( def load_message(
reader: Reader, reader: Reader,