mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
tests(core): add thp cache tests, clean thp tests, add message to set_bool assertion error
[no changelog]
This commit is contained in:
parent
dddf926442
commit
ef33422ab3
@ -105,7 +105,7 @@ class DataCache:
|
||||
|
||||
def set_bool(self, key: int, value: bool) -> None:
|
||||
utils.ensure(
|
||||
self._get_length(key) == 0
|
||||
self._get_length(key) == 0, "Field does not have zero length!"
|
||||
) # skipping get_length in production build
|
||||
if value:
|
||||
self.set(key, b"")
|
||||
|
@ -1,17 +1,22 @@
|
||||
from common import * # isort:skip # noqa: F403
|
||||
|
||||
from mock_storage import mock_storage
|
||||
|
||||
from storage import cache, cache_codec, cache_thp
|
||||
from trezor.messages import Initialize
|
||||
from trezor.messages import EndSession
|
||||
from storage import cache, cache_codec
|
||||
from trezor.messages import EndSession, Initialize
|
||||
|
||||
from apps.base import handle_EndSession, handle_Initialize
|
||||
|
||||
KEY = 0
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
from mock_wire_interface import MockHID
|
||||
from storage import cache_thp
|
||||
from trezor.wire.thp import ChannelState
|
||||
from trezor.wire.thp.session_context import ManagementSessionContext, SessionContext
|
||||
|
||||
_PROTOCOL_CACHE = cache_thp
|
||||
|
||||
else:
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
@ -28,7 +33,284 @@ class TestStorageCache(
|
||||
def setUp(self):
|
||||
cache.clear_all()
|
||||
|
||||
if not utils.USE_THP:
|
||||
if utils.USE_THP:
|
||||
|
||||
def __init__(self):
|
||||
thp_common.suppres_debug_log()
|
||||
# xthp_common.prepare_context()
|
||||
# config.init()
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
cache.clear_all()
|
||||
|
||||
def test_new_channel_and_session(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
# Assert that channel is created with one management session
|
||||
self.assertEqual(len(channel.sessions), 1)
|
||||
self.assertIsInstance(channel.sessions[0], ManagementSessionContext)
|
||||
|
||||
cid_1 = channel.channel_id
|
||||
session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1 = SessionContext(channel, session_cache_1)
|
||||
self.assertEqual(session_1.channel_id, cid_1)
|
||||
|
||||
session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2 = SessionContext(channel, session_cache_2)
|
||||
self.assertEqual(session_2.channel_id, cid_1)
|
||||
self.assertEqual(session_1.channel_id, session_2.channel_id)
|
||||
self.assertNotEqual(session_1.session_id, session_2.session_id)
|
||||
|
||||
channel_2 = thp_common.get_new_channel(self.interface)
|
||||
cid_2 = channel_2.channel_id
|
||||
self.assertNotEqual(cid_1, cid_2)
|
||||
|
||||
session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
|
||||
session_3 = SessionContext(channel_2, session_cache_3)
|
||||
self.assertEqual(session_3.channel_id, cid_2)
|
||||
|
||||
# Sessions 1 and 3 should have different channel_id, but the same session_id
|
||||
self.assertNotEqual(session_1.channel_id, session_3.channel_id)
|
||||
self.assertEqual(session_1.session_id, session_3.session_id)
|
||||
|
||||
self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
|
||||
self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
|
||||
|
||||
# Check that session data IS in cache for created sessions ONLY
|
||||
for i in range(3):
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
|
||||
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
|
||||
self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
|
||||
self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
|
||||
self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
|
||||
|
||||
# Check that session data IS NOT in cache after cache.clear_all()
|
||||
cache.clear_all()
|
||||
for session in cache_thp._SESSIONS:
|
||||
self.assertEqual(session.channel_id, b"")
|
||||
self.assertEqual(session.session_id, b"")
|
||||
self.assertEqual(session.last_usage, 0)
|
||||
self.assertEqual(session.state, b"\x00")
|
||||
|
||||
def test_channel_capacity_in_cache(self):
|
||||
self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
|
||||
channels = []
|
||||
for i in range(cache_thp._MAX_CHANNELS_COUNT):
|
||||
channels.append(thp_common.get_new_channel(self.interface))
|
||||
channel_ids = [channel.channel_cache.channel_id for channel in channels]
|
||||
|
||||
# Assert that each channel_id is unique and that cache and list of channels
|
||||
# have the same "channels" on the same indexes
|
||||
for i in range(len(channel_ids)):
|
||||
self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
|
||||
for j in range(i + 1, len(channel_ids)):
|
||||
self.assertNotEqual(channel_ids[i], channel_ids[j])
|
||||
|
||||
# Create a new channel that is over the capacity
|
||||
new_channel = thp_common.get_new_channel(self.interface)
|
||||
for c in channels:
|
||||
self.assertNotEqual(c.channel_id, new_channel.channel_id)
|
||||
|
||||
# Test that the oldest (least used) channel was replaced (_CHANNELS[0])
|
||||
self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
|
||||
self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
|
||||
|
||||
# Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
|
||||
# assert that it is not replaced when creating a new channel
|
||||
cache_thp.update_channel_last_used(channel_ids[1])
|
||||
new_new_channel = thp_common.get_new_channel(self.interface)
|
||||
self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
|
||||
|
||||
# Assert that it was in fact the _CHANNEL[2] that was replaced
|
||||
self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
|
||||
self.assertEqual(
|
||||
cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
|
||||
)
|
||||
|
||||
def test_session_capacity_in_cache(self):
|
||||
self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
|
||||
channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
|
||||
channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
|
||||
|
||||
sesions_A = []
|
||||
cid = []
|
||||
sid = []
|
||||
for i in range(3):
|
||||
sesions_A.append(cache_thp.get_new_session(channel_cache_A))
|
||||
cid.append(sesions_A[i].channel_id)
|
||||
sid.append(sesions_A[i].session_id)
|
||||
|
||||
sessions_B = []
|
||||
for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
|
||||
sessions_B.append(cache_thp.get_new_session(channel_cache_B))
|
||||
|
||||
for i in range(3):
|
||||
self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
|
||||
self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
|
||||
self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
|
||||
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
|
||||
self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
|
||||
|
||||
# Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
|
||||
new_session = cache_thp.get_new_session(channel_cache_B)
|
||||
self.assertEqual(new_session, cache_thp._SESSIONS[0])
|
||||
self.assertNotEqual(new_session.channel_id, cid[0])
|
||||
self.assertNotEqual(new_session.session_id, sid[0])
|
||||
|
||||
# Assert that updating "last used" for session on channel A increases also
|
||||
# the "last usage" of channel A.
|
||||
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
|
||||
cache_thp.update_session_last_used(
|
||||
channel_cache_A.channel_id, sesions_A[1].session_id
|
||||
)
|
||||
self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
|
||||
|
||||
new_new_session = cache_thp.get_new_session(channel_cache_B)
|
||||
|
||||
# Assert that creating a new session on channel B shifts the "last usage" again
|
||||
# and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
|
||||
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
|
||||
self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
|
||||
self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
|
||||
self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
|
||||
|
||||
def test_clear(self):
|
||||
channel_A = thp_common.get_new_channel(self.interface)
|
||||
channel_B = thp_common.get_new_channel(self.interface)
|
||||
cid_A = channel_A.channel_id
|
||||
cid_B = channel_B.channel_id
|
||||
sessions = []
|
||||
|
||||
for i in range(3):
|
||||
sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
|
||||
sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
|
||||
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
|
||||
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
|
||||
|
||||
# Assert that clearing of channel A works
|
||||
self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
|
||||
self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
|
||||
self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
|
||||
|
||||
channel_A.clear()
|
||||
|
||||
self.assertEqual(channel_A.channel_cache.channel_id, b"")
|
||||
self.assertEqual(channel_A.channel_cache.last_usage, 0)
|
||||
self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
|
||||
|
||||
# Assert that clearing channel A also cleared all its sessions
|
||||
for i in range(3):
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
|
||||
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
|
||||
|
||||
cache.clear_all()
|
||||
for session in cache_thp._SESSIONS:
|
||||
self.assertEqual(session.last_usage, 0)
|
||||
self.assertEqual(session.channel_id, b"")
|
||||
for channel in cache_thp._CHANNELS:
|
||||
self.assertEqual(channel.channel_id, b"")
|
||||
self.assertEqual(channel.last_usage, 0)
|
||||
self.assertEqual(
|
||||
cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
|
||||
)
|
||||
|
||||
def test_get_set(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1.set(KEY, b"hello")
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.set(KEY, b"world")
|
||||
self.assertEqual(session_2.get(KEY), b"world")
|
||||
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
|
||||
def test_get_set_int(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1.set_int(KEY, 1234)
|
||||
|
||||
self.assertEqual(session_1.get_int(KEY), 1234)
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.set_int(KEY, 5678)
|
||||
self.assertEqual(session_2.get_int(KEY), 5678)
|
||||
|
||||
self.assertEqual(session_1.get_int(KEY), 1234)
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(session_1.get_int(KEY))
|
||||
self.assertIsNone(session_2.get_int(KEY))
|
||||
|
||||
def test_get_set_bool(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
with self.assertRaises(AssertionError) as e:
|
||||
session_1.set_bool(KEY, True)
|
||||
self.assertEqual(e.value.value, "Field does not have zero length!")
|
||||
|
||||
# Change length of first session field to 0 so that the length check passes
|
||||
session_1.fields = (0,) + session_1.fields[1:]
|
||||
|
||||
# with self.assertRaises(AssertionError) as e:
|
||||
session_1.set_bool(KEY, True)
|
||||
self.assertEqual(session_1.get_bool(KEY), True)
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
|
||||
session_2.set_bool(KEY, False)
|
||||
self.assertEqual(session_2.get_bool(KEY), False)
|
||||
|
||||
self.assertEqual(session_1.get_bool(KEY), True)
|
||||
|
||||
cache.clear_all()
|
||||
|
||||
# Default value is False
|
||||
self.assertFalse(session_1.get_bool(KEY))
|
||||
self.assertFalse(session_2.get_bool(KEY))
|
||||
|
||||
def test_delete(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
session_1.set(KEY, b"hello")
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
session_1.delete(KEY)
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
|
||||
session_1.set(KEY, b"hello")
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
session_2.set(KEY, b"hello")
|
||||
self.assertEqual(session_2.get(KEY), b"hello")
|
||||
session_2.delete(KEY)
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
else:
|
||||
|
||||
def __init__(self):
|
||||
# Context is needed to test decorators and handleInitialize
|
||||
@ -102,7 +384,7 @@ class TestStorageCache(
|
||||
cache_codec.start_session(session_id1)
|
||||
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
|
||||
|
||||
cache_codec.clear_all()
|
||||
cache.clear_all()
|
||||
self.assertIsNone(cache_codec.get_active_session())
|
||||
|
||||
def test_get_set_int(self):
|
||||
@ -119,7 +401,7 @@ class TestStorageCache(
|
||||
cache_codec.start_session(session_id1)
|
||||
self.assertEqual(get_active_session().get_int(KEY), 1234)
|
||||
|
||||
cache_codec.clear_all()
|
||||
cache.clear_all()
|
||||
self.assertIsNone(get_active_session())
|
||||
|
||||
def test_delete(self):
|
||||
|
@ -7,6 +7,8 @@ from trezor.wire.errors import UnexpectedMessage
|
||||
from trezor.wire.protocol_common import Message
|
||||
|
||||
if utils.USE_THP:
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import thp_common
|
||||
from storage import cache_thp
|
||||
from storage.cache_common import (
|
||||
@ -27,15 +29,56 @@ if utils.USE_THP:
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
from trezor.wire import thp_main
|
||||
from trezor.wire.thp import ChannelState, interface_manager
|
||||
from trezor.wire.thp import ChannelState, checksum, interface_manager
|
||||
from trezor.wire.thp.crypto import Handshake
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
|
||||
from apps.thp import pairing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
def get_dummy_key() -> bytes:
|
||||
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
|
||||
def get_dummy_key() -> bytes:
|
||||
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
|
||||
|
||||
def send_channel_allocation_request(
|
||||
interface: WireInterface, nonce: bytes | None = None
|
||||
) -> bytes:
|
||||
if nonce is None or len(nonce) != 8:
|
||||
nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
|
||||
header = b"\x40\xff\xff\x00\x0c"
|
||||
chksum = checksum.compute(header + nonce)
|
||||
cid_req = header + nonce + chksum
|
||||
gen = thp_main.thp_main_loop(interface, is_debug_session=True)
|
||||
gen.send(None)
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
response_data = (
|
||||
b"\x0a\x04\x54\x32\x54\x31\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
|
||||
)
|
||||
response_without_crc = (
|
||||
b"\x41\xff\xff\x00\x20"
|
||||
+ nonce
|
||||
+ cache_thp.cid_counter.to_bytes(2, "big")
|
||||
+ response_data
|
||||
)
|
||||
chkcsum = checksum.compute(response_without_crc)
|
||||
expected_response = response_without_crc + chkcsum + b"\x00" * 27
|
||||
return expected_response
|
||||
|
||||
def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
|
||||
return int.from_bytes(channel_allocation_response[13:15], "big")
|
||||
|
||||
def get_ack(channel_id: bytes) -> bytes:
|
||||
if len(channel_id) != 2:
|
||||
raise Exception("Channel id should by two bytes long")
|
||||
return (
|
||||
b"\x20"
|
||||
+ channel_id
|
||||
+ b"\x00\x04"
|
||||
+ checksum.compute(b"\x20" + channel_id + b"\x00\x04")
|
||||
+ b"\x00" * 55
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
@ -52,34 +95,33 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
thp_main.set_buffer(buffer)
|
||||
interface_manager.decode_iface = thp_common.dummy_decode_iface
|
||||
|
||||
def test_channel_allocation(self):
|
||||
cid_req = (
|
||||
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
)
|
||||
expected_response = "41ffff0020001122334455667712340a045432543110001800200228022803280428af9907000000000000000000000000000000000000000000000000000000"
|
||||
test_counter = cache_thp.cid_counter + 1
|
||||
self.assertEqual(len(thp_main._CHANNELS), 0)
|
||||
self.assertFalse(test_counter in thp_main._CHANNELS)
|
||||
def test_codec_message(self):
|
||||
self.assertEqual(len(self.interface.data), 0)
|
||||
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
self.assertEqual(
|
||||
utils.get_bytes_as_str(self.interface.data[-1]),
|
||||
expected_response,
|
||||
|
||||
# There should be a failiure response to received init packet (starts with "?##")
|
||||
test_codec_message = b"?## Some data"
|
||||
gen.send(test_codec_message)
|
||||
gen.send(None)
|
||||
self.assertEqual(len(self.interface.data), 1)
|
||||
|
||||
expected_response = (
|
||||
b"?##\x00\x03\x00\x00\x00\x14\x08\x01\x12\x10Invalid protocol"
|
||||
)
|
||||
self.assertEqual(
|
||||
self.interface.data[-1][: len(expected_response)], expected_response
|
||||
)
|
||||
self.assertTrue(test_counter in thp_main._CHANNELS)
|
||||
self.assertEqual(len(thp_main._CHANNELS), 1)
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
|
||||
def test_channel_default_state_is_TH1(self):
|
||||
self.assertEqual(thp_main._CHANNELS[4660].get_channel_state(), ChannelState.TH1)
|
||||
# There should be no response for continuation packet (starts with "?" only)
|
||||
test_codec_message_2 = b"? Cont packet"
|
||||
gen.send(test_codec_message_2)
|
||||
with self.assertRaises(TypeError) as e:
|
||||
gen.send(None)
|
||||
self.assertEqual(e.value.value, "object with buffer protocol required")
|
||||
self.assertEqual(len(self.interface.data), 1)
|
||||
|
||||
def test_channel_errors(self):
|
||||
def test_message_on_unallocated_channel(self):
|
||||
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
@ -93,31 +135,118 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
utils.get_bytes_as_str(self.interface.data[-1]),
|
||||
unallocated_chanel_error_on_channel_789a,
|
||||
)
|
||||
|
||||
def test_channel_allocation(self):
|
||||
test_counter = cache_thp.cid_counter + 1
|
||||
self.assertEqual(len(thp_main._CHANNELS), 0)
|
||||
self.assertFalse(test_counter in thp_main._CHANNELS)
|
||||
|
||||
expected_response = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-1], expected_response)
|
||||
|
||||
self.assertTrue(test_counter in thp_main._CHANNELS)
|
||||
self.assertEqual(len(thp_main._CHANNELS), 1)
|
||||
|
||||
# test channel's default state is TH1:
|
||||
cid = get_channel_id_from_response(self.interface.data[-1])
|
||||
self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
|
||||
|
||||
def test_invalid_encrypted_tag(self):
|
||||
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
|
||||
gen.send(None)
|
||||
# prepare 2 new channels
|
||||
expected_response_1 = send_channel_allocation_request(self.interface)
|
||||
expected_response_2 = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-2], expected_response_1)
|
||||
self.assertEqual(self.interface.data[-1], expected_response_2)
|
||||
|
||||
# test invalid encryption tag
|
||||
config.init()
|
||||
config.wipe()
|
||||
channel = thp_main._CHANNELS[4661]
|
||||
cid_1 = get_channel_id_from_response(expected_response_1)
|
||||
channel = thp_main._CHANNELS[cid_1]
|
||||
channel.iface = self.interface
|
||||
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
message_with_invalid_tag = b"\x04\x12\x35\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\xe1\xfc\xc6\xe0"
|
||||
header = b"\x04" + channel.channel_id + b"\x00\x14"
|
||||
|
||||
tag = b"\x00" * 16
|
||||
chksum = checksum.compute(header + tag)
|
||||
message_with_invalid_tag = header + tag + chksum
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
|
||||
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
|
||||
expected_ack_on_received_message = get_ack(cid_1_bytes)
|
||||
|
||||
gen.send(message_with_invalid_tag)
|
||||
gen.send(None)
|
||||
ack_on_received_message = "2012350004d83ea46f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
|
||||
self.assertEqual(
|
||||
utils.get_bytes_as_str(self.interface.data[-1]),
|
||||
ack_on_received_message,
|
||||
self.interface.data[-1],
|
||||
expected_ack_on_received_message,
|
||||
)
|
||||
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
|
||||
chksum_err = checksum.compute(error_without_crc)
|
||||
gen.send(None)
|
||||
decryption_failed_error_on_channel_1235 = "421235000503caf9634a000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
|
||||
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
|
||||
|
||||
self.assertEqual(
|
||||
utils.get_bytes_as_str(self.interface.data[-1]),
|
||||
decryption_failed_error_on_channel_1235,
|
||||
self.interface.data[-1],
|
||||
decryption_failed_error,
|
||||
)
|
||||
|
||||
channel = thp_main._CHANNELS[4662]
|
||||
def test_channel_errors(self):
|
||||
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
|
||||
gen.send(None)
|
||||
# prepare 2 new channels
|
||||
expected_response_1 = send_channel_allocation_request(self.interface)
|
||||
expected_response_2 = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-2], expected_response_1)
|
||||
self.assertEqual(self.interface.data[-1], expected_response_2)
|
||||
|
||||
# test invalid encryption tag
|
||||
config.init()
|
||||
config.wipe()
|
||||
cid_1 = get_channel_id_from_response(expected_response_1)
|
||||
channel = thp_main._CHANNELS[cid_1]
|
||||
channel.iface = self.interface
|
||||
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
header = b"\x04" + channel.channel_id + b"\x00\x14"
|
||||
|
||||
tag = b"\x00" * 16
|
||||
chksum = checksum.compute(header + tag)
|
||||
message_with_invalid_tag = header + tag + chksum
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
|
||||
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
|
||||
expected_ack_on_received_message = get_ack(cid_1_bytes)
|
||||
|
||||
gen.send(message_with_invalid_tag)
|
||||
gen.send(None)
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
expected_ack_on_received_message,
|
||||
)
|
||||
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
|
||||
chksum_err = checksum.compute(error_without_crc)
|
||||
gen.send(None)
|
||||
|
||||
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
decryption_failed_error,
|
||||
)
|
||||
|
||||
# test invalid tag in handshake phase
|
||||
cid_2 = get_channel_id_from_response(expected_response_1)
|
||||
cid_2_bytes = cid_2.to_bytes(2, "big")
|
||||
channel = thp_main._CHANNELS[cid_2]
|
||||
channel.iface = self.interface
|
||||
|
||||
channel.set_channel_state(ChannelState.TH2)
|
||||
@ -126,13 +255,12 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
|
||||
|
||||
# gen.send(message_with_invalid_tag)
|
||||
# gen.send(None)
|
||||
# gen.send(None)
|
||||
for i in self.interface.data:
|
||||
print(utils.get_bytes_as_str(i))
|
||||
# for i in self.interface.data:
|
||||
# print(utils.get_bytes_as_str(i))
|
||||
|
||||
def test_skip_pairing(self):
|
||||
config.init()
|
||||
@ -159,7 +287,10 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
def test_pairing(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
channel = thp_main._CHANNELS[4660]
|
||||
cid = get_channel_id_from_response(
|
||||
send_channel_allocation_request(self.interface)
|
||||
)
|
||||
channel = thp_main._CHANNELS[cid]
|
||||
channel.selected_pairing_methods = [
|
||||
ThpPairingMethod.CodeEntry,
|
||||
ThpPairingMethod.NFC_Unidirectional,
|
||||
@ -209,8 +340,7 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
user_message = Message(MessageType.ThpCodeEntryCpaceHost, buffer)
|
||||
gen.send(user_message)
|
||||
|
||||
tag_ent = b"\x56\x34\xc5\x36\x60\xcc\x75\xbc\x58\x24\x76\x87\x74\xd2\x5f\x48\x80\xc0\x8c\x65\xab\x00\xe9\xf7\x0e\xb0\x10\x15\xe5\x8b\x4f\x6a"
|
||||
|
||||
tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
|
||||
msg = ThpCodeEntryTag(tag=tag_ent)
|
||||
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
|
||||
|
@ -1,6 +1,10 @@
|
||||
from trezor import utils
|
||||
from trezor.wire.thp import ChannelState
|
||||
|
||||
if utils.USE_THP:
|
||||
import unittest
|
||||
from typing import TYPE_CHECKING, Any, Awaitable
|
||||
|
||||
from mock_wire_interface import MockHID
|
||||
from storage import cache_thp
|
||||
from trezor.wire import context
|
||||
@ -9,14 +13,24 @@ if utils.USE_THP:
|
||||
from trezor.wire.thp.interface_manager import _MOCK_INTERFACE_HID
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
def dummy_decode_iface(cached_iface: bytes):
|
||||
return MockHID(0xDEADBEEF)
|
||||
|
||||
def prepare_context() -> None:
|
||||
def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
|
||||
interface_manager.decode_iface = dummy_decode_iface
|
||||
channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
|
||||
channel = Channel(channel_cache)
|
||||
session_cache = cache_thp.get_new_session(channel_cache)
|
||||
channel.set_channel_state(ChannelState.TH1)
|
||||
if channel_iface is not None:
|
||||
channel.iface = channel_iface
|
||||
return channel
|
||||
|
||||
def prepare_context() -> None:
|
||||
channel = get_new_channel()
|
||||
session_cache = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_ctx = SessionContext(channel, session_cache)
|
||||
context.CURRENT_CONTEXT = session_ctx
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user