1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-05 04:50:57 +00:00

core: Fix mypy.

This commit is contained in:
Andrew Kozlik 2020-04-08 17:37:40 +02:00 committed by Andrew Kozlik
parent 50c08274b9
commit 2b74513e49
15 changed files with 122 additions and 95 deletions

View File

@ -1,9 +1,10 @@
from trezor import utils, wire from trezor import utils, wire
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAck import TxAck from trezor.messages.TxAck import TxAck
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths from apps.common import coins, paths, seed
from apps.wallet.sign_tx import ( from apps.wallet.sign_tx import (
addresses, addresses,
helpers, helpers,
@ -18,12 +19,15 @@ from apps.wallet.sign_tx import (
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.wallet.sign_tx import decred, bitcoinlike from apps.wallet.sign_tx import decred, bitcoinlike
if False:
from typing import Union
async def sign_tx(ctx, msg, keychain):
async def sign_tx(ctx: wire.Context, msg: SignTx, keychain: seed.Keychain) -> TxRequest:
coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin" coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin"
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
if not utils.BITCOIN_ONLY and coin.decred: if not utils.BITCOIN_ONLY and coin.decred:
coinsig = decred.Decred() coinsig = decred.Decred() # type: signing.Bitcoin
elif not utils.BITCOIN_ONLY and coin.overwintered: elif not utils.BITCOIN_ONLY and coin.overwintered:
coinsig = bitcoinlike.Overwintered() coinsig = bitcoinlike.Overwintered()
elif not utils.BITCOIN_ONLY and coin_name not in ("Bitcoin", "Regtest", "Testnet"): elif not utils.BITCOIN_ONLY and coin_name not in ("Bitcoin", "Regtest", "Testnet"):
@ -33,7 +37,7 @@ async def sign_tx(ctx, msg, keychain):
signer = coinsig.signer(msg, keychain, coin) signer = coinsig.signer(msg, keychain, coin)
res = None res = None # type: Union[TxAck, bool]
while True: while True:
try: try:
req = signer.send(res) req = signer.send(res)

View File

@ -3,6 +3,7 @@ from micropython import const
from trezor.crypto import base58, bech32, cashaddr from trezor.crypto import base58, bech32, cashaddr
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages import FailureType, InputScriptType from trezor.messages import FailureType, InputScriptType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.utils import ensure from trezor.utils import ensure
from apps.common import HARDENED, address_type, paths from apps.common import HARDENED, address_type, paths
@ -15,6 +16,7 @@ from apps.wallet.sign_tx.scripts import (
if False: if False:
from typing import List from typing import List
from trezor.crypto import bip32
# supported witness version for bech32 addresses # supported witness version for bech32 addresses
_BECH32_WITVER = const(0x00) _BECH32_WITVER = const(0x00)
@ -25,7 +27,10 @@ class AddressError(Exception):
def get_address( def get_address(
script_type: InputScriptType, coin: CoinInfo, node, multisig=None script_type: int,
coin: CoinInfo,
node: bip32.HDNode,
multisig: MultisigRedeemScriptType = None,
) -> str: ) -> str:
if ( if (
@ -88,7 +93,7 @@ def get_address(
raise AddressError(FailureType.ProcessError, "Invalid script type") raise AddressError(FailureType.ProcessError, "Invalid script type")
def address_multisig_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo): def address_multisig_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise AddressError( raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin" FailureType.ProcessError, "Multisig not enabled on this coin"
@ -98,7 +103,7 @@ def address_multisig_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo):
return address_p2sh(redeem_script_hash, coin) return address_p2sh(redeem_script_hash, coin)
def address_multisig_p2wsh_in_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo): def address_multisig_p2wsh_in_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise AddressError( raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin" FailureType.ProcessError, "Multisig not enabled on this coin"
@ -108,7 +113,7 @@ def address_multisig_p2wsh_in_p2sh(pubkeys: List[bytes], m: int, coin: CoinInfo)
return address_p2wsh_in_p2sh(witness_script_hash, coin) return address_p2wsh_in_p2sh(witness_script_hash, coin)
def address_multisig_p2wsh(pubkeys: List[bytes], m: int, hrp: str): def address_multisig_p2wsh(pubkeys: List[bytes], m: int, hrp: str) -> str:
if not hrp: if not hrp:
raise AddressError( raise AddressError(
FailureType.ProcessError, "Multisig not enabled on this coin" FailureType.ProcessError, "Multisig not enabled on this coin"
@ -196,7 +201,7 @@ def address_short(coin: CoinInfo, address: str) -> str:
def validate_full_path( def validate_full_path(
path: list, coin: CoinInfo, script_type: InputScriptType, validate_script_type=True path: list, coin: CoinInfo, script_type: int, validate_script_type: bool = True
) -> bool: ) -> bool:
""" """
Validates derivation path to fit Bitcoin-like coins. We mostly use Validates derivation path to fit Bitcoin-like coins. We mostly use
@ -244,9 +249,7 @@ def validate_purpose(purpose: int, coin: CoinInfo) -> bool:
return True return True
def validate_purpose_against_script_type( def validate_purpose_against_script_type(purpose: int, script_type: int) -> bool:
purpose: int, script_type: InputScriptType
) -> bool:
""" """
Validates purpose against provided input's script type: Validates purpose against provided input's script type:
- 44 for spending address (script_type == SPENDADDRESS) - 44 for spending address (script_type == SPENDADDRESS)

View File

@ -44,7 +44,7 @@ class Bitcoinlike(signing.Bitcoin):
else: else:
await super().phase2_sign_nonsegwit_input(i_sign) await super().phase2_sign_nonsegwit_input(i_sign)
async def phase2_sign_bip143_input(self, i_sign) -> None: async def phase2_sign_bip143_input(self, i_sign: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.input_check_wallet_path(txi_sign) self.input_check_wallet_path(txi_sign)

View File

@ -33,19 +33,19 @@ class DecredPrefixHasher:
writers.write_uint32(self.h_prefix, tx.version | DECRED_SERIALIZE_NO_WITNESS) writers.write_uint32(self.h_prefix, tx.version | DECRED_SERIALIZE_NO_WITNESS)
writers.write_varint(self.h_prefix, tx.inputs_count) writers.write_varint(self.h_prefix, tx.inputs_count)
def add_prevouts(self, txi: TxInputType): def add_prevouts(self, txi: TxInputType) -> None:
writers.write_tx_input_decred(self.h_prefix, txi) writers.write_tx_input_decred(self.h_prefix, txi)
def add_sequence(self, txi: TxInputType): def add_sequence(self, txi: TxInputType) -> None:
pass pass
def add_output_count(self, tx: SignTx): def add_output_count(self, tx: SignTx) -> None:
writers.write_varint(self.h_prefix, tx.outputs_count) writers.write_varint(self.h_prefix, tx.outputs_count)
def add_output(self, txo_bin: TxOutputBinType): def add_output(self, txo_bin: TxOutputBinType) -> None:
writers.write_tx_output(self.h_prefix, txo_bin) writers.write_tx_output(self.h_prefix, txo_bin)
def add_locktime_expiry(self, tx: SignTx): def add_locktime_expiry(self, tx: SignTx) -> None:
writers.write_uint32(self.h_prefix, tx.lock_time) writers.write_uint32(self.h_prefix, tx.lock_time)
writers.write_uint32(self.h_prefix, tx.expiry) writers.write_uint32(self.h_prefix, tx.expiry)
@ -54,21 +54,23 @@ class DecredPrefixHasher:
class Decred(Bitcoin): class Decred(Bitcoin):
def initialize(self, tx: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo): def initialize(
self, tx: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo
) -> None:
super().initialize(tx, keychain, coin) super().initialize(tx, keychain, coin)
# This is required because the last serialized output obtained in # This is required because the last serialized output obtained in
# `check_fee` will only be sent to the client in `sign_tx` # `check_fee` will only be sent to the client in `sign_tx`
self.last_output_bytes = None # type: bytearray self.last_output_bytes = None # type: bytearray
def init_hash143(self): def init_hash143(self) -> None:
self.hash143 = DecredPrefixHasher(self.tx) # pseudo BIP-0143 prefix hashing self.hash143 = DecredPrefixHasher(self.tx) # pseudo BIP-0143 prefix hashing
async def phase1(self): async def phase1(self) -> None:
await super().phase1() await super().phase1()
self.hash143.add_locktime_expiry(self.tx) self.hash143.add_locktime_expiry(self.tx)
async def phase1_process_input(self, i: int, txi: TxInputType): async def phase1_process_input(self, i: int, txi: TxInputType) -> None:
await super().phase1_process_input(i, txi) await super().phase1_process_input(i, txi)
w_txi = writers.empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash)) w_txi = writers.empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash))
if i == 0: # serializing first input => prepend headers if i == 0: # serializing first input => prepend headers
@ -78,7 +80,7 @@ class Decred(Bitcoin):
async def phase1_confirm_output( async def phase1_confirm_output(
self, i: int, txo: TxOutputType, txo_bin: TxOutputBinType self, i: int, txo: TxOutputType, txo_bin: TxOutputBinType
): ) -> None:
if txo.decred_script_version is not None and txo.decred_script_version != 0: if txo.decred_script_version is not None and txo.decred_script_version != 0:
raise SigningError( raise SigningError(
FailureType.ActionCancelled, FailureType.ActionCancelled,
@ -97,7 +99,7 @@ class Decred(Bitcoin):
await super().phase1_confirm_output(i, txo, txo_bin) await super().phase1_confirm_output(i, txo, txo_bin)
async def phase2(self): async def phase2(self) -> None:
self.tx_req.serialized = None self.tx_req.serialized = None
prefix_hash = self.hash143.prefix_hash() prefix_hash = self.hash143.prefix_hash()
@ -171,7 +173,7 @@ class Decred(Bitcoin):
i_sign, signature, w_txi_sign i_sign, signature, w_txi_sign
) )
return await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int:
total_out = 0 # sum of output amounts total_out = 0 # sum of output amounts

View File

@ -22,7 +22,7 @@ from .writers import TX_HASH_SIZE
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
if False: if False:
from typing import Union from typing import Any, Awaitable, Union
MULTISIG_INPUT_SCRIPT_TYPES = ( MULTISIG_INPUT_SCRIPT_TYPES = (
InputScriptType.SPENDMULTISIG, InputScriptType.SPENDMULTISIG,
@ -92,27 +92,27 @@ class UiConfirmNonDefaultLocktime:
__eq__ = utils.obj_eq __eq__ = utils.obj_eq
def confirm_output(output: TxOutputType, coin: CoinInfo): def confirm_output(output: TxOutputType, coin: CoinInfo) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmOutput(output, coin)) return (yield UiConfirmOutput(output, coin))
def confirm_total(spending: int, fee: int, coin: CoinInfo): def confirm_total(spending: int, fee: int, coin: CoinInfo) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmTotal(spending, fee, coin)) return (yield UiConfirmTotal(spending, fee, coin))
def confirm_feeoverthreshold(fee: int, coin: CoinInfo): def confirm_feeoverthreshold(fee: int, coin: CoinInfo) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmFeeOverThreshold(fee, coin)) return (yield UiConfirmFeeOverThreshold(fee, coin))
def confirm_foreign_address(address_n: list): def confirm_foreign_address(address_n: list) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmForeignAddress(address_n)) return (yield UiConfirmForeignAddress(address_n))
def confirm_nondefault_locktime(lock_time: int): def confirm_nondefault_locktime(lock_time: int) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmNonDefaultLocktime(lock_time)) return (yield UiConfirmNonDefaultLocktime(lock_time))
def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes = None): def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes = None) -> Awaitable[Any]: # type: ignore
tx_req.request_type = TXMETA tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None tx_req.details.request_index = None
@ -122,9 +122,9 @@ def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes = None):
return sanitize_tx_meta(ack.tx, coin) return sanitize_tx_meta(ack.tx, coin)
def request_tx_extra_data( def request_tx_extra_data( # type: ignore
tx_req: TxRequest, offset: int, size: int, tx_hash: bytes = None tx_req: TxRequest, offset: int, size: int, tx_hash: bytes = None
): ) -> Awaitable[Any]:
tx_req.request_type = TXEXTRADATA tx_req.request_type = TXEXTRADATA
tx_req.details.extra_data_offset = offset tx_req.details.extra_data_offset = offset
tx_req.details.extra_data_len = size tx_req.details.extra_data_len = size
@ -138,7 +138,7 @@ def request_tx_extra_data(
return ack.tx.extra_data return ack.tx.extra_data
def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes = None): def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes = None) -> Awaitable[Any]: # type: ignore
tx_req.request_type = TXINPUT tx_req.request_type = TXINPUT
tx_req.details.request_index = i tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
@ -148,7 +148,7 @@ def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes =
return sanitize_tx_input(ack.tx, coin) return sanitize_tx_input(ack.tx, coin)
def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes = None): def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes = None) -> Awaitable[Any]: # type: ignore
tx_req.request_type = TXOUTPUT tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
@ -161,7 +161,7 @@ def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes
return sanitize_tx_binoutput(ack.tx, coin) return sanitize_tx_binoutput(ack.tx, coin)
def request_tx_finish(tx_req: TxRequest): def request_tx_finish(tx_req: TxRequest) -> Awaitable[Any]: # type: ignore
tx_req.request_type = TXFINISHED tx_req.request_type = TXFINISHED
tx_req.details = None tx_req.details = None
yield tx_req yield tx_req
@ -293,7 +293,7 @@ def sanitize_tx_binoutput(tx: TransactionType, coin: CoinInfo) -> TxOutputBinTyp
def _sanitize_decred( def _sanitize_decred(
tx: Union[TxInputType, TxOutputType, TxOutputBinType], coin: CoinInfo tx: Union[TxInputType, TxOutputType, TxOutputBinType], coin: CoinInfo
): ) -> None:
if not coin.decred and tx.decred_script_version is not None: if not coin.decred and tx.decred_script_version is not None:
raise SigningError( raise SigningError(
FailureType.DataError, FailureType.DataError,

View File

@ -3,25 +3,34 @@ from ubinascii import hexlify
from trezor import ui from trezor import ui
from trezor.messages import ButtonRequestType, OutputScriptType from trezor.messages import ButtonRequestType, OutputScriptType
from trezor.messages.TxOutputType import TxOutputType
from trezor.strings import format_amount from trezor.strings import format_amount
from trezor.utils import chunks from trezor.utils import chunks
from apps.common import coininfo
if False:
from typing import Iterator, List
from trezor import wire
_LOCKTIME_TIMESTAMP_MIN_VALUE = const(500000000) _LOCKTIME_TIMESTAMP_MIN_VALUE = const(500000000)
def format_coin_amount(amount, coin): def format_coin_amount(amount: int, coin: coininfo.CoinInfo) -> str:
return "%s %s" % (format_amount(amount, coin.decimals), coin.coin_shortcut) return "%s %s" % (format_amount(amount, coin.decimals), coin.coin_shortcut)
def split_address(address): def split_address(address: str) -> Iterator[str]:
return chunks(address, 17) return chunks(address, 17)
def split_op_return(data): def split_op_return(data: str) -> Iterator[str]:
return chunks(data, 18) return chunks(data, 18)
async def confirm_output(ctx, output, coin): async def confirm_output(
ctx: wire.Context, output: TxOutputType, coin: coininfo.CoinInfo
) -> bool:
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import confirm from apps.common.confirm import confirm
from apps.wallet.sign_tx import addresses, omni from apps.wallet.sign_tx import addresses, omni
@ -34,11 +43,11 @@ async def confirm_output(ctx, output, coin):
text.normal(omni.parse(data)) text.normal(omni.parse(data))
else: else:
# generic OP_RETURN # generic OP_RETURN
data = hexlify(data).decode() hex_data = hexlify(data).decode()
if len(data) >= 18 * 5: if len(hex_data) >= 18 * 5:
data = data[: (18 * 5 - 3)] + "..." hex_data = hex_data[: (18 * 5 - 3)] + "..."
text = Text("OP_RETURN", ui.ICON_SEND, ui.GREEN) text = Text("OP_RETURN", ui.ICON_SEND, ui.GREEN)
text.mono(*split_op_return(data)) text.mono(*split_op_return(hex_data))
else: else:
address = output.address address = output.address
address_short = addresses.address_short(coin, address) address_short = addresses.address_short(coin, address)
@ -48,7 +57,9 @@ async def confirm_output(ctx, output, coin):
return await confirm(ctx, text, ButtonRequestType.ConfirmOutput) return await confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_total(ctx, spending, fee, coin): async def confirm_total(
ctx: wire.Context, spending: int, fee: int, coin: coininfo.CoinInfo
) -> bool:
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import hold_to_confirm from apps.common.confirm import hold_to_confirm
@ -60,7 +71,9 @@ async def confirm_total(ctx, spending, fee, coin):
return await hold_to_confirm(ctx, text, ButtonRequestType.SignTx) return await hold_to_confirm(ctx, text, ButtonRequestType.SignTx)
async def confirm_feeoverthreshold(ctx, fee, coin): async def confirm_feeoverthreshold(
ctx: wire.Context, fee: int, coin: coininfo.CoinInfo
) -> bool:
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import confirm from apps.common.confirm import confirm
@ -71,7 +84,9 @@ async def confirm_feeoverthreshold(ctx, fee, coin):
return await confirm(ctx, text, ButtonRequestType.FeeOverThreshold) return await confirm(ctx, text, ButtonRequestType.FeeOverThreshold)
async def confirm_foreign_address(ctx, address_n, coin): async def confirm_foreign_address(
ctx: wire.Context, address_n: List[int], coin: coininfo.CoinInfo
) -> bool:
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import confirm from apps.common.confirm import confirm
@ -80,7 +95,7 @@ async def confirm_foreign_address(ctx, address_n, coin):
return await confirm(ctx, text, ButtonRequestType.SignTx) return await confirm(ctx, text, ButtonRequestType.SignTx)
async def confirm_nondefault_locktime(ctx, lock_time): async def confirm_nondefault_locktime(ctx: wire.Context, lock_time: int) -> bool:
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import confirm from apps.common.confirm import confirm

View File

@ -70,8 +70,8 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int: def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int:
if multisig.nodes: if multisig.nodes:
for i, hd in enumerate(multisig.nodes): for i, hd_node in enumerate(multisig.nodes):
if multisig_get_pubkey(hd, multisig.address_n) == pubkey: if multisig_get_pubkey(hd_node, multisig.address_n) == pubkey:
return i return i
else: else:
for i, hd in enumerate(multisig.pubkeys): for i, hd in enumerate(multisig.pubkeys):

View File

@ -11,6 +11,10 @@ from apps.wallet.sign_tx.writers import (
write_varint, write_varint,
) )
if False:
from typing import List
from apps.wallet.sign_tx.writers import Writer
class ScriptsError(ValueError): class ScriptsError(ValueError):
pass pass
@ -145,7 +149,7 @@ def witness_p2wsh(
signature: bytes, signature: bytes,
signature_index: int, signature_index: int,
sighash: int, sighash: int,
): ) -> bytearray:
# get other signatures, stretch with None to the number of the pubkeys # get other signatures, stretch with None to the number of the pubkeys
signatures = multisig.signatures + [None] * ( signatures = multisig.signatures + [None] * (
multisig_get_pubkey_count(multisig) - len(multisig.signatures) multisig_get_pubkey_count(multisig) - len(multisig.signatures)
@ -184,7 +188,7 @@ def witness_p2wsh(
# redeem script # redeem script
write_varint(w, redeem_script_length) write_varint(w, redeem_script_length)
output_script_multisig(pubkeys, multisig.m, w) write_output_script_multisig(w, pubkeys, multisig.m)
return w return w
@ -233,12 +237,18 @@ def input_script_multisig(
# redeem script # redeem script
write_op_push(w, redeem_script_length) write_op_push(w, redeem_script_length)
output_script_multisig(pubkeys, multisig.m, w) write_output_script_multisig(w, pubkeys, multisig.m)
return w return w
def output_script_multisig(pubkeys, m: int, w: bytearray = None) -> bytearray: def output_script_multisig(pubkeys: List[bytes], m: int, w: Writer = None) -> bytearray:
w = empty_bytearray(output_script_multisig_length(pubkeys, m))
write_output_script_multisig(w, pubkeys, m)
return w
def write_output_script_multisig(w: Writer, pubkeys: List[bytes], m: int) -> None:
n = len(pubkeys) n = len(pubkeys)
if n < 1 or n > 15 or m < 1 or m > 15 or m > n: if n < 1 or n > 15 or m < 1 or m > 15 or m > n:
raise ScriptsError("Invalid multisig parameters") raise ScriptsError("Invalid multisig parameters")
@ -246,17 +256,14 @@ def output_script_multisig(pubkeys, m: int, w: bytearray = None) -> bytearray:
if len(pubkey) != 33: if len(pubkey) != 33:
raise ScriptsError("Invalid multisig parameters") raise ScriptsError("Invalid multisig parameters")
if w is None:
w = empty_bytearray(output_script_multisig_length(pubkeys, m))
w.append(0x50 + m) # numbers 1 to 16 are pushed as 0x50 + value w.append(0x50 + m) # numbers 1 to 16 are pushed as 0x50 + value
for p in pubkeys: for p in pubkeys:
append_pubkey(w, p) append_pubkey(w, p)
w.append(0x50 + n) w.append(0x50 + n)
w.append(0xAE) # OP_CHECKMULTISIG w.append(0xAE) # OP_CHECKMULTISIG
return w
def output_script_multisig_length(pubkeys, m: int) -> int: def output_script_multisig_length(pubkeys: List[bytes], m: int) -> int:
return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig
@ -276,14 +283,12 @@ def output_script_paytoopreturn(data: bytes) -> bytearray:
# === # ===
def append_signature(w: bytearray, signature: bytes, sighash: int) -> bytearray: def append_signature(w: Writer, signature: bytes, sighash: int) -> None:
write_op_push(w, len(signature) + 1) write_op_push(w, len(signature) + 1)
write_bytes_unchecked(w, signature) write_bytes_unchecked(w, signature)
w.append(sighash) w.append(sighash)
return w
def append_pubkey(w: bytearray, pubkey: bytes) -> bytearray: def append_pubkey(w: Writer, pubkey: bytes) -> None:
write_op_push(w, len(pubkey)) write_op_push(w, len(pubkey))
write_bytes_unchecked(w, pubkey) write_bytes_unchecked(w, pubkey)
return w

View File

@ -25,19 +25,19 @@ class Bip143Error(ValueError):
class Bip143: class Bip143:
def __init__(self): def __init__(self) -> None:
self.h_prevouts = HashWriter(sha256()) self.h_prevouts = HashWriter(sha256())
self.h_sequence = HashWriter(sha256()) self.h_sequence = HashWriter(sha256())
self.h_outputs = HashWriter(sha256()) self.h_outputs = HashWriter(sha256())
def add_prevouts(self, txi: TxInputType): def add_prevouts(self, txi: TxInputType) -> None:
write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE) write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE)
write_uint32(self.h_prevouts, txi.prev_index) write_uint32(self.h_prevouts, txi.prev_index)
def add_sequence(self, txi: TxInputType): def add_sequence(self, txi: TxInputType) -> None:
write_uint32(self.h_sequence, txi.sequence) write_uint32(self.h_sequence, txi.sequence)
def add_output(self, txo_bin: TxOutputBinType): def add_output(self, txo_bin: TxOutputBinType) -> None:
write_tx_output(self.h_outputs, txo_bin) write_tx_output(self.h_outputs, txo_bin)
def get_prevouts_hash(self, coin: CoinInfo) -> bytes: def get_prevouts_hash(self, coin: CoinInfo) -> bytes:

View File

@ -246,7 +246,7 @@ class Bitcoin:
await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
async def phase2_serialize_segwit_input(self, i_sign) -> None: async def phase2_serialize_segwit_input(self, i_sign: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
@ -270,7 +270,7 @@ class Bitcoin:
writers.write_tx_input(w_txi, txi_sign) writers.write_tx_input(w_txi, txi_sign)
self.tx_req.serialized = TxRequestSerializedType(serialized_tx=w_txi) self.tx_req.serialized = TxRequestSerializedType(serialized_tx=w_txi)
async def phase2_sign_segwit_input(self, i) -> Tuple[bytearray, bytes]: async def phase2_sign_segwit_input(self, i: int) -> Tuple[bytearray, bytes]:
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.input_check_wallet_path(txi) self.input_check_wallet_path(txi)
@ -552,8 +552,8 @@ class Bitcoin:
# p2wsh in p2sh # p2wsh in p2sh
pubkeys = multisig.multisig_get_pubkeys(i.multisig) pubkeys = multisig.multisig_get_pubkeys(i.multisig)
witness_script_hasher = utils.HashWriter(sha256()) witness_script_hasher = utils.HashWriter(sha256())
scripts.output_script_multisig( scripts.write_output_script_multisig(
pubkeys, i.multisig.m, witness_script_hasher witness_script_hasher, pubkeys, i.multisig.m
) )
witness_script_hash = witness_script_hasher.get_digest() witness_script_hash = witness_script_hasher.get_digest()
return scripts.input_script_p2wsh_in_p2sh(witness_script_hash) return scripts.input_script_p2wsh_in_p2sh(witness_script_hash)

View File

@ -44,13 +44,13 @@ class TxWeightCalculator:
) )
self.segwit = False self.segwit = False
def add_witness_header(self): def add_witness_header(self) -> None:
if not self.segwit: if not self.segwit:
self.counter += _TXSIZE_SEGWIT_OVERHEAD self.counter += _TXSIZE_SEGWIT_OVERHEAD
self.counter += self.ser_length_size(self.inputs_count) self.counter += self.ser_length_size(self.inputs_count)
self.segwit = True self.segwit = True
def add_input(self, i: TxInputType): def add_input(self, i: TxInputType) -> None:
if i.multisig: if i.multisig:
multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * ( multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * (
@ -88,7 +88,7 @@ class TxWeightCalculator:
self.counter += 4 # empty self.counter += 4 # empty
self.counter += input_script_size # discounted witness self.counter += input_script_size # discounted witness
def add_output(self, script: bytes): def add_output(self, script: bytes) -> None:
size = len(script) + self.ser_length_size(len(script)) size = len(script) + self.ser_length_size(len(script))
self.counter += 4 * (_TXSIZE_OUTPUT + size) self.counter += 4 * (_TXSIZE_OUTPUT + size)
@ -96,7 +96,7 @@ class TxWeightCalculator:
return self.counter return self.counter
@staticmethod @staticmethod
def ser_length_size(length: int): def ser_length_size(length: int) -> int:
if length < 253: if length < 253:
return 1 return 1
if length < 0x10000: if length < 0x10000:
@ -104,7 +104,7 @@ class TxWeightCalculator:
return 5 return 5
@staticmethod @staticmethod
def op_push_size(length: int): def op_push_size(length: int) -> int:
if length < 0x4C: if length < 0x4C:
return 1 return 1
if length < 0x100: if length < 0x100:

View File

@ -18,6 +18,7 @@ from apps.common.writers import ( # noqa: F401
if False: if False:
from apps.common.writers import Writer from apps.common.writers import Writer
from trezor.utils import HashWriter
write_uint16 = write_uint16_le write_uint16 = write_uint16_le
write_uint32 = write_uint32_le write_uint32 = write_uint32_le
@ -31,14 +32,14 @@ def write_bytes_prefixed(w: Writer, b: bytes) -> None:
write_bytes_unchecked(w, b) write_bytes_unchecked(w, b)
def write_tx_input(w, i: TxInputType): def write_tx_input(w: Writer, i: TxInputType) -> None:
write_bytes_reversed(w, i.prev_hash, TX_HASH_SIZE) write_bytes_reversed(w, i.prev_hash, TX_HASH_SIZE)
write_uint32(w, i.prev_index) write_uint32(w, i.prev_index)
write_bytes_prefixed(w, i.script_sig) write_bytes_prefixed(w, i.script_sig)
write_uint32(w, i.sequence) write_uint32(w, i.sequence)
def write_tx_input_check(w, i: TxInputType): def write_tx_input_check(w: Writer, i: TxInputType) -> None:
write_bytes_fixed(w, i.prev_hash, TX_HASH_SIZE) write_bytes_fixed(w, i.prev_hash, TX_HASH_SIZE)
write_uint32(w, i.prev_index) write_uint32(w, i.prev_index)
write_uint32(w, i.script_type) write_uint32(w, i.script_type)
@ -49,28 +50,28 @@ def write_tx_input_check(w, i: TxInputType):
write_uint64(w, i.amount or 0) write_uint64(w, i.amount or 0)
def write_tx_input_decred(w, i: TxInputType): def write_tx_input_decred(w: Writer, i: TxInputType) -> None:
write_bytes_reversed(w, i.prev_hash, TX_HASH_SIZE) write_bytes_reversed(w, i.prev_hash, TX_HASH_SIZE)
write_uint32(w, i.prev_index or 0) write_uint32(w, i.prev_index or 0)
write_uint8(w, i.decred_tree or 0) write_uint8(w, i.decred_tree or 0)
write_uint32(w, i.sequence) write_uint32(w, i.sequence)
def write_tx_input_decred_witness(w, i: TxInputType): def write_tx_input_decred_witness(w: Writer, i: TxInputType) -> None:
write_uint64(w, i.amount or 0) write_uint64(w, i.amount or 0)
write_uint32(w, 0) # block height fraud proof write_uint32(w, 0) # block height fraud proof
write_uint32(w, 0xFFFFFFFF) # block index fraud proof write_uint32(w, 0xFFFFFFFF) # block index fraud proof
write_bytes_prefixed(w, i.script_sig) write_bytes_prefixed(w, i.script_sig)
def write_tx_output(w, o: TxOutputBinType): def write_tx_output(w: Writer, o: TxOutputBinType) -> None:
write_uint64(w, o.amount) write_uint64(w, o.amount)
if o.decred_script_version is not None: if o.decred_script_version is not None:
write_uint16(w, o.decred_script_version) write_uint16(w, o.decred_script_version)
write_bytes_prefixed(w, o.script_pubkey) write_bytes_prefixed(w, o.script_pubkey)
def write_op_push(w, n: int): def write_op_push(w: Writer, n: int) -> None:
ensure(n >= 0 and n <= 0xFFFFFFFF) ensure(n >= 0 and n <= 0xFFFFFFFF)
if n < 0x4C: if n < 0x4C:
w.append(n & 0xFF) w.append(n & 0xFF)
@ -89,7 +90,7 @@ def write_op_push(w, n: int):
w.append((n >> 24) & 0xFF) w.append((n >> 24) & 0xFF)
def write_varint(w, n: int): def write_varint(w: Writer, n: int) -> None:
ensure(n >= 0 and n <= 0xFFFFFFFF) ensure(n >= 0 and n <= 0xFFFFFFFF)
if n < 253: if n < 253:
w.append(n & 0xFF) w.append(n & 0xFF)
@ -105,7 +106,7 @@ def write_varint(w, n: int):
w.append((n >> 24) & 0xFF) w.append((n >> 24) & 0xFF)
def get_tx_hash(w, double: bool = False, reverse: bool = False) -> bytes: def get_tx_hash(w: HashWriter, double: bool = False, reverse: bool = False) -> bytes:
d = w.get_digest() d = w.get_digest()
if double: if double:
d = sha256(d).digest() d = sha256(d).digest()

View File

@ -53,14 +53,14 @@ class Zip143:
self.h_sequence = HashWriter(blake2b(outlen=32, personal=b"ZcashSequencHash")) self.h_sequence = HashWriter(blake2b(outlen=32, personal=b"ZcashSequencHash"))
self.h_outputs = HashWriter(blake2b(outlen=32, personal=b"ZcashOutputsHash")) self.h_outputs = HashWriter(blake2b(outlen=32, personal=b"ZcashOutputsHash"))
def add_prevouts(self, txi: TxInputType): def add_prevouts(self, txi: TxInputType) -> None:
write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE) write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE)
write_uint32(self.h_prevouts, txi.prev_index) write_uint32(self.h_prevouts, txi.prev_index)
def add_sequence(self, txi: TxInputType): def add_sequence(self, txi: TxInputType) -> None:
write_uint32(self.h_sequence, txi.sequence) write_uint32(self.h_sequence, txi.sequence)
def add_output(self, txo_bin: TxOutputBinType): def add_output(self, txo_bin: TxOutputBinType) -> None:
write_tx_output(self.h_outputs, txo_bin) write_tx_output(self.h_outputs, txo_bin)
def get_prevouts_hash(self) -> bytes: def get_prevouts_hash(self) -> bytes:
@ -119,7 +119,7 @@ class Zip143:
class Zip243(Zip143): class Zip243(Zip143):
def __init__(self, branch_id) -> None: def __init__(self, branch_id: int) -> None:
super().__init__(branch_id) super().__init__(branch_id)
def preimage_hash( def preimage_hash(

View File

@ -21,7 +21,7 @@
"""Reference implementation for Bech32 and segwit addresses.""" """Reference implementation for Bech32 and segwit addresses."""
if False: if False:
from typing import List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
@ -81,7 +81,7 @@ def bech32_decode(bech: str) -> Tuple[Optional[str], Optional[List[int]]]:
def convertbits( def convertbits(
data: List[int], frombits: int, tobits: int, pad: bool = True data: Iterable[int], frombits: int, tobits: int, pad: bool = True
) -> List[int]: ) -> List[int]:
"""General power-of-2 base conversion.""" """General power-of-2 base conversion."""
acc = 0 acc = 0
@ -120,7 +120,7 @@ def decode(hrp: str, addr: str) -> Tuple[Optional[int], Optional[List[int]]]:
return (data[0], decoded) return (data[0], decoded)
def encode(hrp: str, witver: int, witprog: List[int]) -> Optional[str]: def encode(hrp: str, witver: int, witprog: Iterable[int]) -> Optional[str]:
"""Encode a segwit address.""" """Encode a segwit address."""
ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5)) ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5))
if decode(hrp, ret) == (None, None): if decode(hrp, ret) == (None, None):

View File

@ -86,9 +86,6 @@ if False:
def extend(self, buf: bytes) -> None: def extend(self, buf: bytes) -> None:
... ...
def write(self, buf: bytes) -> None:
...
class HashWriter: class HashWriter:
def __init__(self, ctx: HashContext) -> None: def __init__(self, ctx: HashContext) -> None: