You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/common/cbor.py

311 lines
9.4 KiB

"""
Minimalistic CBOR implementation, supports only what we need in cardano.
"""
import ustruct as struct
from micropython import const
from trezor import log, utils
from . import readers
if False:
from typing import Any, Generic, Iterator, Tuple, TypeVar, Union
K = TypeVar("K")
V = TypeVar("V")
Value = Any
CborSequence = Union[list[Value], Tuple[Value, ...]]
else:
# mypy cheat: Generic[K, V] will be `object` which is a valid parent type
Generic = {(0, 0): object} # type: ignore
K = V = 0 # type: ignore
_CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F)
_CBOR_UNSIGNED_INT = const(0b000 << 5)
_CBOR_NEGATIVE_INT = const(0b001 << 5)
_CBOR_BYTE_STRING = const(0b010 << 5)
_CBOR_TEXT_STRING = const(0b011 << 5)
_CBOR_ARRAY = const(0b100 << 5)
_CBOR_MAP = const(0b101 << 5)
_CBOR_TAG = const(0b110 << 5)
_CBOR_PRIMITIVE = const(0b111 << 5)
_CBOR_UINT8_FOLLOWS = const(0x18)
_CBOR_UINT16_FOLLOWS = const(0x19)
_CBOR_UINT32_FOLLOWS = const(0x1A)
_CBOR_UINT64_FOLLOWS = const(0x1B)
_CBOR_VAR_FOLLOWS = const(0x1F)
_CBOR_FALSE = const(0x14)
_CBOR_TRUE = const(0x15)
_CBOR_NULL = const(0x16)
_CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18)
def _header(typ: int, l: int) -> bytes:
if l < 24:
return struct.pack(">B", typ + l)
elif l < 2 ** 8:
return struct.pack(">BB", typ + 24, l)
elif l < 2 ** 16:
return struct.pack(">BH", typ + 25, l)
elif l < 2 ** 32:
return struct.pack(">BI", typ + 26, l)
elif l < 2 ** 64:
return struct.pack(">BQ", typ + 27, l)
else:
raise NotImplementedError("Length %d not supported" % l)
def _cbor_encode(value: Value) -> Iterator[bytes]:
if isinstance(value, int):
if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value)
else:
yield _header(_CBOR_NEGATIVE_INT, -1 - value)
elif isinstance(value, bytes):
yield _header(_CBOR_BYTE_STRING, len(value))
yield value
elif isinstance(value, bytearray):
yield _header(_CBOR_BYTE_STRING, len(value))
yield bytes(value)
elif isinstance(value, str):
encoded_value = value.encode()
yield _header(_CBOR_TEXT_STRING, len(encoded_value))
yield encoded_value
elif isinstance(value, list) or isinstance(value, tuple):
# definite-length valued list
yield _header(_CBOR_ARRAY, len(value))
for x in value:
yield from _cbor_encode(x)
elif isinstance(value, dict):
yield _header(_CBOR_MAP, len(value))
sorted_map = sorted((encode(k), v) for k, v in value.items())
for k, v in sorted_map:
yield k
yield from _cbor_encode(v)
elif isinstance(value, OrderedMap):
yield _header(_CBOR_MAP, len(value))
for k, v in value:
yield encode(k)
yield from _cbor_encode(v)
elif isinstance(value, Tagged):
yield _header(_CBOR_TAG, value.tag)
yield from _cbor_encode(value.value)
elif isinstance(value, IndefiniteLengthArray):
yield bytes([_CBOR_ARRAY + 31])
for x in value.array:
yield from _cbor_encode(x)
yield bytes([_CBOR_PRIMITIVE + 31])
elif isinstance(value, bool):
if value:
yield bytes([_CBOR_PRIMITIVE + _CBOR_TRUE])
else:
yield bytes([_CBOR_PRIMITIVE + _CBOR_FALSE])
elif isinstance(value, Raw):
yield value.value
elif value is None:
yield bytes([_CBOR_PRIMITIVE + _CBOR_NULL])
else:
if __debug__:
log.debug(__name__, "not implemented (encode): %s", type(value))
raise NotImplementedError
def _read_length(r: utils.BufferReader, aux: int) -> int:
if aux < _CBOR_UINT8_FOLLOWS:
return aux
elif aux == _CBOR_UINT8_FOLLOWS:
return r.get()
elif aux == _CBOR_UINT16_FOLLOWS:
return readers.read_uint16_be(r)
elif aux == _CBOR_UINT32_FOLLOWS:
return readers.read_uint32_be(r)
elif aux == _CBOR_UINT64_FOLLOWS:
return readers.read_uint64_be(r)
else:
raise NotImplementedError("Length %d not supported" % aux)
def _cbor_decode(r: utils.BufferReader) -> Value:
fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
if fb_type == _CBOR_UNSIGNED_INT:
return _read_length(r, fb_aux)
elif fb_type == _CBOR_NEGATIVE_INT:
val = _read_length(r, fb_aux)
return -1 - val
elif fb_type == _CBOR_BYTE_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln)
elif fb_type == _CBOR_TEXT_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln).decode()
elif fb_type == _CBOR_ARRAY:
if fb_aux == _CBOR_VAR_FOLLOWS:
res: Value = []
while True:
item = _cbor_decode(r)
if item == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
res.append(item)
return res
else:
ln = _read_length(r, fb_aux)
res = []
for _ in range(ln):
item = _cbor_decode(r)
res.append(item)
return res
elif fb_type == _CBOR_MAP:
res = {}
if fb_aux == _CBOR_VAR_FOLLOWS:
while True:
key = _cbor_decode(r)
if key in res:
raise ValueError
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
value = _cbor_decode(r)
res[key] = value
else:
ln = _read_length(r, fb_aux)
for _ in range(ln):
key = _cbor_decode(r)
if key in res:
raise ValueError
value = _cbor_decode(r)
res[key] = value
return res
elif fb_type == _CBOR_TAG:
val = _read_length(r, fb_aux)
item = _cbor_decode(r)
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return item
else:
return Tagged(val, item)
elif fb_type == _CBOR_PRIMITIVE:
if fb_aux == _CBOR_FALSE:
return False
elif fb_aux == _CBOR_TRUE:
return True
elif fb_aux == _CBOR_NULL:
return None
elif fb_aux == _CBOR_BREAK:
return fb
else:
raise NotImplementedError
else:
if __debug__:
log.debug(__name__, "not implemented (decode): %s", fb)
raise NotImplementedError
class Tagged:
def __init__(self, tag: int, value: Value) -> None:
self.tag = tag
self.value = value
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Tagged)
and self.tag == other.tag
and self.value == other.value
)
class Raw:
def __init__(self, value: Value):
self.value = value
class IndefiniteLengthArray:
def __init__(self, array: list[Value]) -> None:
self.array = array
def __eq__(self, other: object) -> bool:
if isinstance(other, IndefiniteLengthArray):
return self.array == other.array
elif isinstance(other, list):
return self.array == other
else:
return False
class OrderedMap(Generic[K, V]):
"""
Items of an OrderedMap are included in CBOR as they are added without sorting them in any way. We also allow
duplicates since CBOR is also somewhat lenient in not allowing them. It is thus up to the client to make sure no
duplicates are inserted if it's desired.
"""
def __init__(self) -> None:
self._internal_list: list[tuple[K, V]] = []
def __setitem__(self, key: K, value: V) -> None:
self._internal_list.append((key, value))
def __iter__(self) -> Iterator:
yield from self._internal_list
def __len__(self) -> int:
return len(self._internal_list)
def encode(value: Value) -> bytes:
return b"".join(_cbor_encode(value))
def encode_streamed(value: Value) -> Iterator[bytes]:
"""
Returns the encoded value as an iterable of the individual
CBOR "chunks", removing the need to reserve a continuous
chunk of memory for the full serialized representation of the value
"""
return _cbor_encode(value)
def encode_chunked(value: Value, max_chunk_size: int) -> Iterator[bytes]:
"""
Returns the encoded value as an iterable of chunks of a given size,
removing the need to reserve a continuous chunk of memory for the
full serialized representation of the value.
"""
if max_chunk_size <= 0:
raise ValueError
chunks = encode_streamed(value)
chunk_buffer = utils.empty_bytearray(max_chunk_size)
try:
current_chunk_view = utils.BufferReader(next(chunks))
while True:
num_bytes_to_write = min(
current_chunk_view.remaining_count(),
max_chunk_size - len(chunk_buffer),
)
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
if len(chunk_buffer) >= max_chunk_size:
yield chunk_buffer
chunk_buffer[:] = bytes()
if current_chunk_view.remaining_count() == 0:
current_chunk_view = utils.BufferReader(next(chunks))
except StopIteration:
if len(chunk_buffer) > 0:
yield chunk_buffer
def decode(cbor: bytes) -> Value:
r = utils.BufferReader(cbor)
res = _cbor_decode(r)
if r.remaining_count():
raise ValueError
return res