You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/common/cbor.py

280 lines
8.4 KiB

"""
Minimalistic CBOR implementation, supports only what we need in cardano.
"""
import ustruct as struct
from micropython import const
from trezor import log, utils
from . import readers, writers
if False:
from typing import Any, Union, Iterator, Tuple
Value = Any
CborSequence = Union[list[Value], Tuple[Value, ...]]
_CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F)
_CBOR_UNSIGNED_INT = const(0b000 << 5)
_CBOR_NEGATIVE_INT = const(0b001 << 5)
_CBOR_BYTE_STRING = const(0b010 << 5)
_CBOR_TEXT_STRING = const(0b011 << 5)
_CBOR_ARRAY = const(0b100 << 5)
_CBOR_MAP = const(0b101 << 5)
_CBOR_TAG = const(0b110 << 5)
_CBOR_PRIMITIVE = const(0b111 << 5)
_CBOR_UINT8_FOLLOWS = const(0x18)
_CBOR_UINT16_FOLLOWS = const(0x19)
_CBOR_UINT32_FOLLOWS = const(0x1A)
_CBOR_UINT64_FOLLOWS = const(0x1B)
_CBOR_VAR_FOLLOWS = const(0x1F)
_CBOR_FALSE = const(0x14)
_CBOR_TRUE = const(0x15)
_CBOR_NULL = const(0x16)
_CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18)
def _header(typ: int, l: int) -> bytes:
if l < 24:
return struct.pack(">B", typ + l)
elif l < 2 ** 8:
return struct.pack(">BB", typ + 24, l)
elif l < 2 ** 16:
return struct.pack(">BH", typ + 25, l)
elif l < 2 ** 32:
return struct.pack(">BI", typ + 26, l)
elif l < 2 ** 64:
return struct.pack(">BQ", typ + 27, l)
else:
raise NotImplementedError("Length %d not suppported" % l)
def _cbor_encode(value: Value) -> Iterator[bytes]:
if isinstance(value, int):
if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value)
else:
yield _header(_CBOR_NEGATIVE_INT, -1 - value)
elif isinstance(value, bytes):
yield _header(_CBOR_BYTE_STRING, len(value))
yield value
elif isinstance(value, bytearray):
yield _header(_CBOR_BYTE_STRING, len(value))
yield bytes(value)
elif isinstance(value, str):
encoded_value = value.encode()
yield _header(_CBOR_TEXT_STRING, len(encoded_value))
yield encoded_value
elif isinstance(value, list) or isinstance(value, tuple):
# definite-length valued list
yield _header(_CBOR_ARRAY, len(value))
for x in value:
yield from _cbor_encode(x)
elif isinstance(value, dict):
yield _header(_CBOR_MAP, len(value))
sorted_map = sorted((encode(k), v) for k, v in value.items())
for k, v in sorted_map:
yield k
yield from _cbor_encode(v)
elif isinstance(value, Tagged):
yield _header(_CBOR_TAG, value.tag)
yield from _cbor_encode(value.value)
elif isinstance(value, IndefiniteLengthArray):
yield bytes([_CBOR_ARRAY + 31])
for x in value.array:
yield from _cbor_encode(x)
yield bytes([_CBOR_PRIMITIVE + 31])
elif isinstance(value, bool):
if value:
yield bytes([_CBOR_PRIMITIVE + _CBOR_TRUE])
else:
yield bytes([_CBOR_PRIMITIVE + _CBOR_FALSE])
elif isinstance(value, Raw):
yield value.value
elif value is None:
yield bytes([_CBOR_PRIMITIVE + _CBOR_NULL])
else:
if __debug__:
log.debug(__name__, "not implemented (encode): %s", type(value))
raise NotImplementedError
def _read_length(r: utils.BufferReader, aux: int) -> int:
if aux < _CBOR_UINT8_FOLLOWS:
return aux
elif aux == _CBOR_UINT8_FOLLOWS:
return r.get()
elif aux == _CBOR_UINT16_FOLLOWS:
return readers.read_uint16_be(r)
elif aux == _CBOR_UINT32_FOLLOWS:
return readers.read_uint32_be(r)
elif aux == _CBOR_UINT64_FOLLOWS:
return readers.read_uint64_be(r)
else:
raise NotImplementedError("Length %d not suppported" % aux)
def _cbor_decode(r: utils.BufferReader) -> Value:
fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
if fb_type == _CBOR_UNSIGNED_INT:
return _read_length(r, fb_aux)
elif fb_type == _CBOR_NEGATIVE_INT:
val = _read_length(r, fb_aux)
return -1 - val
elif fb_type == _CBOR_BYTE_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln)
elif fb_type == _CBOR_TEXT_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln).decode()
elif fb_type == _CBOR_ARRAY:
if fb_aux == _CBOR_VAR_FOLLOWS:
res: Value = []
while True:
item = _cbor_decode(r)
if item == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
res.append(item)
return res
else:
ln = _read_length(r, fb_aux)
res = []
for _ in range(ln):
item = _cbor_decode(r)
res.append(item)
return res
elif fb_type == _CBOR_MAP:
res = {}
if fb_aux == _CBOR_VAR_FOLLOWS:
while True:
key = _cbor_decode(r)
if key in res:
raise ValueError
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
value = _cbor_decode(r)
res[key] = value
else:
ln = _read_length(r, fb_aux)
for _ in range(ln):
key = _cbor_decode(r)
if key in res:
raise ValueError
value = _cbor_decode(r)
res[key] = value
return res
elif fb_type == _CBOR_TAG:
val = _read_length(r, fb_aux)
item = _cbor_decode(r)
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return item
else:
return Tagged(val, item)
elif fb_type == _CBOR_PRIMITIVE:
if fb_aux == _CBOR_FALSE:
return False
elif fb_aux == _CBOR_TRUE:
return True
elif fb_aux == _CBOR_NULL:
return None
elif fb_aux == _CBOR_BREAK:
return fb
else:
raise NotImplementedError
else:
if __debug__:
log.debug(__name__, "not implemented (decode): %s", fb)
raise NotImplementedError
class Tagged:
def __init__(self, tag: int, value: Value) -> None:
self.tag = tag
self.value = value
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Tagged)
and self.tag == other.tag
and self.value == other.value
)
class Raw:
def __init__(self, value: Value):
self.value = value
class IndefiniteLengthArray:
def __init__(self, array: list[Value]) -> None:
self.array = array
def __eq__(self, other: object) -> bool:
if isinstance(other, IndefiniteLengthArray):
return self.array == other.array
elif isinstance(other, list):
return self.array == other
else:
return False
def encode(value: Value) -> bytes:
return b"".join(_cbor_encode(value))
def encode_streamed(value: Value) -> Iterator[bytes]:
"""
Returns the encoded value as an iterable of the individual
CBOR "chunks", removing the need to reserve a continuous
chunk of memory for the full serialized representation of the value
"""
return _cbor_encode(value)
def encode_chunked(value: Value, max_chunk_size: int) -> Iterator[bytes]:
"""
Returns the encoded value as an iterable of chunks of a given size,
removing the need to reserve a continuous chunk of memory for the
full serialized representation of the value.
"""
if max_chunk_size <= 0:
raise ValueError
chunks = encode_streamed(value)
chunk_buffer = writers.empty_bytearray(max_chunk_size)
try:
current_chunk_view = utils.BufferReader(next(chunks))
while True:
num_bytes_to_write = min(
current_chunk_view.remaining_count(),
max_chunk_size - len(chunk_buffer),
)
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
if len(chunk_buffer) >= max_chunk_size:
yield chunk_buffer
chunk_buffer[:] = bytes()
if current_chunk_view.remaining_count() == 0:
current_chunk_view = utils.BufferReader(next(chunks))
except StopIteration:
if len(chunk_buffer) > 0:
yield chunk_buffer
def decode(cbor: bytes) -> Value:
r = utils.BufferReader(cbor)
res = _cbor_decode(r)
if r.remaining_count():
raise ValueError
return res