diff --git a/core/src/trezor/crypto/der.py b/core/src/trezor/crypto/der.py index bfe8c6a914..f8525507df 100644 --- a/core/src/trezor/crypto/der.py +++ b/core/src/trezor/crypto/der.py @@ -1,4 +1,12 @@ -from trezor.utils import BufferReader +from micropython import const + +from trezor.utils import BufferReader, empty_bytearray + +if False: + from trezor.utils import Writer + +# Maximum length of a DER-encoded secp256k1 or secp256p1 signature. +MAX_DER_SIGNATURE_LENGTH = const(72) def encode_length(l: int) -> bytes: @@ -12,7 +20,7 @@ def encode_length(l: int) -> bytes: raise ValueError -def decode_length(r: BufferReader) -> int: +def read_length(r: BufferReader) -> int: init = r.get() if init < 0x80: # short form encodes length in initial octet @@ -32,21 +40,27 @@ def decode_length(r: BufferReader) -> int: return n -def encode_int(i: bytes) -> bytes: - i = i.lstrip(b"\x00") - if not i: - i = b"\00" +def write_int(w: Writer, number: bytes) -> None: + i = 0 + while i < len(number) and number[i] == 0: + i += 1 - if i[0] >= 0x80: - i = b"\x00" + i - return b"\x02" + encode_length(len(i)) + i + length = len(number) - i + w.append(0x02) + if length == 0 or number[i] >= 0x80: + w.extend(encode_length(length + 1)) + w.append(0x00) + else: + w.extend(encode_length(length)) + + w.extend(memoryview(number)[i:]) -def decode_int(r: BufferReader) -> memoryview: +def read_int(r: BufferReader) -> memoryview: if r.get() != 0x02: raise ValueError - n = decode_length(r) + n = read_length(r) if n == 0: raise ValueError @@ -66,10 +80,13 @@ def decode_int(r: BufferReader) -> memoryview: def encode_seq(seq: tuple) -> bytes: - res = b"" + # Preallocate space for a signature, which is all that this function ever encodes. + buffer = empty_bytearray(MAX_DER_SIGNATURE_LENGTH) + buffer.append(0x30) for i in seq: - res += encode_int(i) - return b"\x30" + encode_length(len(res)) + res + write_int(buffer, i) + buffer[1:1] = encode_length(len(buffer) - 1) + return buffer def decode_seq(data: memoryview) -> list[memoryview]: @@ -77,12 +94,12 @@ def decode_seq(data: memoryview) -> list[memoryview]: if r.get() != 0x30: raise ValueError - n = decode_length(r) + n = read_length(r) seq = [] end = r.offset + n while r.offset < end: - i = decode_int(r) + i = read_int(r) seq.append(i) if r.offset != end or r.remaining_count():