1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-21 15:08:12 +00:00

feat(core): enforce return type annotations

This commit is contained in:
matejcik 2024-11-08 16:03:35 +01:00 committed by matejcik
parent 34d97ee942
commit 8fb41ee290
34 changed files with 110 additions and 71 deletions

View File

@ -26,7 +26,7 @@ from .hash_benchmark import HashBenchmark
# This is a wrapper above the trezor.crypto.curve.ed25519 module that satisfies SignCurve protocol, the modules uses `message` instead of `digest` in `sign()` and `verify()`
class Ed25519:
def __init__(self):
def __init__(self) -> None:
pass
def generate_secret(self) -> bytes:

View File

@ -16,17 +16,17 @@ if TYPE_CHECKING:
class EncryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size
def prepare(self):
def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.cipher_ctx.encrypt(self.data)
@ -44,17 +44,17 @@ class EncryptBenchmark:
class DecryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size
def prepare(self):
def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.cipher_ctx.decrypt(self.data)

View File

@ -32,15 +32,15 @@ if TYPE_CHECKING:
class SignBenchmark:
def __init__(self, curve: SignCurve):
def __init__(self, curve: SignCurve) -> None:
self.curve = curve
def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.digest = random_bytes(32)
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.sign(self.secret_key, self.digest)
@ -51,17 +51,17 @@ class SignBenchmark:
class VerifyBenchmark:
def __init__(self, curve: SignCurve):
def __init__(self, curve: SignCurve) -> None:
self.curve = curve
def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.secret_key)
self.digest = random_bytes(32)
self.signature = self.curve.sign(self.secret_key, self.digest)
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.verify(self.public_key, self.signature, self.digest)
@ -72,15 +72,15 @@ class VerifyBenchmark:
class MultiplyBenchmark:
def __init__(self, curve: MultiplyCurve):
def __init__(self, curve: MultiplyCurve) -> None:
self.curve = curve
def prepare(self):
def prepare(self) -> None:
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.curve.generate_secret())
self.iterations_count = 10
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.multiply(self.secret_key, self.public_key)
@ -91,14 +91,14 @@ class MultiplyBenchmark:
class PublickeyBenchmark:
def __init__(self, curve: Curve):
def __init__(self, curve: Curve) -> None:
self.curve = curve
def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.publickey(self.secret_key)

View File

@ -14,16 +14,16 @@ if TYPE_CHECKING:
class HashBenchmark:
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]):
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]) -> None:
self.hash_ctx_constructor = hash_ctx_constructor
def prepare(self):
def prepare(self) -> None:
self.hash_ctx = self.hash_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.hash_ctx.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.hash_ctx.block_size)
def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.hash_ctx.update(self.data)

View File

@ -345,7 +345,7 @@ class AccountType:
require_bech32: bool,
require_taproot: bool,
account_level: bool = False,
):
) -> None:
self.account_name = account_name
self.pattern = pattern
self.script_type = script_type

View File

@ -46,7 +46,7 @@ class UiConfirmOutput(UiConfirm):
output_index: int,
chunkify: bool,
address_n: Bip32Path | None,
):
) -> None:
self.output = output
self.coin = coin
self.amount_unit = amount_unit
@ -66,7 +66,9 @@ class UiConfirmOutput(UiConfirm):
class UiConfirmDecredSSTXSubmission(UiConfirm):
def __init__(self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit):
def __init__(
self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None:
self.output = output
self.coin = coin
self.amount_unit = amount_unit
@ -83,7 +85,7 @@ class UiConfirmPaymentRequest(UiConfirm):
payment_req: TxAckPaymentRequest,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.payment_req = payment_req
self.amount_unit = amount_unit
self.coin = coin
@ -97,7 +99,7 @@ class UiConfirmPaymentRequest(UiConfirm):
class UiConfirmReplacement(UiConfirm):
def __init__(self, title: str, txid: bytes):
def __init__(self, title: str, txid: bytes) -> None:
self.title = title
self.txid = txid
@ -112,7 +114,7 @@ class UiConfirmModifyOutput(UiConfirm):
orig_txo: TxOutput,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.txo = txo
self.orig_txo = orig_txo
self.coin = coin
@ -133,7 +135,7 @@ class UiConfirmModifyFee(UiConfirm):
fee_rate: float,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.title = title
self.user_fee_change = user_fee_change
self.total_fee_new = total_fee_new
@ -161,7 +163,7 @@ class UiConfirmTotal(UiConfirm):
coin: CoinInfo,
amount_unit: AmountUnit,
address_n: Bip32Path | None,
):
) -> None:
self.spending = spending
self.fee = fee
self.fee_rate = fee_rate
@ -183,7 +185,7 @@ class UiConfirmTotal(UiConfirm):
class UiConfirmJointTotal(UiConfirm):
def __init__(
self, spending: int, total: int, coin: CoinInfo, amount_unit: AmountUnit
):
) -> None:
self.spending = spending
self.total = total
self.coin = coin
@ -196,7 +198,7 @@ class UiConfirmJointTotal(UiConfirm):
class UiConfirmFeeOverThreshold(UiConfirm):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit) -> None:
self.fee = fee
self.coin = coin
self.amount_unit = amount_unit
@ -206,7 +208,7 @@ class UiConfirmFeeOverThreshold(UiConfirm):
class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int):
def __init__(self, change_count: int) -> None:
self.change_count = change_count
def confirm_dialog(self) -> Awaitable[Any]:
@ -219,7 +221,7 @@ class UiConfirmUnverifiedExternalInput(UiConfirm):
class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list):
def __init__(self, address_n: list) -> None:
self.address_n = address_n
def confirm_dialog(self) -> Awaitable[Any]:
@ -229,7 +231,7 @@ class UiConfirmForeignAddress(UiConfirm):
class UiConfirmNonDefaultLocktime(UiConfirm):
def __init__(self, lock_time: int, lock_time_disabled: bool):
def __init__(self, lock_time: int, lock_time_disabled: bool) -> None:
self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled

View File

@ -14,7 +14,7 @@ _PREV_TX_MULTIPLIER = 5
class Progress:
def __init__(self):
def __init__(self) -> None:
self.progress = 0
self.steps = 0
self.signing = False

View File

@ -17,7 +17,7 @@ class SignatureVerifier:
script_sig: bytes | None,
witness: bytes | None,
coin: CoinInfo,
):
) -> None:
from trezor import utils
from trezor.crypto.hashlib import sha256
from trezor.wire import DataError # local_cache_global

View File

@ -42,7 +42,7 @@ class Credential:
key_hash: bytes | None,
script_hash: bytes | None,
pointer: messages.CardanoBlockchainPointerType | None,
):
) -> None:
self.type_name = type_name
self.address_type = address_type
self.path = path

View File

@ -107,7 +107,7 @@ class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError
previous_encoded_key: bytes
def __init__(self, size: int, key_order_error: wire.ProcessError):
def __init__(self, size: int, key_order_error: wire.ProcessError) -> None:
super().__init__(size)
self.key_order_error = key_order_error
self.previous_encoded_key = b""

View File

@ -23,7 +23,7 @@ class OrdinarySigner(Signer):
self,
msg: messages.CardanoSignTxInit,
keychain: seed.Keychain,
):
) -> None:
super().__init__(msg, keychain)
self.suite_tx_type: SuiteTxType = self._suite_tx_type()

View File

@ -12,12 +12,12 @@ def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup():
def activate_repeated_backup() -> None:
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter)
def deactivate_repeated_backup():
def deactivate_repeated_backup() -> None:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter)

View File

@ -229,7 +229,7 @@ class Tagged:
# TODO: this seems to be unused - is checked against, but is never created???
class Raw:
def __init__(self, value: Value):
def __init__(self, value: Value) -> None:
self.value = value

View File

@ -142,7 +142,7 @@ async def require_confirm_claim(
)
async def require_confirm_unknown_token(address_bytes: bytes):
async def require_confirm_unknown_token(address_bytes: bytes) -> None:
from ubinascii import hexlify
from trezor.ui.layouts import confirm_address, show_warning

View File

@ -233,7 +233,7 @@ def _get_slip39_mnemonics(
group_threshold: int,
groups: Sequence[tuple[int, int]],
extendable: bool,
):
) -> list[list[str]]:
if extendable:
identifier = slip39.generate_random_identifier()
else:

View File

@ -25,7 +25,7 @@ class State:
from apps.monero.xmr.mlsag_hasher import PreMlsagHasher
# Account credentials
# type: AccountCreds
# - type: AccountCreds
# - view private/public key
# - spend private/public key
# - and its corresponding address

View File

@ -16,6 +16,8 @@ from apps.common.paths import address_n_to_str
from .types import AddressType
if TYPE_CHECKING:
from typing import Sequence
from .transaction.instructions import Instruction, SystemProgramTransferInstruction
from .types import AddressReference
@ -33,7 +35,9 @@ def _format_path(path: list[int]) -> str:
return f"Solana #{unharden(account_index) + 1}"
def _get_address_reference_props(address: AddressReference, display_name: str):
def _get_address_reference_props(
address: AddressReference, display_name: str
) -> Sequence[tuple[str, str]]:
return (
(TR.solana__is_provided_via_lookup_table_template.format(display_name), ""),
(f"{TR.solana__lookup_table_address}:", base58.encode(address[0])),
@ -293,7 +297,7 @@ async def confirm_token_transfer(
fee: int,
signer_path: list[int],
blockhash: bytes,
):
) -> None:
await confirm_value(
title=TR.words__recipient,
value=base58.encode(destination_account),

View File

@ -42,7 +42,7 @@ def get_create_associated_token_account_instructions(
def is_predefined_token_transfer(
instructions: list[Instruction],
):
) -> bool:
"""
Checks that the transaction consists of one or zero create token account instructions
and one or more transfer token instructions. Also checks that the token program, token mint

View File

@ -41,7 +41,7 @@ class Instruction:
@staticmethod
def parse_instruction_data(
instruction_data: InstructionData, property_templates: list[PropertyTemplate]
):
) -> dict[str, Any]:
from trezor.utils import BufferReader
from trezor.wire import DataError
@ -65,7 +65,7 @@ class Instruction:
@staticmethod
def parse_instruction_accounts(
accounts: list[Account], accounts_template: list[AccountTemplate]
):
) -> dict[str, Account]:
parsed_account = {}
for i, account_template in enumerate(accounts_template):
if i >= len(accounts):

View File

@ -20,7 +20,7 @@ from .instruction import Instruction
from .parse import parse_byte, parse_memo, parse_pubkey, parse_string
if TYPE_CHECKING:
from typing import Any, Type
from typing import Any, Type, TypeGuard
from ..types import Account, InstructionData, InstructionId
@ -303,7 +303,7 @@ def __getattr__(name: str) -> Type[Instruction]:
class FakeClass(Instruction):
@classmethod
def is_type_of(cls, ins: Any):
def is_type_of(cls, ins: Any) -> TypeGuard[Instruction]:
return ins.program_id == id[0] and ins.instruction_id == id[1]
return FakeClass

View File

@ -52,7 +52,7 @@ from .parse import (
)
if TYPE_CHECKING:
from typing import Any, Type
from typing import Any, Type, TypeGuard
from ..types import Account, InstructionId, InstructionData
@ -88,7 +88,7 @@ def __getattr__(name: str) -> Type[Instruction]:
class FakeClass(Instruction):
@classmethod
def is_type_of(cls, ins: Any):
def is_type_of(cls, ins: Any) -> TypeGuard[Instruction]:
return ins.program_id == id[0] and ins.instruction_id == id[1]
return FakeClass

View File

@ -42,7 +42,7 @@ class PropertyTemplate(Generic[T]):
is_optional: bool,
parse: Callable[[BufferReader], T],
format: Callable[[Instruction, T], str],
):
) -> None:
self.name = name
self.is_authority = is_authority
self.is_optional = is_optional
@ -51,7 +51,7 @@ class PropertyTemplate(Generic[T]):
class AccountTemplate:
def __init__(self, name: str, is_authority: bool, optional: bool):
def __init__(self, name: str, is_authority: bool, optional: bool) -> None:
self.name = name
self.is_authority = is_authority
self.optional = optional

View File

@ -219,7 +219,7 @@ _last_auth_valid = False
class CborError(Exception):
def __init__(self, code: int):
def __init__(self, code: int) -> None:
super().__init__()
self.code = code

View File

@ -46,7 +46,7 @@ def blake_hash_writer_32(personal: bytes) -> HashWriter:
class ZcashHasher:
def __init__(self, tx: SignTx | PrevTx):
def __init__(self, tx: SignTx | PrevTx) -> None:
from trezor.utils import empty_bytearray
self.header = HeaderHasher(tx)
@ -130,7 +130,7 @@ class ZcashHasher:
class HeaderHasher:
def __init__(self, tx: SignTx | PrevTx):
def __init__(self, tx: SignTx | PrevTx) -> None:
h = blake_hash_writer_32(b"ZTxIdHeadersHash")
assert tx.version_group_id is not None

View File

@ -327,7 +327,7 @@ if TYPE_CHECKING:
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = func(*args, **kwargs)
@ -341,7 +341,7 @@ def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs):
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key)
if value is None:
value = await func(*args, **kwargs)

View File

@ -140,7 +140,7 @@ class Share:
index: int,
threshold: int,
share_value: bytes,
):
) -> None:
self.identifier = identifier
self.extendable = extendable
self.iteration_exponent = iteration_exponent

View File

@ -115,7 +115,7 @@ def _slip_39_checklist_items(
advanced: bool,
count: int | None = None,
threshold: int | None = None,
):
) -> tuple[str, str, str]:
if not advanced:
entry_1 = (
TR.reset__slip39_checklist_num_shares_x_template.format(count)

View File

@ -239,7 +239,7 @@ def slip39_prompt_number_of_shares(
min_count = 1
max_count = 16
def description(i: int):
def description(i: int) -> str:
if group_id is None:
if i == 1:
return TR.reset__only_one_share_will_be_created

View File

@ -45,6 +45,7 @@ if TYPE_CHECKING:
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
Filter = Callable[[int, Handler], Handler]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -264,7 +265,7 @@ def find_handler(iface: WireInterface, msg_type: int) -> Handler:
return handler
filters: list[Callable[[int, Handler], Handler]] = []
filters: list[Filter] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
@ -292,7 +293,7 @@ and `filters` becomes private!
"""
def remove_filter(filter):
def remove_filter(filter: Filter) -> None:
try:
filters.remove(filter)
except ValueError:

17
poetry.lock generated
View File

@ -488,6 +488,21 @@ mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.11.0,<2.12.0"
pyflakes = ">=3.2.0,<3.3.0"
[[package]]
name = "flake8-annotations"
version = "3.1.1"
description = "Flake8 Type Annotation Checks"
optional = false
python-versions = ">=3.8.1"
files = [
{file = "flake8_annotations-3.1.1-py3-none-any.whl", hash = "sha256:102935bdcbfa714759a152aeb07b14aee343fc0b6f7c55ad16968ce3e0e91a8a"},
{file = "flake8_annotations-3.1.1.tar.gz", hash = "sha256:6c98968ccc6bdc0581d363bf147a87df2f01d0d078264b2da805799d911cf5fe"},
]
[package.dependencies]
attrs = ">=21.4"
flake8 = ">=5.0"
[[package]]
name = "flake8-requirements"
version = "2.1.0"
@ -1798,4 +1813,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
content-hash = "971d0f6f2926d839954b35b2029978046e282df3d8d595b1e176dc0cf37889fb"
content-hash = "6ead2686c279a0baa17e3cc51acf49916be4649552782492ae97f2dd875e0ff6"

View File

@ -73,6 +73,7 @@ binsize = "^0.1.3"
toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"}
trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true}
trezor-core-tools = {path = "./core/tools", develop = true}
flake8-annotations = "^3.1.1"
[tool.poetry.dev-dependencies]
scan-build = "*"

View File

@ -22,7 +22,9 @@ ignore =
# E741: ambiguous variable name
E741,
# W503: line break before binary operator
W503
W503,
# flake8-annotations
ANN,
per-file-ignores =
helper-scripts/*:I
tools/*:I

View File

@ -17,9 +17,23 @@ ignore =
# W503: line break before binary operator
W503,
# flake8-requirements import checks
I
I,
# flake8-annotations self/cls type
ANN101, ANN102,
# flake8-annotations allow Any type
ANN401,
per-file-ignores =
core/mocks/generated/*:F4
core/src/typing.py:ANN
core/src/apps/monero/*:ANN
core/site_scons/*:ANN
core/tests/*:ANN
ci/*:ANN
common/*:ANN
crypto/*:ANN
legacy/*:ANN
storage/*:ANN
tests/*:ANN
[tool:pytest]
addopts = -rfE --strict-markers --random-order

View File

@ -28,7 +28,7 @@ def _silent_call(*args):
def format(file: Path):
_silent_call("isort", file)
_silent_call("black", file)
_silent_call("flake8", file)
_silent_call("flake8", file, "--config", ROOT / "setup.cfg")
def render_single(template_path: Path, programs: Munch) -> str: