Compare commits

...

6 Commits

@ -231,6 +231,8 @@ trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.retransmission
import trezor.wire.thp.retransmission
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import ChannelContext, SessionState
from trezor.wire.thp import SessionState
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
async def create_new_session(

@ -1,4 +1,4 @@
from trezor import log, protobuf
from trezor import protobuf
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
@ -25,6 +25,9 @@ from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers
if __debug__:
from trezor import log
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType

@ -1,5 +1,5 @@
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow

@ -3,12 +3,43 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
from trezorio import WireInterface
from typing import Protocol
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
class ChannelContext(Protocol):
buffer: utils.BufferType
iface: WireInterface
channel_id: bytes
channel_cache: ChannelCache
selected_pairing_methods = [] # TODO add type
sessions: dict[int, SessionContext]
waiting_for_ack_timeout: loop.spawn | None
write_task_spawn: loop.spawn | None
connection_context: PairingContext | None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: "ChannelState") -> None: ...
async def write(
self, msg: protobuf.MessageType, session_id: int = 0
) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(
self, ctrl_byte: int, payload: bytes
) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int: ...
else:
IntEnum = object
@ -36,29 +67,6 @@ class WireInterfaceType(IntEnum):
BLE = 2
class ChannelContext:
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
self.buffer: utils.BufferType
self.iface: WireInterface = iface
self.channel_id: bytes = channel_cache.channel_id
self.channel_cache: ChannelCache = channel_cache
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: ChannelState) -> None: ...
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,

@ -6,14 +6,7 @@ from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.wire.thp import interface_manager, received_message_handler
from . import (
ChannelContext,
ChannelState,
checksum,
control_byte,
crypto,
memory_manager,
)
from . import ChannelState, checksum, control_byte, crypto, memory_manager
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
@ -31,19 +24,32 @@ if __debug__:
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from . import ChannelContext, PairingContext
from .session_context import SessionContext
else:
ChannelContext = object
class Channel(ChannelContext):
class Channel:
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__:
log.debug(__name__, "channel initialization")
iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache)
self.channel_cache = channel_cache
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
@ -168,7 +174,7 @@ class Channel(ChannelContext):
if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = memory_manager.encode_into_buffer(
memoryview(self.buffer), msg, session_id
self.buffer, msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len])

@ -46,7 +46,7 @@ def select_buffer(
def encode_into_buffer(
buffer: memoryview, msg: protobuf.MessageType, session_id: int
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
@ -58,7 +58,7 @@ def encode_into_buffer(
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = memoryview(bytearray(required_min_size))
buffer = bytearray(required_min_size)
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(

@ -1,19 +1,23 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, workflow
from trezor import loop, protobuf, workflow
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, MessageWithType
from . import ChannelContext
from .session_context import UnexpectedMessageWithType
if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports]
from typing import Container
from . import ChannelContext
pass
if __debug__:
from trezor import log
class PairingContext(Context):
def __init__(self, channel_ctx: ChannelContext) -> None:

@ -19,7 +19,6 @@ from trezor.wire.thp.thp_messages import (
)
from . import (
ChannelContext,
ChannelState,
SessionState,
checksum,
@ -33,6 +32,8 @@ from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from . import ChannelContext
if __debug__:
from . import state_to_str

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
class Retransmission:
def __init__(
self, channel_context: ChannelContext, ctrl_byte: int, payload: memoryview
) -> None:
self.channel_context: ChannelContext = channel_context
self.ctrl_byte: int = ctrl_byte
self.payload: memoryview = payload
def start(self):
pass
def stop(self):
pass
def change_ctrl_byte(self, ctrl_byte: int) -> None:
self.ctrl_byte = ctrl_byte

@ -6,7 +6,7 @@ from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType
from . import ChannelContext, SessionState
from . import SessionState
if TYPE_CHECKING:
from typing import ( # pyright: ignore[reportShadowedImports]
@ -15,6 +15,8 @@ if TYPE_CHECKING:
Container,
)
from . import ChannelContext
pass
_EXIT_LOOP = True

@ -1,7 +1,15 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import log, loop
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import SessionContext
from trezor import loop
from .session_context import SessionContext
if __debug__:
from trezor import log
if TYPE_CHECKING:
from . import ChannelContext
def create_new_session(channel_ctx: ChannelContext) -> SessionContext:

@ -19,7 +19,7 @@ _WEBAUTHN_PORT_OFFSET = const(2)
_VCP_PORT_OFFSET = const(3)
if utils.EMULATOR:
import uos # pyright: ignore[reportMissingModuleSource]
import uos
UDP_PORT = int(uos.getenv("TREZOR_UDP_PORT") or "21324")

@ -0,0 +1,71 @@
from common import *
from trezor import io
from trezor.loop import wait
from trezor.wire import thp_v1
from trezor.wire.thp import channel
from storage import cache_thp
from ubinascii import hexlify
from trezor.wire.thp import ChannelState
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def getBytes(a):
return hexlify(a).decode("utf-8")
class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_v1.set_buffer(buffer)
channel._decode_iface = dummy_decode_iface
def test_simple(self):
self.assertTrue(True)
def test_channel_allocation(self):
cid_req = (
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_v1.CHANNELS), 0)
self.assertFalse(test_counter in thp_v1.CHANNELS)
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req)
gen.send(None)
self.assertEqual(
getBytes(self.interface.data[-1]),
expected_response,
)
self.assertTrue(test_counter in thp_v1.CHANNELS)
self.assertEqual(len(thp_v1.CHANNELS), 1)
def test_channel_default_state_is_TH1(self):
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
if __name__ == "__main__":
unittest.main()

@ -1,12 +1,13 @@
from common import *
from ubinascii import hexlify, unhexlify
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor.wire.thp.writer import REPORT_LENGTH
from ubinascii import hexlify
import ustruct
from trezor import io, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v1
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum
@ -39,9 +40,7 @@ CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = (
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
if not utils.USE_THP:
import storage.cache_thp # noqa: F401
def test_simple(self):
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self):
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self):
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self):
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_roundtrip(self):
def _roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_write_one_packet(self):
def _write_one_packet(self):
message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
)
@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_read_huge_packet(self):
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH
- CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory

@ -2,6 +2,10 @@ import sys
from trezor.utils import ensure
DEFAULT_COLOR = "\033[0m"
ERROR_COLOR = "\033[31m"
OK_COLOR = "\033[32m"
class SkipTest(Exception):
pass
@ -252,16 +256,16 @@ def run_class(c, test_result):
raise RuntimeError(f"{name} should not return a result.")
finally:
tear_down()
print("\033[32mok\033[0m")
print(f"{OK_COLOR} ok{DEFAULT_COLOR}")
except SkipTest as e:
print(" skipped:", e.args[0])
test_result.skippedNum += 1
except AssertionError as e:
print("\033[31mfailed\033[0m")
print(f"{ERROR_COLOR} failed{DEFAULT_COLOR}")
sys.print_exception(e)
test_result.failuresNum += 1
except BaseException as e:
print("\033[31merrored:\033[0m", e)
print(f"{ERROR_COLOR} errored:{DEFAULT_COLOR}", e)
sys.print_exception(e)
test_result.errorsNum += 1

Loading…
Cancel
Save