From e5741ac308a587fecc14a46014452d1e466ec153 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Fri, 5 Feb 2021 19:51:01 +0100 Subject: [PATCH] chore(core): Use BufferReader for CBOR decoding. --- core/src/apps/common/cbor.py | 100 ++++++++++++---------------- core/src/apps/common/readers.py | 19 ++++-- core/tests/test_apps.common.cbor.py | 14 ++-- 3 files changed, 66 insertions(+), 67 deletions(-) diff --git a/core/src/apps/common/cbor.py b/core/src/apps/common/cbor.py index 6a78c4612..3db8579ea 100644 --- a/core/src/apps/common/cbor.py +++ b/core/src/apps/common/cbor.py @@ -5,7 +5,9 @@ Minimalistic CBOR implementation, supports only what we need in cardano. import ustruct as struct from micropython import const -from trezor import log +from trezor import log, utils + +from . import readers if False: from typing import Any, Iterable, List, Tuple, Union @@ -103,110 +105,93 @@ def _cbor_encode(value: Value) -> Iterable[bytes]: raise NotImplementedError -def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]: +def _read_length(r: utils.BufferReader, aux: int) -> int: if aux < _CBOR_UINT8_FOLLOWS: - return (aux, cbor) + return aux elif aux == _CBOR_UINT8_FOLLOWS: - return (cbor[0], cbor[1:]) + return r.get() elif aux == _CBOR_UINT16_FOLLOWS: - res = cbor[1] - res += cbor[0] << 8 - return (res, cbor[2:]) + return readers.read_uint16_be(r) elif aux == _CBOR_UINT32_FOLLOWS: - res = cbor[3] - res += cbor[2] << 8 - res += cbor[1] << 16 - res += cbor[0] << 24 - return (res, cbor[4:]) + return readers.read_uint32_be(r) elif aux == _CBOR_UINT64_FOLLOWS: - res = cbor[7] - res += cbor[6] << 8 - res += cbor[5] << 16 - res += cbor[4] << 24 - res += cbor[3] << 32 - res += cbor[2] << 40 - res += cbor[1] << 48 - res += cbor[0] << 56 - return (res, cbor[8:]) + return readers.read_uint64_be(r) else: raise NotImplementedError("Length %d not suppported" % aux) -def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]: - fb = cbor[0] - data = b"" +def _cbor_decode(r: utils.BufferReader) -> Value: + fb = r.get() fb_type = fb & _CBOR_TYPE_MASK fb_aux = fb & _CBOR_INFO_BITS if fb_type == _CBOR_UNSIGNED_INT: - return _read_length(cbor[1:], fb_aux) + return _read_length(r, fb_aux) elif fb_type == _CBOR_NEGATIVE_INT: - val, data = _read_length(cbor[1:], fb_aux) - return (-1 - val, data) + val = _read_length(r, fb_aux) + return -1 - val elif fb_type == _CBOR_BYTE_STRING: - ln, data = _read_length(cbor[1:], fb_aux) - return (data[0:ln], data[ln:]) + ln = _read_length(r, fb_aux) + return r.read(ln) elif fb_type == _CBOR_TEXT_STRING: - ln, data = _read_length(cbor[1:], fb_aux) - return (data[0:ln].decode(), data[ln:]) + ln = _read_length(r, fb_aux) + return r.read(ln).decode() elif fb_type == _CBOR_ARRAY: if fb_aux == _CBOR_VAR_FOLLOWS: res: Value = [] - data = cbor[1:] while True: - item, data = _cbor_decode(data) + item = _cbor_decode(r) if item == _CBOR_PRIMITIVE + _CBOR_BREAK: break res.append(item) - return (res, data) + return res else: - ln, data = _read_length(cbor[1:], fb_aux) + ln = _read_length(r, fb_aux) res = [] - for i in range(ln): - item, data = _cbor_decode(data) + for _ in range(ln): + item = _cbor_decode(r) res.append(item) - return (res, data) + return res elif fb_type == _CBOR_MAP: res = {} if fb_aux == _CBOR_VAR_FOLLOWS: - data = cbor[1:] while True: - key, data = _cbor_decode(data) + key = _cbor_decode(r) if key in res: raise ValueError if key == _CBOR_PRIMITIVE + _CBOR_BREAK: break - value, data = _cbor_decode(data) + value = _cbor_decode(r) res[key] = value else: - ln, data = _read_length(cbor[1:], fb_aux) - for i in range(ln): - key, data = _cbor_decode(data) + ln = _read_length(r, fb_aux) + for _ in range(ln): + key = _cbor_decode(r) if key in res: raise ValueError - value, data = _cbor_decode(data) + value = _cbor_decode(r) res[key] = value - return res, data + return res elif fb_type == _CBOR_TAG: - val, data = _read_length(cbor[1:], fb_aux) - item, data = _cbor_decode(data) + val = _read_length(r, fb_aux) + item = _cbor_decode(r) if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported - return item, data + return item else: - return Tagged(val, item), data + return Tagged(val, item) elif fb_type == _CBOR_PRIMITIVE: if fb_aux == _CBOR_FALSE: - return (False, cbor[1:]) + return False elif fb_aux == _CBOR_TRUE: - return (True, cbor[1:]) + return True elif fb_aux == _CBOR_NULL: - return (None, cbor[1:]) + return None elif fb_aux == _CBOR_BREAK: - return (cbor[0], cbor[1:]) + return fb else: raise NotImplementedError else: if __debug__: - log.debug(__name__, "not implemented (decode): %s", cbor[0]) + log.debug(__name__, "not implemented (decode): %s", fb) raise NotImplementedError @@ -246,7 +231,8 @@ def encode(value: Value) -> bytes: def decode(cbor: bytes) -> Value: - res, check = _cbor_decode(cbor) - if not (check == b""): + r = utils.BufferReader(cbor) + res = _cbor_decode(r) + if r.remaining_count(): raise ValueError return res diff --git a/core/src/apps/common/readers.py b/core/src/apps/common/readers.py index b4ed9e7a6..b75b6b56a 100644 --- a/core/src/apps/common/readers.py +++ b/core/src/apps/common/readers.py @@ -18,9 +18,20 @@ def read_bitcoin_varint(r: BufferReader) -> int: return n +def read_uint16_be(r: BufferReader) -> int: + n = r.get() + return (n << 8) + r.get() + + def read_uint32_be(r: BufferReader) -> int: - n = r.get() << 24 - n += r.get() << 16 - n += r.get() << 8 - n += r.get() + n = r.get() + for _ in range(3): + n = (n << 8) + r.get() + return n + + +def read_uint64_be(r: BufferReader) -> int: + n = r.get() + for _ in range(7): + n = (n << 8) + r.get() return n diff --git a/core/tests/test_apps.common.cbor.py b/core/tests/test_apps.common.cbor.py index 39454b54a..f41276473 100644 --- a/core/tests/test_apps.common.cbor.py +++ b/core/tests/test_apps.common.cbor.py @@ -69,9 +69,10 @@ class TestCardanoCbor(unittest.TestCase): # null (None, 'f6'), ] - for val, encoded in test_vectors: - self.assertEqual(unhexlify(encoded), encode(val)) - self.assertEqual(val, decode(unhexlify(encoded))) + for val, encoded_hex in test_vectors: + encoded = unhexlify(encoded_hex) + self.assertEqual(encode(val), encoded) + self.assertEqual(decode(encoded), val) def test_cbor_tuples(self): """ @@ -83,10 +84,11 @@ class TestCardanoCbor(unittest.TestCase): ([1, [2, 3], [4, 5]], '8301820203820405'), (list(range(1, 26)), '98190102030405060708090a0b0c0d0e0f101112131415161718181819'), ] - for val, encoded in test_vectors: + for val, encoded_hex in test_vectors: value_tuple = tuple(val) - self.assertEqual(unhexlify(encoded), encode(value_tuple)) - self.assertEqual(val, decode(unhexlify(encoded))) + encoded = unhexlify(encoded_hex) + self.assertEqual(encode(value_tuple), encoded) + self.assertEqual(decode(encoded), val) if __name__ == '__main__': unittest.main()