1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-12 05:17:09 +00:00

refactor(core/crypto): Optimize DER encoding of signatures.

This commit is contained in:
Andrew Kozlik 2021-03-23 10:52:21 +01:00 committed by Andrew Kozlik
parent 2964f2e855
commit 334103b089

View File

@ -1,4 +1,12 @@
from trezor.utils import BufferReader
from micropython import const
from trezor.utils import BufferReader, empty_bytearray
if False:
from trezor.utils import Writer
# Maximum length of a DER-encoded secp256k1 or secp256p1 signature.
MAX_DER_SIGNATURE_LENGTH = const(72)
def encode_length(l: int) -> bytes:
@ -12,7 +20,7 @@ def encode_length(l: int) -> bytes:
raise ValueError
def decode_length(r: BufferReader) -> int:
def read_length(r: BufferReader) -> int:
init = r.get()
if init < 0x80:
# short form encodes length in initial octet
@ -32,21 +40,27 @@ def decode_length(r: BufferReader) -> int:
return n
def encode_int(i: bytes) -> bytes:
i = i.lstrip(b"\x00")
if not i:
i = b"\00"
def write_int(w: Writer, number: bytes) -> None:
i = 0
while i < len(number) and number[i] == 0:
i += 1
if i[0] >= 0x80:
i = b"\x00" + i
return b"\x02" + encode_length(len(i)) + i
length = len(number) - i
w.append(0x02)
if length == 0 or number[i] >= 0x80:
w.extend(encode_length(length + 1))
w.append(0x00)
else:
w.extend(encode_length(length))
w.extend(memoryview(number)[i:])
def decode_int(r: BufferReader) -> memoryview:
def read_int(r: BufferReader) -> memoryview:
if r.get() != 0x02:
raise ValueError
n = decode_length(r)
n = read_length(r)
if n == 0:
raise ValueError
@ -66,10 +80,13 @@ def decode_int(r: BufferReader) -> memoryview:
def encode_seq(seq: tuple) -> bytes:
res = b""
# Preallocate space for a signature, which is all that this function ever encodes.
buffer = empty_bytearray(MAX_DER_SIGNATURE_LENGTH)
buffer.append(0x30)
for i in seq:
res += encode_int(i)
return b"\x30" + encode_length(len(res)) + res
write_int(buffer, i)
buffer[1:1] = encode_length(len(buffer) - 1)
return buffer
def decode_seq(data: memoryview) -> list[memoryview]:
@ -77,12 +94,12 @@ def decode_seq(data: memoryview) -> list[memoryview]:
if r.get() != 0x30:
raise ValueError
n = decode_length(r)
n = read_length(r)
seq = []
end = r.offset + n
while r.offset < end:
i = decode_int(r)
i = read_int(r)
seq.append(i)
if r.offset != end or r.remaining_count():