1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-12 00:10:58 +00:00

core/crypto: Add functions for verifying DER encoded signatures.

This commit is contained in:
Andrew Kozlik 2020-05-26 00:13:23 +02:00 committed by Andrew Kozlik
parent 9459c5a5c2
commit 3b6c1e5e6b
3 changed files with 90 additions and 3 deletions

View File

@ -5,7 +5,8 @@ from trezor.crypto import bech32, bip32, der
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.utils import ensure from trezor.utils import ensure
from apps.common.coininfo import CoinInfo if False:
from apps.common.coininfo import CoinInfo
# supported witness version for bech32 addresses # supported witness version for bech32 addresses
_BECH32_WITVER = const(0x00) _BECH32_WITVER = const(0x00)

View File

@ -1,3 +1,9 @@
from apps.common.readers import BytearrayReader
if False:
from typing import List
def encode_length(l: int) -> bytes: def encode_length(l: int) -> bytes:
if l < 0x80: if l < 0x80:
return bytes([l]) return bytes([l])
@ -9,15 +15,80 @@ def encode_length(l: int) -> bytes:
raise ValueError raise ValueError
def decode_length(r: BytearrayReader) -> int:
init = r.get()
if init < 0x80:
# short form encodes length in initial octet
return init
if init == 0x80 or init == 0xFF or r.peek() == 0x00:
raise ValueError # indefinite length, RFU or not shortest possible
# long form
n = 0
for _ in range(init & 0x7F):
n = n * 0x100 + r.get()
if n < 128:
raise ValueError # encoding is not the shortest possible
return n
def encode_int(i: bytes) -> bytes: def encode_int(i: bytes) -> bytes:
i = i.lstrip(b"\x00") i = i.lstrip(b"\x00")
if not i:
i = b"\00"
if i[0] >= 0x80: if i[0] >= 0x80:
i = b"\x00" + i i = b"\x00" + i
return b"\x02" + encode_length(len(i)) + i return b"\x02" + encode_length(len(i)) + i
def decode_int(r: BytearrayReader) -> bytes:
if r.get() != 0x02:
raise ValueError
n = decode_length(r)
if n == 0:
raise ValueError
if r.peek() & 0x80:
raise ValueError # negative integer
if r.peek() == 0x00 and n > 1:
r.get()
n -= 1
if r.peek() & 0x80 == 0x00:
raise ValueError # excessive zero-padding
if r.peek() == 0x00:
raise ValueError # excessive zero-padding
return r.read(n)
def encode_seq(seq: tuple) -> bytes: def encode_seq(seq: tuple) -> bytes:
res = b"" res = b""
for i in seq: for i in seq:
res += encode_int(i) res += encode_int(i)
return b"\x30" + encode_length(len(res)) + res return b"\x30" + encode_length(len(res)) + res
def decode_seq(data: bytes) -> List[bytes]:
r = BytearrayReader(data)
if r.get() != 0x30:
raise ValueError
n = decode_length(r)
seq = []
end = r.offset + n
while r.offset < end:
i = decode_int(r)
seq.append(i)
if r.offset != end or r.remaining_count():
raise ValueError
return seq

View File

@ -46,11 +46,26 @@ class TestCryptoDer(unittest.TestCase):
def test_der_encode_seq(self): def test_der_encode_seq(self):
for s, d in self.vectors_seq: for s, d in self.vectors_seq:
s = (unhexlify(i) for i in s) s = tuple(unhexlify(i) for i in s)
d = unhexlify(d) d = unhexlify(d)
d2 = der.encode_seq(s) d2 = der.encode_seq(s)
self.assertEqual(d, d2) self.assertEqual(d2, d)
s = [i.lstrip(b"\x00") for i in s]
s2 = der.decode_seq(d)
self.assertEqual(s2, s)
def test_der_encode_decode_long_seq(self):
for length in (1, 127, 128, 129, 255, 256, 257):
raw_int = bytes((i & 0xfe) + 1 for i in range(length))
for leading_zeros in range(3):
encoded = der.encode_seq((b"\x00" * leading_zeros + raw_int,))
decoded = der.decode_seq(encoded)
self.assertEqual(decoded, [raw_int])
for zeroes in range(3):
encoded = der.encode_seq((b"\x00" * zeroes,))
decoded = der.decode_seq(encoded)
self.assertEqual(decoded, [b"\x00"])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()