1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-21 23:18:13 +00:00

feat(core): enforce return type annotations

This commit is contained in:
matejcik 2024-11-08 16:03:35 +01:00 committed by matejcik
parent 34d97ee942
commit 8fb41ee290
34 changed files with 110 additions and 71 deletions

View File

@ -26,7 +26,7 @@ from .hash_benchmark import HashBenchmark
# This is a wrapper above the trezor.crypto.curve.ed25519 module that satisfies SignCurve protocol, the modules uses `message` instead of `digest` in `sign()` and `verify()` # This is a wrapper above the trezor.crypto.curve.ed25519 module that satisfies SignCurve protocol, the modules uses `message` instead of `digest` in `sign()` and `verify()`
class Ed25519: class Ed25519:
def __init__(self): def __init__(self) -> None:
pass pass
def generate_secret(self) -> bytes: def generate_secret(self) -> bytes:

View File

@ -16,17 +16,17 @@ if TYPE_CHECKING:
class EncryptBenchmark: class EncryptBenchmark:
def __init__( def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
): ) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size self.block_size = block_size
def prepare(self): def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor() self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100 self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size) self.data = random_bytes(self.blocks_count * self.block_size)
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.cipher_ctx.encrypt(self.data) self.cipher_ctx.encrypt(self.data)
@ -44,17 +44,17 @@ class EncryptBenchmark:
class DecryptBenchmark: class DecryptBenchmark:
def __init__( def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
): ) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size self.block_size = block_size
def prepare(self): def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor() self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100 self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size) self.data = random_bytes(self.blocks_count * self.block_size)
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.cipher_ctx.decrypt(self.data) self.cipher_ctx.decrypt(self.data)

View File

@ -32,15 +32,15 @@ if TYPE_CHECKING:
class SignBenchmark: class SignBenchmark:
def __init__(self, curve: SignCurve): def __init__(self, curve: SignCurve) -> None:
self.curve = curve self.curve = curve
def prepare(self): def prepare(self) -> None:
self.iterations_count = 10 self.iterations_count = 10
self.secret_key = self.curve.generate_secret() self.secret_key = self.curve.generate_secret()
self.digest = random_bytes(32) self.digest = random_bytes(32)
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.curve.sign(self.secret_key, self.digest) self.curve.sign(self.secret_key, self.digest)
@ -51,17 +51,17 @@ class SignBenchmark:
class VerifyBenchmark: class VerifyBenchmark:
def __init__(self, curve: SignCurve): def __init__(self, curve: SignCurve) -> None:
self.curve = curve self.curve = curve
def prepare(self): def prepare(self) -> None:
self.iterations_count = 10 self.iterations_count = 10
self.secret_key = self.curve.generate_secret() self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.secret_key) self.public_key = self.curve.publickey(self.secret_key)
self.digest = random_bytes(32) self.digest = random_bytes(32)
self.signature = self.curve.sign(self.secret_key, self.digest) self.signature = self.curve.sign(self.secret_key, self.digest)
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.curve.verify(self.public_key, self.signature, self.digest) self.curve.verify(self.public_key, self.signature, self.digest)
@ -72,15 +72,15 @@ class VerifyBenchmark:
class MultiplyBenchmark: class MultiplyBenchmark:
def __init__(self, curve: MultiplyCurve): def __init__(self, curve: MultiplyCurve) -> None:
self.curve = curve self.curve = curve
def prepare(self): def prepare(self) -> None:
self.secret_key = self.curve.generate_secret() self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.curve.generate_secret()) self.public_key = self.curve.publickey(self.curve.generate_secret())
self.iterations_count = 10 self.iterations_count = 10
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.curve.multiply(self.secret_key, self.public_key) self.curve.multiply(self.secret_key, self.public_key)
@ -91,14 +91,14 @@ class MultiplyBenchmark:
class PublickeyBenchmark: class PublickeyBenchmark:
def __init__(self, curve: Curve): def __init__(self, curve: Curve) -> None:
self.curve = curve self.curve = curve
def prepare(self): def prepare(self) -> None:
self.iterations_count = 10 self.iterations_count = 10
self.secret_key = self.curve.generate_secret() self.secret_key = self.curve.generate_secret()
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.curve.publickey(self.secret_key) self.curve.publickey(self.secret_key)

View File

@ -14,16 +14,16 @@ if TYPE_CHECKING:
class HashBenchmark: class HashBenchmark:
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]): def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]) -> None:
self.hash_ctx_constructor = hash_ctx_constructor self.hash_ctx_constructor = hash_ctx_constructor
def prepare(self): def prepare(self) -> None:
self.hash_ctx = self.hash_ctx_constructor() self.hash_ctx = self.hash_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.hash_ctx.block_size self.blocks_count = maximum_used_memory_in_bytes // self.hash_ctx.block_size
self.iterations_count = 100 self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.hash_ctx.block_size) self.data = random_bytes(self.blocks_count * self.hash_ctx.block_size)
def run(self): def run(self) -> None:
for _ in range(self.iterations_count): for _ in range(self.iterations_count):
self.hash_ctx.update(self.data) self.hash_ctx.update(self.data)

View File

@ -345,7 +345,7 @@ class AccountType:
require_bech32: bool, require_bech32: bool,
require_taproot: bool, require_taproot: bool,
account_level: bool = False, account_level: bool = False,
): ) -> None:
self.account_name = account_name self.account_name = account_name
self.pattern = pattern self.pattern = pattern
self.script_type = script_type self.script_type = script_type

View File

@ -46,7 +46,7 @@ class UiConfirmOutput(UiConfirm):
output_index: int, output_index: int,
chunkify: bool, chunkify: bool,
address_n: Bip32Path | None, address_n: Bip32Path | None,
): ) -> None:
self.output = output self.output = output
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
@ -66,7 +66,9 @@ class UiConfirmOutput(UiConfirm):
class UiConfirmDecredSSTXSubmission(UiConfirm): class UiConfirmDecredSSTXSubmission(UiConfirm):
def __init__(self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit): def __init__(
self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None:
self.output = output self.output = output
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
@ -83,7 +85,7 @@ class UiConfirmPaymentRequest(UiConfirm):
payment_req: TxAckPaymentRequest, payment_req: TxAckPaymentRequest,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
): ) -> None:
self.payment_req = payment_req self.payment_req = payment_req
self.amount_unit = amount_unit self.amount_unit = amount_unit
self.coin = coin self.coin = coin
@ -97,7 +99,7 @@ class UiConfirmPaymentRequest(UiConfirm):
class UiConfirmReplacement(UiConfirm): class UiConfirmReplacement(UiConfirm):
def __init__(self, title: str, txid: bytes): def __init__(self, title: str, txid: bytes) -> None:
self.title = title self.title = title
self.txid = txid self.txid = txid
@ -112,7 +114,7 @@ class UiConfirmModifyOutput(UiConfirm):
orig_txo: TxOutput, orig_txo: TxOutput,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
): ) -> None:
self.txo = txo self.txo = txo
self.orig_txo = orig_txo self.orig_txo = orig_txo
self.coin = coin self.coin = coin
@ -133,7 +135,7 @@ class UiConfirmModifyFee(UiConfirm):
fee_rate: float, fee_rate: float,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
): ) -> None:
self.title = title self.title = title
self.user_fee_change = user_fee_change self.user_fee_change = user_fee_change
self.total_fee_new = total_fee_new self.total_fee_new = total_fee_new
@ -161,7 +163,7 @@ class UiConfirmTotal(UiConfirm):
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
address_n: Bip32Path | None, address_n: Bip32Path | None,
): ) -> None:
self.spending = spending self.spending = spending
self.fee = fee self.fee = fee
self.fee_rate = fee_rate self.fee_rate = fee_rate
@ -183,7 +185,7 @@ class UiConfirmTotal(UiConfirm):
class UiConfirmJointTotal(UiConfirm): class UiConfirmJointTotal(UiConfirm):
def __init__( def __init__(
self, spending: int, total: int, coin: CoinInfo, amount_unit: AmountUnit self, spending: int, total: int, coin: CoinInfo, amount_unit: AmountUnit
): ) -> None:
self.spending = spending self.spending = spending
self.total = total self.total = total
self.coin = coin self.coin = coin
@ -196,7 +198,7 @@ class UiConfirmJointTotal(UiConfirm):
class UiConfirmFeeOverThreshold(UiConfirm): class UiConfirmFeeOverThreshold(UiConfirm):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit): def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit) -> None:
self.fee = fee self.fee = fee
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
@ -206,7 +208,7 @@ class UiConfirmFeeOverThreshold(UiConfirm):
class UiConfirmChangeCountOverThreshold(UiConfirm): class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int): def __init__(self, change_count: int) -> None:
self.change_count = change_count self.change_count = change_count
def confirm_dialog(self) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
@ -219,7 +221,7 @@ class UiConfirmUnverifiedExternalInput(UiConfirm):
class UiConfirmForeignAddress(UiConfirm): class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list): def __init__(self, address_n: list) -> None:
self.address_n = address_n self.address_n = address_n
def confirm_dialog(self) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
@ -229,7 +231,7 @@ class UiConfirmForeignAddress(UiConfirm):
class UiConfirmNonDefaultLocktime(UiConfirm): class UiConfirmNonDefaultLocktime(UiConfirm):
def __init__(self, lock_time: int, lock_time_disabled: bool): def __init__(self, lock_time: int, lock_time_disabled: bool) -> None:
self.lock_time = lock_time self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled self.lock_time_disabled = lock_time_disabled

View File

@ -14,7 +14,7 @@ _PREV_TX_MULTIPLIER = 5
class Progress: class Progress:
def __init__(self): def __init__(self) -> None:
self.progress = 0 self.progress = 0
self.steps = 0 self.steps = 0
self.signing = False self.signing = False

View File

@ -17,7 +17,7 @@ class SignatureVerifier:
script_sig: bytes | None, script_sig: bytes | None,
witness: bytes | None, witness: bytes | None,
coin: CoinInfo, coin: CoinInfo,
): ) -> None:
from trezor import utils from trezor import utils
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.wire import DataError # local_cache_global from trezor.wire import DataError # local_cache_global

View File

@ -42,7 +42,7 @@ class Credential:
key_hash: bytes | None, key_hash: bytes | None,
script_hash: bytes | None, script_hash: bytes | None,
pointer: messages.CardanoBlockchainPointerType | None, pointer: messages.CardanoBlockchainPointerType | None,
): ) -> None:
self.type_name = type_name self.type_name = type_name
self.address_type = address_type self.address_type = address_type
self.path = path self.path = path

View File

@ -107,7 +107,7 @@ class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError key_order_error: wire.ProcessError
previous_encoded_key: bytes previous_encoded_key: bytes
def __init__(self, size: int, key_order_error: wire.ProcessError): def __init__(self, size: int, key_order_error: wire.ProcessError) -> None:
super().__init__(size) super().__init__(size)
self.key_order_error = key_order_error self.key_order_error = key_order_error
self.previous_encoded_key = b"" self.previous_encoded_key = b""

View File

@ -23,7 +23,7 @@ class OrdinarySigner(Signer):
self, self,
msg: messages.CardanoSignTxInit, msg: messages.CardanoSignTxInit,
keychain: seed.Keychain, keychain: seed.Keychain,
): ) -> None:
super().__init__(msg, keychain) super().__init__(msg, keychain)
self.suite_tx_type: SuiteTxType = self._suite_tx_type() self.suite_tx_type: SuiteTxType = self._suite_tx_type()

View File

@ -12,12 +12,12 @@ def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup(): def activate_repeated_backup() -> None:
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter) wire.filters.append(_repeated_backup_filter)
def deactivate_repeated_backup(): def deactivate_repeated_backup() -> None:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter) wire.remove_filter(_repeated_backup_filter)

View File

@ -229,7 +229,7 @@ class Tagged:
# TODO: this seems to be unused - is checked against, but is never created??? # TODO: this seems to be unused - is checked against, but is never created???
class Raw: class Raw:
def __init__(self, value: Value): def __init__(self, value: Value) -> None:
self.value = value self.value = value

View File

@ -142,7 +142,7 @@ async def require_confirm_claim(
) )
async def require_confirm_unknown_token(address_bytes: bytes): async def require_confirm_unknown_token(address_bytes: bytes) -> None:
from ubinascii import hexlify from ubinascii import hexlify
from trezor.ui.layouts import confirm_address, show_warning from trezor.ui.layouts import confirm_address, show_warning

View File

@ -233,7 +233,7 @@ def _get_slip39_mnemonics(
group_threshold: int, group_threshold: int,
groups: Sequence[tuple[int, int]], groups: Sequence[tuple[int, int]],
extendable: bool, extendable: bool,
): ) -> list[list[str]]:
if extendable: if extendable:
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
else: else:

View File

@ -25,7 +25,7 @@ class State:
from apps.monero.xmr.mlsag_hasher import PreMlsagHasher from apps.monero.xmr.mlsag_hasher import PreMlsagHasher
# Account credentials # Account credentials
# type: AccountCreds # - type: AccountCreds
# - view private/public key # - view private/public key
# - spend private/public key # - spend private/public key
# - and its corresponding address # - and its corresponding address

View File

@ -16,6 +16,8 @@ from apps.common.paths import address_n_to_str
from .types import AddressType from .types import AddressType
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence
from .transaction.instructions import Instruction, SystemProgramTransferInstruction from .transaction.instructions import Instruction, SystemProgramTransferInstruction
from .types import AddressReference from .types import AddressReference
@ -33,7 +35,9 @@ def _format_path(path: list[int]) -> str:
return f"Solana #{unharden(account_index) + 1}" return f"Solana #{unharden(account_index) + 1}"
def _get_address_reference_props(address: AddressReference, display_name: str): def _get_address_reference_props(
address: AddressReference, display_name: str
) -> Sequence[tuple[str, str]]:
return ( return (
(TR.solana__is_provided_via_lookup_table_template.format(display_name), ""), (TR.solana__is_provided_via_lookup_table_template.format(display_name), ""),
(f"{TR.solana__lookup_table_address}:", base58.encode(address[0])), (f"{TR.solana__lookup_table_address}:", base58.encode(address[0])),
@ -293,7 +297,7 @@ async def confirm_token_transfer(
fee: int, fee: int,
signer_path: list[int], signer_path: list[int],
blockhash: bytes, blockhash: bytes,
): ) -> None:
await confirm_value( await confirm_value(
title=TR.words__recipient, title=TR.words__recipient,
value=base58.encode(destination_account), value=base58.encode(destination_account),

View File

@ -42,7 +42,7 @@ def get_create_associated_token_account_instructions(
def is_predefined_token_transfer( def is_predefined_token_transfer(
instructions: list[Instruction], instructions: list[Instruction],
): ) -> bool:
""" """
Checks that the transaction consists of one or zero create token account instructions Checks that the transaction consists of one or zero create token account instructions
and one or more transfer token instructions. Also checks that the token program, token mint and one or more transfer token instructions. Also checks that the token program, token mint

View File

@ -41,7 +41,7 @@ class Instruction:
@staticmethod @staticmethod
def parse_instruction_data( def parse_instruction_data(
instruction_data: InstructionData, property_templates: list[PropertyTemplate] instruction_data: InstructionData, property_templates: list[PropertyTemplate]
): ) -> dict[str, Any]:
from trezor.utils import BufferReader from trezor.utils import BufferReader
from trezor.wire import DataError from trezor.wire import DataError
@ -65,7 +65,7 @@ class Instruction:
@staticmethod @staticmethod
def parse_instruction_accounts( def parse_instruction_accounts(
accounts: list[Account], accounts_template: list[AccountTemplate] accounts: list[Account], accounts_template: list[AccountTemplate]
): ) -> dict[str, Account]:
parsed_account = {} parsed_account = {}
for i, account_template in enumerate(accounts_template): for i, account_template in enumerate(accounts_template):
if i >= len(accounts): if i >= len(accounts):

View File

@ -20,7 +20,7 @@ from .instruction import Instruction
from .parse import parse_byte, parse_memo, parse_pubkey, parse_string from .parse import parse_byte, parse_memo, parse_pubkey, parse_string
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Type from typing import Any, Type, TypeGuard
from ..types import Account, InstructionData, InstructionId from ..types import Account, InstructionData, InstructionId
@ -303,7 +303,7 @@ def __getattr__(name: str) -> Type[Instruction]:
class FakeClass(Instruction): class FakeClass(Instruction):
@classmethod @classmethod
def is_type_of(cls, ins: Any): def is_type_of(cls, ins: Any) -> TypeGuard[Instruction]:
return ins.program_id == id[0] and ins.instruction_id == id[1] return ins.program_id == id[0] and ins.instruction_id == id[1]
return FakeClass return FakeClass

View File

@ -52,7 +52,7 @@ from .parse import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Type from typing import Any, Type, TypeGuard
from ..types import Account, InstructionId, InstructionData from ..types import Account, InstructionId, InstructionData
@ -88,7 +88,7 @@ def __getattr__(name: str) -> Type[Instruction]:
class FakeClass(Instruction): class FakeClass(Instruction):
@classmethod @classmethod
def is_type_of(cls, ins: Any): def is_type_of(cls, ins: Any) -> TypeGuard[Instruction]:
return ins.program_id == id[0] and ins.instruction_id == id[1] return ins.program_id == id[0] and ins.instruction_id == id[1]
return FakeClass return FakeClass

View File

@ -42,7 +42,7 @@ class PropertyTemplate(Generic[T]):
is_optional: bool, is_optional: bool,
parse: Callable[[BufferReader], T], parse: Callable[[BufferReader], T],
format: Callable[[Instruction, T], str], format: Callable[[Instruction, T], str],
): ) -> None:
self.name = name self.name = name
self.is_authority = is_authority self.is_authority = is_authority
self.is_optional = is_optional self.is_optional = is_optional
@ -51,7 +51,7 @@ class PropertyTemplate(Generic[T]):
class AccountTemplate: class AccountTemplate:
def __init__(self, name: str, is_authority: bool, optional: bool): def __init__(self, name: str, is_authority: bool, optional: bool) -> None:
self.name = name self.name = name
self.is_authority = is_authority self.is_authority = is_authority
self.optional = optional self.optional = optional

View File

@ -219,7 +219,7 @@ _last_auth_valid = False
class CborError(Exception): class CborError(Exception):
def __init__(self, code: int): def __init__(self, code: int) -> None:
super().__init__() super().__init__()
self.code = code self.code = code

View File

@ -46,7 +46,7 @@ def blake_hash_writer_32(personal: bytes) -> HashWriter:
class ZcashHasher: class ZcashHasher:
def __init__(self, tx: SignTx | PrevTx): def __init__(self, tx: SignTx | PrevTx) -> None:
from trezor.utils import empty_bytearray from trezor.utils import empty_bytearray
self.header = HeaderHasher(tx) self.header = HeaderHasher(tx)
@ -130,7 +130,7 @@ class ZcashHasher:
class HeaderHasher: class HeaderHasher:
def __init__(self, tx: SignTx | PrevTx): def __init__(self, tx: SignTx | PrevTx) -> None:
h = blake_hash_writer_32(b"ZTxIdHeadersHash") h = blake_hash_writer_32(b"ZTxIdHeadersHash")
assert tx.version_group_id is not None assert tx.version_group_id is not None

View File

@ -327,7 +327,7 @@ if TYPE_CHECKING:
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]: def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key) value = get(key)
if value is None: if value is None:
value = func(*args, **kwargs) value = func(*args, **kwargs)
@ -341,7 +341,7 @@ def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs): async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = get(key) value = get(key)
if value is None: if value is None:
value = await func(*args, **kwargs) value = await func(*args, **kwargs)

View File

@ -140,7 +140,7 @@ class Share:
index: int, index: int,
threshold: int, threshold: int,
share_value: bytes, share_value: bytes,
): ) -> None:
self.identifier = identifier self.identifier = identifier
self.extendable = extendable self.extendable = extendable
self.iteration_exponent = iteration_exponent self.iteration_exponent = iteration_exponent

View File

@ -115,7 +115,7 @@ def _slip_39_checklist_items(
advanced: bool, advanced: bool,
count: int | None = None, count: int | None = None,
threshold: int | None = None, threshold: int | None = None,
): ) -> tuple[str, str, str]:
if not advanced: if not advanced:
entry_1 = ( entry_1 = (
TR.reset__slip39_checklist_num_shares_x_template.format(count) TR.reset__slip39_checklist_num_shares_x_template.format(count)

View File

@ -239,7 +239,7 @@ def slip39_prompt_number_of_shares(
min_count = 1 min_count = 1
max_count = 16 max_count = 16
def description(i: int): def description(i: int) -> str:
if group_id is None: if group_id is None:
if i == 1: if i == 1:
return TR.reset__only_one_share_will_be_created return TR.reset__only_one_share_will_be_created

View File

@ -45,6 +45,7 @@ if TYPE_CHECKING:
Msg = TypeVar("Msg", bound=protobuf.MessageType) Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType] HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask] Handler = Callable[[Msg], HandlerTask]
Filter = Callable[[int, Handler], Handler]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -264,7 +265,7 @@ def find_handler(iface: WireInterface, msg_type: int) -> Handler:
return handler return handler
filters: list[Callable[[int, Handler], Handler]] = [] filters: list[Filter] = []
"""Filters for the wire handler. """Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It Filters are applied in order. Each filter gets a message id and a preceding handler. It
@ -292,7 +293,7 @@ and `filters` becomes private!
""" """
def remove_filter(filter): def remove_filter(filter: Filter) -> None:
try: try:
filters.remove(filter) filters.remove(filter)
except ValueError: except ValueError:

17
poetry.lock generated
View File

@ -488,6 +488,21 @@ mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.11.0,<2.12.0" pycodestyle = ">=2.11.0,<2.12.0"
pyflakes = ">=3.2.0,<3.3.0" pyflakes = ">=3.2.0,<3.3.0"
[[package]]
name = "flake8-annotations"
version = "3.1.1"
description = "Flake8 Type Annotation Checks"
optional = false
python-versions = ">=3.8.1"
files = [
{file = "flake8_annotations-3.1.1-py3-none-any.whl", hash = "sha256:102935bdcbfa714759a152aeb07b14aee343fc0b6f7c55ad16968ce3e0e91a8a"},
{file = "flake8_annotations-3.1.1.tar.gz", hash = "sha256:6c98968ccc6bdc0581d363bf147a87df2f01d0d078264b2da805799d911cf5fe"},
]
[package.dependencies]
attrs = ">=21.4"
flake8 = ">=5.0"
[[package]] [[package]]
name = "flake8-requirements" name = "flake8-requirements"
version = "2.1.0" version = "2.1.0"
@ -1798,4 +1813,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.1" python-versions = "^3.8.1"
content-hash = "971d0f6f2926d839954b35b2029978046e282df3d8d595b1e176dc0cf37889fb" content-hash = "6ead2686c279a0baa17e3cc51acf49916be4649552782492ae97f2dd875e0ff6"

View File

@ -73,6 +73,7 @@ binsize = "^0.1.3"
toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"} toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"}
trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true} trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true}
trezor-core-tools = {path = "./core/tools", develop = true} trezor-core-tools = {path = "./core/tools", develop = true}
flake8-annotations = "^3.1.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
scan-build = "*" scan-build = "*"

View File

@ -22,7 +22,9 @@ ignore =
# E741: ambiguous variable name # E741: ambiguous variable name
E741, E741,
# W503: line break before binary operator # W503: line break before binary operator
W503 W503,
# flake8-annotations
ANN,
per-file-ignores = per-file-ignores =
helper-scripts/*:I helper-scripts/*:I
tools/*:I tools/*:I

View File

@ -17,9 +17,23 @@ ignore =
# W503: line break before binary operator # W503: line break before binary operator
W503, W503,
# flake8-requirements import checks # flake8-requirements import checks
I I,
# flake8-annotations self/cls type
ANN101, ANN102,
# flake8-annotations allow Any type
ANN401,
per-file-ignores = per-file-ignores =
core/mocks/generated/*:F4 core/mocks/generated/*:F4
core/src/typing.py:ANN
core/src/apps/monero/*:ANN
core/site_scons/*:ANN
core/tests/*:ANN
ci/*:ANN
common/*:ANN
crypto/*:ANN
legacy/*:ANN
storage/*:ANN
tests/*:ANN
[tool:pytest] [tool:pytest]
addopts = -rfE --strict-markers --random-order addopts = -rfE --strict-markers --random-order

View File

@ -28,7 +28,7 @@ def _silent_call(*args):
def format(file: Path): def format(file: Path):
_silent_call("isort", file) _silent_call("isort", file)
_silent_call("black", file) _silent_call("black", file)
_silent_call("flake8", file) _silent_call("flake8", file, "--config", ROOT / "setup.cfg")
def render_single(template_path: Path, programs: Munch) -> str: def render_single(template_path: Path, programs: Munch) -> str: