mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-20 05:18:08 +00:00
ethereum/signing: streaming; all tests passing
This commit is contained in:
parent
47b3baa30a
commit
1f677306a1
@ -45,24 +45,49 @@ async def ethereum_sign_tx(ctx, msg):
|
|||||||
data += msg.data_initial_chunk
|
data += msg.data_initial_chunk
|
||||||
data_left = data_total - len(msg.data_initial_chunk)
|
data_left = data_total - len(msg.data_initial_chunk)
|
||||||
|
|
||||||
|
total_length = get_total_length(msg, data_total)
|
||||||
|
|
||||||
|
sha = sha3_256()
|
||||||
|
sha.update(rlp.encode_length(total_length, True)) # total length
|
||||||
|
|
||||||
|
for field in [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value]:
|
||||||
|
sha.update(rlp.encode(field))
|
||||||
|
|
||||||
|
if data_left == 0:
|
||||||
|
sha.update(rlp.encode(data))
|
||||||
|
else:
|
||||||
|
sha.update(rlp.encode_length(data_total, False))
|
||||||
|
sha.update(rlp.encode(data, False))
|
||||||
|
|
||||||
while data_left > 0:
|
while data_left > 0:
|
||||||
resp = await send_request_chunk(ctx, data_left, data_total)
|
resp = await send_request_chunk(ctx, data_left, data_total)
|
||||||
data += resp.data_chunk
|
|
||||||
data_left -= len(resp.data_chunk)
|
data_left -= len(resp.data_chunk)
|
||||||
# todo stream
|
sha.update(resp.data_chunk)
|
||||||
|
|
||||||
|
# eip 155 replay protection
|
||||||
if msg.chain_id:
|
if msg.chain_id:
|
||||||
fields = [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value, data, msg.chain_id, 0, 0]
|
sha.update(rlp.encode(msg.chain_id))
|
||||||
else:
|
sha.update(rlp.encode(0))
|
||||||
fields = [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value, data]
|
sha.update(rlp.encode(0))
|
||||||
rlp_encoded = rlp.encode(fields)
|
|
||||||
sha256 = sha3_256()
|
|
||||||
sha256.update(rlp_encoded)
|
|
||||||
digest = sha256.digest(True)
|
|
||||||
|
|
||||||
|
digest = sha.digest(True)
|
||||||
return await send_signature(ctx, msg, digest)
|
return await send_signature(ctx, msg, digest)
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
||||||
|
length = 0
|
||||||
|
for field in [msg.nonce, msg.gas_price, msg.gas_limit, msg.to, msg.value]:
|
||||||
|
length += rlp.field_length(len(field), field[:1])
|
||||||
|
|
||||||
|
if msg.chain_id: # forks replay protection
|
||||||
|
length += rlp.field_length(1, [msg.chain_id])
|
||||||
|
length += rlp.field_length(0, 0)
|
||||||
|
length += rlp.field_length(0, 0)
|
||||||
|
|
||||||
|
length += rlp.field_length(data_total, msg.data_initial_chunk)
|
||||||
|
return length
|
||||||
|
|
||||||
|
|
||||||
async def send_request_chunk(ctx, data_left: int, data_total: int):
|
async def send_request_chunk(ctx, data_left: int, data_total: int):
|
||||||
from trezor.messages.wire_types import EthereumTxAck
|
from trezor.messages.wire_types import EthereumTxAck
|
||||||
# todo layoutProgress ?
|
# todo layoutProgress ?
|
||||||
|
@ -20,13 +20,13 @@ def encode_length(l: int, is_list: bool) -> bytes:
|
|||||||
raise ValueError('Input too long')
|
raise ValueError('Input too long')
|
||||||
|
|
||||||
|
|
||||||
def encode(data) -> bytes:
|
def encode(data, include_length=True) -> bytes:
|
||||||
if isinstance(data, int):
|
if isinstance(data, int):
|
||||||
return encode(int_to_bytes(data))
|
return encode(int_to_bytes(data))
|
||||||
if isinstance(data, bytearray):
|
if isinstance(data, bytearray):
|
||||||
data = bytes(data)
|
data = bytes(data)
|
||||||
if isinstance(data, bytes):
|
if isinstance(data, bytes):
|
||||||
if len(data) == 1 and ord(data) < 128:
|
if (len(data) == 1 and ord(data) < 128) or not include_length:
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
return encode_length(len(data), is_list=False) + data
|
return encode_length(len(data), is_list=False) + data
|
||||||
@ -34,6 +34,21 @@ def encode(data) -> bytes:
|
|||||||
output = b''
|
output = b''
|
||||||
for item in data:
|
for item in data:
|
||||||
output += encode(item)
|
output += encode(item)
|
||||||
return encode_length(len(output), is_list=True) + output
|
if include_length:
|
||||||
|
return encode_length(len(output), is_list=True) + output
|
||||||
|
else:
|
||||||
|
return output
|
||||||
else:
|
else:
|
||||||
raise TypeError('Invalid input of type ' + str(type(data)))
|
raise TypeError('Invalid input of type ' + str(type(data)))
|
||||||
|
|
||||||
|
|
||||||
|
def field_length(length: int, first_byte: bytearray) -> int:
|
||||||
|
if length == 1 and first_byte[0] <= 0x7f:
|
||||||
|
return 1
|
||||||
|
elif length <= 55:
|
||||||
|
return 1 + length
|
||||||
|
elif length <= 0xff:
|
||||||
|
return 2 + length
|
||||||
|
elif length <= 0xffff:
|
||||||
|
return 3 + length
|
||||||
|
return 4 + length
|
||||||
|
Loading…
Reference in New Issue
Block a user