chore(core): Use BufferReader for CBOR decoding.

pull/1470/head
Andrew Kozlik 3 years ago committed by Andrew Kozlik
parent ac939c94aa
commit e5741ac308

@ -5,7 +5,9 @@ Minimalistic CBOR implementation, supports only what we need in cardano.
import ustruct as struct
from micropython import const
from trezor import log
from trezor import log, utils
from . import readers
if False:
from typing import Any, Iterable, List, Tuple, Union
@ -103,110 +105,93 @@ def _cbor_encode(value: Value) -> Iterable[bytes]:
raise NotImplementedError
def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]:
def _read_length(r: utils.BufferReader, aux: int) -> int:
if aux < _CBOR_UINT8_FOLLOWS:
return (aux, cbor)
return aux
elif aux == _CBOR_UINT8_FOLLOWS:
return (cbor[0], cbor[1:])
return r.get()
elif aux == _CBOR_UINT16_FOLLOWS:
res = cbor[1]
res += cbor[0] << 8
return (res, cbor[2:])
return readers.read_uint16_be(r)
elif aux == _CBOR_UINT32_FOLLOWS:
res = cbor[3]
res += cbor[2] << 8
res += cbor[1] << 16
res += cbor[0] << 24
return (res, cbor[4:])
return readers.read_uint32_be(r)
elif aux == _CBOR_UINT64_FOLLOWS:
res = cbor[7]
res += cbor[6] << 8
res += cbor[5] << 16
res += cbor[4] << 24
res += cbor[3] << 32
res += cbor[2] << 40
res += cbor[1] << 48
res += cbor[0] << 56
return (res, cbor[8:])
return readers.read_uint64_be(r)
else:
raise NotImplementedError("Length %d not suppported" % aux)
def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]:
fb = cbor[0]
data = b""
def _cbor_decode(r: utils.BufferReader) -> Value:
fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
if fb_type == _CBOR_UNSIGNED_INT:
return _read_length(cbor[1:], fb_aux)
return _read_length(r, fb_aux)
elif fb_type == _CBOR_NEGATIVE_INT:
val, data = _read_length(cbor[1:], fb_aux)
return (-1 - val, data)
val = _read_length(r, fb_aux)
return -1 - val
elif fb_type == _CBOR_BYTE_STRING:
ln, data = _read_length(cbor[1:], fb_aux)
return (data[0:ln], data[ln:])
ln = _read_length(r, fb_aux)
return r.read(ln)
elif fb_type == _CBOR_TEXT_STRING:
ln, data = _read_length(cbor[1:], fb_aux)
return (data[0:ln].decode(), data[ln:])
ln = _read_length(r, fb_aux)
return r.read(ln).decode()
elif fb_type == _CBOR_ARRAY:
if fb_aux == _CBOR_VAR_FOLLOWS:
res: Value = []
data = cbor[1:]
while True:
item, data = _cbor_decode(data)
item = _cbor_decode(r)
if item == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
res.append(item)
return (res, data)
return res
else:
ln, data = _read_length(cbor[1:], fb_aux)
ln = _read_length(r, fb_aux)
res = []
for i in range(ln):
item, data = _cbor_decode(data)
for _ in range(ln):
item = _cbor_decode(r)
res.append(item)
return (res, data)
return res
elif fb_type == _CBOR_MAP:
res = {}
if fb_aux == _CBOR_VAR_FOLLOWS:
data = cbor[1:]
while True:
key, data = _cbor_decode(data)
key = _cbor_decode(r)
if key in res:
raise ValueError
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
value, data = _cbor_decode(data)
value = _cbor_decode(r)
res[key] = value
else:
ln, data = _read_length(cbor[1:], fb_aux)
for i in range(ln):
key, data = _cbor_decode(data)
ln = _read_length(r, fb_aux)
for _ in range(ln):
key = _cbor_decode(r)
if key in res:
raise ValueError
value, data = _cbor_decode(data)
value = _cbor_decode(r)
res[key] = value
return res, data
return res
elif fb_type == _CBOR_TAG:
val, data = _read_length(cbor[1:], fb_aux)
item, data = _cbor_decode(data)
val = _read_length(r, fb_aux)
item = _cbor_decode(r)
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return item, data
return item
else:
return Tagged(val, item), data
return Tagged(val, item)
elif fb_type == _CBOR_PRIMITIVE:
if fb_aux == _CBOR_FALSE:
return (False, cbor[1:])
return False
elif fb_aux == _CBOR_TRUE:
return (True, cbor[1:])
return True
elif fb_aux == _CBOR_NULL:
return (None, cbor[1:])
return None
elif fb_aux == _CBOR_BREAK:
return (cbor[0], cbor[1:])
return fb
else:
raise NotImplementedError
else:
if __debug__:
log.debug(__name__, "not implemented (decode): %s", cbor[0])
log.debug(__name__, "not implemented (decode): %s", fb)
raise NotImplementedError
@ -246,7 +231,8 @@ def encode(value: Value) -> bytes:
def decode(cbor: bytes) -> Value:
res, check = _cbor_decode(cbor)
if not (check == b""):
r = utils.BufferReader(cbor)
res = _cbor_decode(r)
if r.remaining_count():
raise ValueError
return res

@ -18,9 +18,20 @@ def read_bitcoin_varint(r: BufferReader) -> int:
return n
def read_uint16_be(r: BufferReader) -> int:
n = r.get()
return (n << 8) + r.get()
def read_uint32_be(r: BufferReader) -> int:
n = r.get() << 24
n += r.get() << 16
n += r.get() << 8
n += r.get()
n = r.get()
for _ in range(3):
n = (n << 8) + r.get()
return n
def read_uint64_be(r: BufferReader) -> int:
n = r.get()
for _ in range(7):
n = (n << 8) + r.get()
return n

@ -69,9 +69,10 @@ class TestCardanoCbor(unittest.TestCase):
# null
(None, 'f6'),
]
for val, encoded in test_vectors:
self.assertEqual(unhexlify(encoded), encode(val))
self.assertEqual(val, decode(unhexlify(encoded)))
for val, encoded_hex in test_vectors:
encoded = unhexlify(encoded_hex)
self.assertEqual(encode(val), encoded)
self.assertEqual(decode(encoded), val)
def test_cbor_tuples(self):
"""
@ -83,10 +84,11 @@ class TestCardanoCbor(unittest.TestCase):
([1, [2, 3], [4, 5]], '8301820203820405'),
(list(range(1, 26)), '98190102030405060708090a0b0c0d0e0f101112131415161718181819'),
]
for val, encoded in test_vectors:
for val, encoded_hex in test_vectors:
value_tuple = tuple(val)
self.assertEqual(unhexlify(encoded), encode(value_tuple))
self.assertEqual(val, decode(unhexlify(encoded)))
encoded = unhexlify(encoded_hex)
self.assertEqual(encode(value_tuple), encoded)
self.assertEqual(decode(encoded), val)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save