diff --git a/core/src/apps/bitcoin/common.py b/core/src/apps/bitcoin/common.py index e1c88b536..9fceb2dff 100644 --- a/core/src/apps/bitcoin/common.py +++ b/core/src/apps/bitcoin/common.py @@ -5,7 +5,8 @@ from trezor.crypto import bech32, bip32, der from trezor.crypto.curve import secp256k1 from trezor.utils import ensure -from apps.common.coininfo import CoinInfo +if False: + from apps.common.coininfo import CoinInfo # supported witness version for bech32 addresses _BECH32_WITVER = const(0x00) diff --git a/core/src/trezor/crypto/der.py b/core/src/trezor/crypto/der.py index ac99aa3e7..06f6d2c06 100644 --- a/core/src/trezor/crypto/der.py +++ b/core/src/trezor/crypto/der.py @@ -1,3 +1,9 @@ +from apps.common.readers import BytearrayReader + +if False: + from typing import List + + def encode_length(l: int) -> bytes: if l < 0x80: return bytes([l]) @@ -9,15 +15,80 @@ def encode_length(l: int) -> bytes: raise ValueError +def decode_length(r: BytearrayReader) -> int: + init = r.get() + if init < 0x80: + # short form encodes length in initial octet + return init + + if init == 0x80 or init == 0xFF or r.peek() == 0x00: + raise ValueError # indefinite length, RFU or not shortest possible + + # long form + n = 0 + for _ in range(init & 0x7F): + n = n * 0x100 + r.get() + + if n < 128: + raise ValueError # encoding is not the shortest possible + + return n + + def encode_int(i: bytes) -> bytes: i = i.lstrip(b"\x00") + if not i: + i = b"\00" + if i[0] >= 0x80: i = b"\x00" + i return b"\x02" + encode_length(len(i)) + i +def decode_int(r: BytearrayReader) -> bytes: + if r.get() != 0x02: + raise ValueError + + n = decode_length(r) + if n == 0: + raise ValueError + + if r.peek() & 0x80: + raise ValueError # negative integer + + if r.peek() == 0x00 and n > 1: + r.get() + n -= 1 + if r.peek() & 0x80 == 0x00: + raise ValueError # excessive zero-padding + + if r.peek() == 0x00: + raise ValueError # excessive zero-padding + + return r.read(n) + + def encode_seq(seq: tuple) -> bytes: res = b"" for i in seq: res += encode_int(i) return b"\x30" + encode_length(len(res)) + res + + +def decode_seq(data: bytes) -> List[bytes]: + r = BytearrayReader(data) + + if r.get() != 0x30: + raise ValueError + n = decode_length(r) + + seq = [] + end = r.offset + n + while r.offset < end: + i = decode_int(r) + seq.append(i) + + if r.offset != end or r.remaining_count(): + raise ValueError + + return seq diff --git a/core/tests/test_trezor.crypto.der.py b/core/tests/test_trezor.crypto.der.py index 6adf42d06..d935b85b9 100644 --- a/core/tests/test_trezor.crypto.der.py +++ b/core/tests/test_trezor.crypto.der.py @@ -46,11 +46,26 @@ class TestCryptoDer(unittest.TestCase): def test_der_encode_seq(self): for s, d in self.vectors_seq: - s = (unhexlify(i) for i in s) + s = tuple(unhexlify(i) for i in s) d = unhexlify(d) d2 = der.encode_seq(s) - self.assertEqual(d, d2) - + self.assertEqual(d2, d) + s = [i.lstrip(b"\x00") for i in s] + s2 = der.decode_seq(d) + self.assertEqual(s2, s) + + def test_der_encode_decode_long_seq(self): + for length in (1, 127, 128, 129, 255, 256, 257): + raw_int = bytes((i & 0xfe) + 1 for i in range(length)) + for leading_zeros in range(3): + encoded = der.encode_seq((b"\x00" * leading_zeros + raw_int,)) + decoded = der.decode_seq(encoded) + self.assertEqual(decoded, [raw_int]) + + for zeroes in range(3): + encoded = der.encode_seq((b"\x00" * zeroes,)) + decoded = der.decode_seq(encoded) + self.assertEqual(decoded, [b"\x00"]) if __name__ == '__main__': unittest.main()