diff --git a/src/apps/common/storage.py b/src/apps/common/storage.py index 0d1c3fd813..7692d9741e 100644 --- a/src/apps/common/storage.py +++ b/src/apps/common/storage.py @@ -11,6 +11,8 @@ HOMESCREEN_MAXSIZE = 16384 _STORAGE_VERSION = b"\x01" _FALSE_BYTE = b"\x00" _TRUE_BYTE = b"\x01" +_COUNTER_HEAD_LEN = 4 +_COUNTER_TAIL_LEN = 8 # fmt: off _APP = const(0x01) # app namespace @@ -41,6 +43,46 @@ def _get_bool(app: int, key: int, public: bool = False) -> bool: return config.get(app, key, public) == _TRUE_BYTE +def _set_counter(app: int, key: int, count: int, public: bool = False) -> None: + value = count.to_bytes(_COUNTER_HEAD_LEN, "big") + if public: + value += _COUNTER_TAIL_LEN * b"\xff" + config.set(app, key, value, public) + + +def _next_counter(app: int, key: int, public: bool = False) -> int: +# If the counter value is public, then it is stored as a four byte integer in +# big endian byte order, called the "head", followed an eight byte "tail". The +# counter value is equal to the integer value of the head plus the number of +# zero bits in the tail. The counter value 0 is stored as 00000000FFFFFFFFFFFFFFFF. +# With each increment the tail is shifted to the right by one bit. Thus after +# three increments the stored value is 000000001FFFFFFFFFFFFFFF. Once all the +# bits in the tail are set to zero, the next counter value is stored as +# 00000021FFFFFFFFFFFFFFFF. + + value = config.get(app, key, public) + if value is None: + _set_counter(app, key, 0, public) + return 0 + + head = value[: _COUNTER_HEAD_LEN] + tail = value[_COUNTER_HEAD_LEN :] + i = tail.rfind(b"\x00") + 1 + count = int.from_bytes(head, "big") + 1 + 8*i + if i == len(tail): + _set_counter(app, key, count, public) + return count + + zero_count = 0 + while (tail[i] << zero_count) < 128: + zero_count += 1 + count += zero_count + + tail = tail[:i] + bytes([tail[i] >> 1]) + tail[i+1:] + config.set(app, key, head + tail, public) + return count + + def _new_device_id() -> str: return hexlify(random.bytes(12)).decode().upper() @@ -171,20 +213,11 @@ def set_autolock_delay_ms(delay_ms: int) -> None: def next_u2f_counter() -> int: - b = config.get(_APP, _U2F_COUNTER) - if not b: - b = 0 - else: - b = int.from_bytes(b, "big") + 1 - set_u2f_counter(b) - return b + return _next_counter(_APP, _U2F_COUNTER) -def set_u2f_counter(cntr: int): - if cntr: - config.set(_APP, _U2F_COUNTER, cntr.to_bytes(4, "big")) - else: - config.set(_APP, _U2F_COUNTER, b"") +def set_u2f_counter(cntr: int) -> None: + _set_counter(_APP, _U2F_COUNTER, cntr) def wipe(): diff --git a/tests/test_apps.common.storage.py b/tests/test_apps.common.storage.py new file mode 100644 index 0000000000..f530c0b1c1 --- /dev/null +++ b/tests/test_apps.common.storage.py @@ -0,0 +1,21 @@ +from common import * +from trezor.pin import pin_to_int +from trezor import config +from apps.common import storage + + +class TestConfig(unittest.TestCase): + + def test_counter(self): + config.init() + config.wipe() + self.assertEqual(config.unlock(pin_to_int('')), True) + for i in range(150): + self.assertEqual(storage.next_u2f_counter(), i) + storage.set_u2f_counter(350) + for i in range(351, 500): + self.assertEqual(storage.next_u2f_counter(), i) + + +if __name__ == '__main__': + unittest.main()