mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-26 16:18:22 +00:00
chore(core): Use BufferReader for CBOR decoding.
This commit is contained in:
parent
ac939c94aa
commit
e5741ac308
@ -5,7 +5,9 @@ Minimalistic CBOR implementation, supports only what we need in cardano.
|
||||
import ustruct as struct
|
||||
from micropython import const
|
||||
|
||||
from trezor import log
|
||||
from trezor import log, utils
|
||||
|
||||
from . import readers
|
||||
|
||||
if False:
|
||||
from typing import Any, Iterable, List, Tuple, Union
|
||||
@ -103,110 +105,93 @@ def _cbor_encode(value: Value) -> Iterable[bytes]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]:
|
||||
def _read_length(r: utils.BufferReader, aux: int) -> int:
|
||||
if aux < _CBOR_UINT8_FOLLOWS:
|
||||
return (aux, cbor)
|
||||
return aux
|
||||
elif aux == _CBOR_UINT8_FOLLOWS:
|
||||
return (cbor[0], cbor[1:])
|
||||
return r.get()
|
||||
elif aux == _CBOR_UINT16_FOLLOWS:
|
||||
res = cbor[1]
|
||||
res += cbor[0] << 8
|
||||
return (res, cbor[2:])
|
||||
return readers.read_uint16_be(r)
|
||||
elif aux == _CBOR_UINT32_FOLLOWS:
|
||||
res = cbor[3]
|
||||
res += cbor[2] << 8
|
||||
res += cbor[1] << 16
|
||||
res += cbor[0] << 24
|
||||
return (res, cbor[4:])
|
||||
return readers.read_uint32_be(r)
|
||||
elif aux == _CBOR_UINT64_FOLLOWS:
|
||||
res = cbor[7]
|
||||
res += cbor[6] << 8
|
||||
res += cbor[5] << 16
|
||||
res += cbor[4] << 24
|
||||
res += cbor[3] << 32
|
||||
res += cbor[2] << 40
|
||||
res += cbor[1] << 48
|
||||
res += cbor[0] << 56
|
||||
return (res, cbor[8:])
|
||||
return readers.read_uint64_be(r)
|
||||
else:
|
||||
raise NotImplementedError("Length %d not suppported" % aux)
|
||||
|
||||
|
||||
def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]:
|
||||
fb = cbor[0]
|
||||
data = b""
|
||||
def _cbor_decode(r: utils.BufferReader) -> Value:
|
||||
fb = r.get()
|
||||
fb_type = fb & _CBOR_TYPE_MASK
|
||||
fb_aux = fb & _CBOR_INFO_BITS
|
||||
if fb_type == _CBOR_UNSIGNED_INT:
|
||||
return _read_length(cbor[1:], fb_aux)
|
||||
return _read_length(r, fb_aux)
|
||||
elif fb_type == _CBOR_NEGATIVE_INT:
|
||||
val, data = _read_length(cbor[1:], fb_aux)
|
||||
return (-1 - val, data)
|
||||
val = _read_length(r, fb_aux)
|
||||
return -1 - val
|
||||
elif fb_type == _CBOR_BYTE_STRING:
|
||||
ln, data = _read_length(cbor[1:], fb_aux)
|
||||
return (data[0:ln], data[ln:])
|
||||
ln = _read_length(r, fb_aux)
|
||||
return r.read(ln)
|
||||
elif fb_type == _CBOR_TEXT_STRING:
|
||||
ln, data = _read_length(cbor[1:], fb_aux)
|
||||
return (data[0:ln].decode(), data[ln:])
|
||||
ln = _read_length(r, fb_aux)
|
||||
return r.read(ln).decode()
|
||||
elif fb_type == _CBOR_ARRAY:
|
||||
if fb_aux == _CBOR_VAR_FOLLOWS:
|
||||
res: Value = []
|
||||
data = cbor[1:]
|
||||
while True:
|
||||
item, data = _cbor_decode(data)
|
||||
item = _cbor_decode(r)
|
||||
if item == _CBOR_PRIMITIVE + _CBOR_BREAK:
|
||||
break
|
||||
res.append(item)
|
||||
return (res, data)
|
||||
return res
|
||||
else:
|
||||
ln, data = _read_length(cbor[1:], fb_aux)
|
||||
ln = _read_length(r, fb_aux)
|
||||
res = []
|
||||
for i in range(ln):
|
||||
item, data = _cbor_decode(data)
|
||||
for _ in range(ln):
|
||||
item = _cbor_decode(r)
|
||||
res.append(item)
|
||||
return (res, data)
|
||||
return res
|
||||
elif fb_type == _CBOR_MAP:
|
||||
res = {}
|
||||
if fb_aux == _CBOR_VAR_FOLLOWS:
|
||||
data = cbor[1:]
|
||||
while True:
|
||||
key, data = _cbor_decode(data)
|
||||
key = _cbor_decode(r)
|
||||
if key in res:
|
||||
raise ValueError
|
||||
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
|
||||
break
|
||||
value, data = _cbor_decode(data)
|
||||
value = _cbor_decode(r)
|
||||
res[key] = value
|
||||
else:
|
||||
ln, data = _read_length(cbor[1:], fb_aux)
|
||||
for i in range(ln):
|
||||
key, data = _cbor_decode(data)
|
||||
ln = _read_length(r, fb_aux)
|
||||
for _ in range(ln):
|
||||
key = _cbor_decode(r)
|
||||
if key in res:
|
||||
raise ValueError
|
||||
value, data = _cbor_decode(data)
|
||||
value = _cbor_decode(r)
|
||||
res[key] = value
|
||||
return res, data
|
||||
return res
|
||||
elif fb_type == _CBOR_TAG:
|
||||
val, data = _read_length(cbor[1:], fb_aux)
|
||||
item, data = _cbor_decode(data)
|
||||
val = _read_length(r, fb_aux)
|
||||
item = _cbor_decode(r)
|
||||
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
|
||||
return item, data
|
||||
return item
|
||||
else:
|
||||
return Tagged(val, item), data
|
||||
return Tagged(val, item)
|
||||
elif fb_type == _CBOR_PRIMITIVE:
|
||||
if fb_aux == _CBOR_FALSE:
|
||||
return (False, cbor[1:])
|
||||
return False
|
||||
elif fb_aux == _CBOR_TRUE:
|
||||
return (True, cbor[1:])
|
||||
return True
|
||||
elif fb_aux == _CBOR_NULL:
|
||||
return (None, cbor[1:])
|
||||
return None
|
||||
elif fb_aux == _CBOR_BREAK:
|
||||
return (cbor[0], cbor[1:])
|
||||
return fb
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if __debug__:
|
||||
log.debug(__name__, "not implemented (decode): %s", cbor[0])
|
||||
log.debug(__name__, "not implemented (decode): %s", fb)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -246,7 +231,8 @@ def encode(value: Value) -> bytes:
|
||||
|
||||
|
||||
def decode(cbor: bytes) -> Value:
|
||||
res, check = _cbor_decode(cbor)
|
||||
if not (check == b""):
|
||||
r = utils.BufferReader(cbor)
|
||||
res = _cbor_decode(r)
|
||||
if r.remaining_count():
|
||||
raise ValueError
|
||||
return res
|
||||
|
@ -18,9 +18,20 @@ def read_bitcoin_varint(r: BufferReader) -> int:
|
||||
return n
|
||||
|
||||
|
||||
def read_uint16_be(r: BufferReader) -> int:
|
||||
n = r.get()
|
||||
return (n << 8) + r.get()
|
||||
|
||||
|
||||
def read_uint32_be(r: BufferReader) -> int:
|
||||
n = r.get() << 24
|
||||
n += r.get() << 16
|
||||
n += r.get() << 8
|
||||
n += r.get()
|
||||
n = r.get()
|
||||
for _ in range(3):
|
||||
n = (n << 8) + r.get()
|
||||
return n
|
||||
|
||||
|
||||
def read_uint64_be(r: BufferReader) -> int:
|
||||
n = r.get()
|
||||
for _ in range(7):
|
||||
n = (n << 8) + r.get()
|
||||
return n
|
||||
|
@ -69,9 +69,10 @@ class TestCardanoCbor(unittest.TestCase):
|
||||
# null
|
||||
(None, 'f6'),
|
||||
]
|
||||
for val, encoded in test_vectors:
|
||||
self.assertEqual(unhexlify(encoded), encode(val))
|
||||
self.assertEqual(val, decode(unhexlify(encoded)))
|
||||
for val, encoded_hex in test_vectors:
|
||||
encoded = unhexlify(encoded_hex)
|
||||
self.assertEqual(encode(val), encoded)
|
||||
self.assertEqual(decode(encoded), val)
|
||||
|
||||
def test_cbor_tuples(self):
|
||||
"""
|
||||
@ -83,10 +84,11 @@ class TestCardanoCbor(unittest.TestCase):
|
||||
([1, [2, 3], [4, 5]], '8301820203820405'),
|
||||
(list(range(1, 26)), '98190102030405060708090a0b0c0d0e0f101112131415161718181819'),
|
||||
]
|
||||
for val, encoded in test_vectors:
|
||||
for val, encoded_hex in test_vectors:
|
||||
value_tuple = tuple(val)
|
||||
self.assertEqual(unhexlify(encoded), encode(value_tuple))
|
||||
self.assertEqual(val, decode(unhexlify(encoded)))
|
||||
encoded = unhexlify(encoded_hex)
|
||||
self.assertEqual(encode(value_tuple), encoded)
|
||||
self.assertEqual(decode(encoded), val)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user