From 81ec2f3c657ac03b579973f37791d2871da834e0 Mon Sep 17 00:00:00 2001 From: Tomas Susanka Date: Thu, 26 Oct 2017 13:23:30 +0200 Subject: [PATCH] wallet/signing: hash writers and serialization moved to seperate file --- src/apps/wallet/sign_tx/signing.py | 130 +--------------------------- src/apps/wallet/sign_tx/writers.py | 132 +++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 128 deletions(-) create mode 100644 src/apps/wallet/sign_tx/writers.py diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 6c955498a4..ae5ab8fe6c 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -6,8 +6,6 @@ from trezor.utils import ensure from trezor.messages.CoinType import CoinType from trezor.messages.SignTx import SignTx from trezor.messages.TxOutputType import TxOutputType -from trezor.messages.TxOutputBinType import TxOutputBinType -from trezor.messages.TxInputType import TxInputType from trezor.messages.TxRequest import TxRequest from trezor.messages.TransactionType import TransactionType from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED @@ -17,6 +15,7 @@ from trezor.messages import OutputScriptType, InputScriptType, FailureType from apps.common import address_type from apps.common import coins +from apps.wallet.sign_tx.writers import * # Machine instructions @@ -137,6 +136,7 @@ def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType: # Transaction signing # === +# Phase 1 async def check_tx_fee(tx: SignTx, root, segwit=False): coin = coins.by_name(tx.coin_name) @@ -337,15 +337,6 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde return total_out -def get_tx_hash(w, double: bool, reverse: bool=False) -> bytes: - d = w.getvalue() - if double: - d = sha256(d).digest() - if reverse: - d = bytes(reversed(d)) - return d - - def estimate_tx_size(inputs, outputs): return 10 + inputs * 149 + outputs * 35 @@ -479,120 +470,3 @@ def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray: write_op_push(w, len(pubkey)) write_bytes(w, pubkey) return w - - -# TX Serialization -# === - -_DEFAULT_SEQUENCE = 4294967295 - - -def write_tx_input(w, i: TxInputType): - i_sequence = i.sequence if i.sequence is not None else _DEFAULT_SEQUENCE - write_bytes_rev(w, i.prev_hash) - write_uint32(w, i.prev_index) - write_varint(w, len(i.script_sig)) - write_bytes(w, i.script_sig) - write_uint32(w, i_sequence) - - -def write_tx_input_check(w, i: TxInputType): - i_sequence = i.sequence if i.sequence is not None else _DEFAULT_SEQUENCE - write_bytes(w, i.prev_hash) - write_uint32(w, i.prev_index) - write_uint32(w, len(i.address_n)) - for n in i.address_n: - write_uint32(w, n) - write_uint32(w, i_sequence) - - -def write_tx_output(w, o: TxOutputBinType): - write_uint64(w, o.amount) - write_varint(w, len(o.script_pubkey)) - write_bytes(w, o.script_pubkey) - - -def write_op_push(w, n: int): - if n < 0x4C: - w.append(n & 0xFF) - elif n < 0xFF: - w.append(0x4C) - w.append(n & 0xFF) - elif n < 0xFFFF: - w.append(0x4D) - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - else: - w.append(0x4E) - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - w.append((n >> 16) & 0xFF) - w.append((n >> 24) & 0xFF) - - -# Buffer IO & Serialization -# === - - -def write_varint(w, n: int): - if n < 253: - w.append(n & 0xFF) - elif n < 65536: - w.append(253) - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - else: - w.append(254) - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - w.append((n >> 16) & 0xFF) - w.append((n >> 24) & 0xFF) - - -def write_uint32(w, n: int): - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - w.append((n >> 16) & 0xFF) - w.append((n >> 24) & 0xFF) - - -def write_uint64(w, n: int): - w.append(n & 0xFF) - w.append((n >> 8) & 0xFF) - w.append((n >> 16) & 0xFF) - w.append((n >> 24) & 0xFF) - w.append((n >> 32) & 0xFF) - w.append((n >> 40) & 0xFF) - w.append((n >> 48) & 0xFF) - w.append((n >> 56) & 0xFF) - - -def write_bytes(w, buf: bytearray): - w.extend(buf) - - -def write_bytes_rev(w, buf: bytearray): - w.extend(bytearray(reversed(buf))) - - -def bytearray_with_cap(cap: int) -> bytearray: - b = bytearray(cap) - b[:] = bytes() - return b - - -class HashWriter: - - def __init__(self, hashfunc): - self.ctx = hashfunc() - self.buf = bytearray(1) # used in append() - - def extend(self, buf: bytearray): - self.ctx.update(buf) - - def append(self, b: int): - self.buf[0] = b - self.ctx.update(self.buf) - - def getvalue(self) -> bytes: - return self.ctx.digest() diff --git a/src/apps/wallet/sign_tx/writers.py b/src/apps/wallet/sign_tx/writers.py new file mode 100644 index 0000000000..c5eea29d06 --- /dev/null +++ b/src/apps/wallet/sign_tx/writers.py @@ -0,0 +1,132 @@ +from trezor.messages.TxOutputBinType import TxOutputBinType +from trezor.messages.TxInputType import TxInputType +from trezor.crypto.hashlib import sha256 + +# TX Serialization +# === + +_DEFAULT_SEQUENCE = 4294967295 + + +def write_tx_input(w, i: TxInputType): + i_sequence = i.sequence if i.sequence is not None else _DEFAULT_SEQUENCE + write_bytes_rev(w, i.prev_hash) + write_uint32(w, i.prev_index) + write_varint(w, len(i.script_sig)) + write_bytes(w, i.script_sig) + write_uint32(w, i_sequence) + + +def write_tx_input_check(w, i: TxInputType): + i_sequence = i.sequence if i.sequence is not None else _DEFAULT_SEQUENCE + write_bytes(w, i.prev_hash) + write_uint32(w, i.prev_index) + write_uint32(w, len(i.address_n)) + for n in i.address_n: + write_uint32(w, n) + write_uint32(w, i_sequence) + + +def write_tx_output(w, o: TxOutputBinType): + write_uint64(w, o.amount) + write_varint(w, len(o.script_pubkey)) + write_bytes(w, o.script_pubkey) + + +def write_op_push(w, n: int): + if n < 0x4C: + w.append(n & 0xFF) + elif n < 0xFF: + w.append(0x4C) + w.append(n & 0xFF) + elif n < 0xFFFF: + w.append(0x4D) + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + else: + w.append(0x4E) + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + w.append((n >> 16) & 0xFF) + w.append((n >> 24) & 0xFF) + + +# Buffer IO & Serialization +# === + + +def write_varint(w, n: int): + if n < 253: + w.append(n & 0xFF) + elif n < 65536: + w.append(253) + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + else: + w.append(254) + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + w.append((n >> 16) & 0xFF) + w.append((n >> 24) & 0xFF) + + +def write_uint32(w, n: int): + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + w.append((n >> 16) & 0xFF) + w.append((n >> 24) & 0xFF) + + +def write_uint64(w, n: int): + w.append(n & 0xFF) + w.append((n >> 8) & 0xFF) + w.append((n >> 16) & 0xFF) + w.append((n >> 24) & 0xFF) + w.append((n >> 32) & 0xFF) + w.append((n >> 40) & 0xFF) + w.append((n >> 48) & 0xFF) + w.append((n >> 56) & 0xFF) + + +def write_bytes(w, buf: bytearray): + w.extend(buf) + + +def write_bytes_rev(w, buf: bytearray): + w.extend(bytearray(reversed(buf))) + + +def bytearray_with_cap(cap: int) -> bytearray: + b = bytearray(cap) + b[:] = bytes() + return b + + +# Hashes +# === + + +def get_tx_hash(w, double: bool, reverse: bool=False) -> bytes: + d = w.getvalue() + if double: + d = sha256(d).digest() + if reverse: + d = bytes(reversed(d)) + return d + + +class HashWriter: + + def __init__(self, hashfunc): + self.ctx = hashfunc() + self.buf = bytearray(1) # used in append() + + def extend(self, buf: bytearray): + self.ctx.update(buf) + + def append(self, b: int): + self.buf[0] = b + self.ctx.update(self.buf) + + def getvalue(self) -> bytes: + return self.ctx.digest()