mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-12 00:10:58 +00:00
core/crypto: Add functions for verifying DER encoded signatures.
This commit is contained in:
parent
9459c5a5c2
commit
3b6c1e5e6b
@ -5,7 +5,8 @@ from trezor.crypto import bech32, bip32, der
|
|||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from trezor.utils import ensure
|
from trezor.utils import ensure
|
||||||
|
|
||||||
from apps.common.coininfo import CoinInfo
|
if False:
|
||||||
|
from apps.common.coininfo import CoinInfo
|
||||||
|
|
||||||
# supported witness version for bech32 addresses
|
# supported witness version for bech32 addresses
|
||||||
_BECH32_WITVER = const(0x00)
|
_BECH32_WITVER = const(0x00)
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
from apps.common.readers import BytearrayReader
|
||||||
|
|
||||||
|
if False:
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
def encode_length(l: int) -> bytes:
|
def encode_length(l: int) -> bytes:
|
||||||
if l < 0x80:
|
if l < 0x80:
|
||||||
return bytes([l])
|
return bytes([l])
|
||||||
@ -9,15 +15,80 @@ def encode_length(l: int) -> bytes:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def decode_length(r: BytearrayReader) -> int:
|
||||||
|
init = r.get()
|
||||||
|
if init < 0x80:
|
||||||
|
# short form encodes length in initial octet
|
||||||
|
return init
|
||||||
|
|
||||||
|
if init == 0x80 or init == 0xFF or r.peek() == 0x00:
|
||||||
|
raise ValueError # indefinite length, RFU or not shortest possible
|
||||||
|
|
||||||
|
# long form
|
||||||
|
n = 0
|
||||||
|
for _ in range(init & 0x7F):
|
||||||
|
n = n * 0x100 + r.get()
|
||||||
|
|
||||||
|
if n < 128:
|
||||||
|
raise ValueError # encoding is not the shortest possible
|
||||||
|
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
def encode_int(i: bytes) -> bytes:
|
def encode_int(i: bytes) -> bytes:
|
||||||
i = i.lstrip(b"\x00")
|
i = i.lstrip(b"\x00")
|
||||||
|
if not i:
|
||||||
|
i = b"\00"
|
||||||
|
|
||||||
if i[0] >= 0x80:
|
if i[0] >= 0x80:
|
||||||
i = b"\x00" + i
|
i = b"\x00" + i
|
||||||
return b"\x02" + encode_length(len(i)) + i
|
return b"\x02" + encode_length(len(i)) + i
|
||||||
|
|
||||||
|
|
||||||
|
def decode_int(r: BytearrayReader) -> bytes:
|
||||||
|
if r.get() != 0x02:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
n = decode_length(r)
|
||||||
|
if n == 0:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if r.peek() & 0x80:
|
||||||
|
raise ValueError # negative integer
|
||||||
|
|
||||||
|
if r.peek() == 0x00 and n > 1:
|
||||||
|
r.get()
|
||||||
|
n -= 1
|
||||||
|
if r.peek() & 0x80 == 0x00:
|
||||||
|
raise ValueError # excessive zero-padding
|
||||||
|
|
||||||
|
if r.peek() == 0x00:
|
||||||
|
raise ValueError # excessive zero-padding
|
||||||
|
|
||||||
|
return r.read(n)
|
||||||
|
|
||||||
|
|
||||||
def encode_seq(seq: tuple) -> bytes:
|
def encode_seq(seq: tuple) -> bytes:
|
||||||
res = b""
|
res = b""
|
||||||
for i in seq:
|
for i in seq:
|
||||||
res += encode_int(i)
|
res += encode_int(i)
|
||||||
return b"\x30" + encode_length(len(res)) + res
|
return b"\x30" + encode_length(len(res)) + res
|
||||||
|
|
||||||
|
|
||||||
|
def decode_seq(data: bytes) -> List[bytes]:
|
||||||
|
r = BytearrayReader(data)
|
||||||
|
|
||||||
|
if r.get() != 0x30:
|
||||||
|
raise ValueError
|
||||||
|
n = decode_length(r)
|
||||||
|
|
||||||
|
seq = []
|
||||||
|
end = r.offset + n
|
||||||
|
while r.offset < end:
|
||||||
|
i = decode_int(r)
|
||||||
|
seq.append(i)
|
||||||
|
|
||||||
|
if r.offset != end or r.remaining_count():
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return seq
|
||||||
|
@ -46,11 +46,26 @@ class TestCryptoDer(unittest.TestCase):
|
|||||||
def test_der_encode_seq(self):
|
def test_der_encode_seq(self):
|
||||||
|
|
||||||
for s, d in self.vectors_seq:
|
for s, d in self.vectors_seq:
|
||||||
s = (unhexlify(i) for i in s)
|
s = tuple(unhexlify(i) for i in s)
|
||||||
d = unhexlify(d)
|
d = unhexlify(d)
|
||||||
d2 = der.encode_seq(s)
|
d2 = der.encode_seq(s)
|
||||||
self.assertEqual(d, d2)
|
self.assertEqual(d2, d)
|
||||||
|
s = [i.lstrip(b"\x00") for i in s]
|
||||||
|
s2 = der.decode_seq(d)
|
||||||
|
self.assertEqual(s2, s)
|
||||||
|
|
||||||
|
def test_der_encode_decode_long_seq(self):
|
||||||
|
for length in (1, 127, 128, 129, 255, 256, 257):
|
||||||
|
raw_int = bytes((i & 0xfe) + 1 for i in range(length))
|
||||||
|
for leading_zeros in range(3):
|
||||||
|
encoded = der.encode_seq((b"\x00" * leading_zeros + raw_int,))
|
||||||
|
decoded = der.decode_seq(encoded)
|
||||||
|
self.assertEqual(decoded, [raw_int])
|
||||||
|
|
||||||
|
for zeroes in range(3):
|
||||||
|
encoded = der.encode_seq((b"\x00" * zeroes,))
|
||||||
|
decoded = der.decode_seq(encoded)
|
||||||
|
self.assertEqual(decoded, [b"\x00"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user