Compare commits

...

6 Commits

@ -231,6 +231,8 @@ trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler import trezor.wire.thp.received_message_handler
trezor.wire.thp.retransmission
import trezor.wire.thp.retransmission
trezor.wire.thp.session_context trezor.wire.thp.session_context
import trezor.wire.thp.session_context import trezor.wire.thp.session_context
trezor.wire.thp.session_manager trezor.wire.thp.session_manager

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from trezor import log, loop from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import ChannelContext, SessionState from trezor.wire.thp import SessionState
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
async def create_new_session( async def create_new_session(

@ -1,4 +1,4 @@
from trezor import log, protobuf from trezor import protobuf
from trezor.enums import MessageType, ThpPairingMethod from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import ( from trezor.messages import (
ThpCodeEntryChallenge, ThpCodeEntryChallenge,
@ -25,6 +25,9 @@ from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers # TODO implement the following handlers
if __debug__:
from trezor import log
async def handle_pairing_request( async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType

@ -1,5 +1,5 @@
from micropython import const # pyright: ignore[reportMissingModuleSource] from micropython import const
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow from trezor import log, loop, protobuf, utils, workflow

@ -3,12 +3,43 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING: if TYPE_CHECKING:
from enum import IntEnum from enum import IntEnum
from trezorio import WireInterface from trezorio import WireInterface
from typing import Protocol
from storage.cache_thp import ChannelCache from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils from trezor import loop, protobuf, utils
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext from trezor.wire.thp.session_context import SessionContext
class ChannelContext(Protocol):
buffer: utils.BufferType
iface: WireInterface
channel_id: bytes
channel_cache: ChannelCache
selected_pairing_methods = [] # TODO add type
sessions: dict[int, SessionContext]
waiting_for_ack_timeout: loop.spawn | None
write_task_spawn: loop.spawn | None
connection_context: PairingContext | None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: "ChannelState") -> None: ...
async def write(
self, msg: protobuf.MessageType, session_id: int = 0
) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(
self, ctrl_byte: int, payload: bytes
) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int: ...
else: else:
IntEnum = object IntEnum = object
@ -36,29 +67,6 @@ class WireInterfaceType(IntEnum):
BLE = 2 BLE = 2
class ChannelContext:
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
self.buffer: utils.BufferType
self.iface: WireInterface = iface
self.channel_id: bytes = channel_cache.channel_id
self.channel_cache: ChannelCache = channel_cache
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: ChannelState) -> None: ...
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def is_channel_state_pairing(state: int) -> bool: def is_channel_state_pairing(state: int) -> bool:
if state in ( if state in (
ChannelState.TP1, ChannelState.TP1,

@ -6,14 +6,7 @@ from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.wire.thp import interface_manager, received_message_handler from trezor.wire.thp import interface_manager, received_message_handler
from . import ( from . import ChannelState, checksum, control_byte, crypto, memory_manager
ChannelContext,
ChannelState,
checksum,
control_byte,
crypto,
memory_manager,
)
from . import thp_session as THP from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
@ -31,19 +24,32 @@ if __debug__:
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from . import ChannelContext, PairingContext
from .session_context import SessionContext
else:
ChannelContext = object
class Channel(ChannelContext): class Channel:
def __init__(self, channel_cache: ChannelCache) -> None: def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__: if __debug__:
log.debug(__name__, "channel initialization") log.debug(__name__, "channel initialization")
iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache) self.channel_cache: ChannelCache = channel_cache
self.channel_cache = channel_cache
self.is_cont_packet_expected: bool = False self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0
self.bytes_read: int = 0 self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
# ACCESS TO CHANNEL_DATA # ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int: def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big") state = int.from_bytes(self.channel_cache.state, "big")
@ -168,7 +174,7 @@ class Channel(ChannelContext):
if __debug__: if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = memory_manager.encode_into_buffer( noise_payload_len = memory_manager.encode_into_buffer(
memoryview(self.buffer), msg, session_id self.buffer, msg, session_id
) )
await self.write_and_encrypt(self.buffer[:noise_payload_len]) await self.write_and_encrypt(self.buffer[:noise_payload_len])

@ -46,7 +46,7 @@ def select_buffer(
def encode_into_buffer( def encode_into_buffer(
buffer: memoryview, msg: protobuf.MessageType, session_id: int buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int: ) -> int:
# cannot write message without wire type # cannot write message without wire type
@ -58,7 +58,7 @@ def encode_into_buffer(
if required_min_size > len(buffer): if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer # message is too big, we need to allocate a new buffer
buffer = memoryview(bytearray(required_min_size)) buffer = bytearray(required_min_size)
_encode_session_into_buffer(memoryview(buffer), session_id) _encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer( _encode_message_type_into_buffer(

@ -1,19 +1,23 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, workflow from trezor import loop, protobuf, workflow
from trezor.wire import context, message_handler, protocol_common from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, MessageWithType from trezor.wire.protocol_common import Context, MessageWithType
from . import ChannelContext
from .session_context import UnexpectedMessageWithType from .session_context import UnexpectedMessageWithType
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports] from typing import Container
from . import ChannelContext
pass pass
if __debug__:
from trezor import log
class PairingContext(Context): class PairingContext(Context):
def __init__(self, channel_ctx: ChannelContext) -> None: def __init__(self, channel_ctx: ChannelContext) -> None:

@ -19,7 +19,6 @@ from trezor.wire.thp.thp_messages import (
) )
from . import ( from . import (
ChannelContext,
ChannelState, ChannelState,
SessionState, SessionState,
checksum, checksum,
@ -33,6 +32,8 @@ from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from . import ChannelContext
if __debug__: if __debug__:
from . import state_to_str from . import state_to_str

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
class Retransmission:
def __init__(
self, channel_context: ChannelContext, ctrl_byte: int, payload: memoryview
) -> None:
self.channel_context: ChannelContext = channel_context
self.ctrl_byte: int = ctrl_byte
self.payload: memoryview = payload
def start(self):
pass
def stop(self):
pass
def change_ctrl_byte(self, ctrl_byte: int) -> None:
self.ctrl_byte = ctrl_byte

@ -6,7 +6,7 @@ from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import ChannelContext, SessionState from . import SessionState
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import ( # pyright: ignore[reportShadowedImports] from typing import ( # pyright: ignore[reportShadowedImports]
@ -15,6 +15,8 @@ if TYPE_CHECKING:
Container, Container,
) )
from . import ChannelContext
pass pass
_EXIT_LOOP = True _EXIT_LOOP = True

@ -1,7 +1,15 @@
from typing import TYPE_CHECKING
from storage import cache_thp from storage import cache_thp
from trezor import log, loop from trezor import loop
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import SessionContext from .session_context import SessionContext
if __debug__:
from trezor import log
if TYPE_CHECKING:
from . import ChannelContext
def create_new_session(channel_ctx: ChannelContext) -> SessionContext: def create_new_session(channel_ctx: ChannelContext) -> SessionContext:

@ -19,7 +19,7 @@ _WEBAUTHN_PORT_OFFSET = const(2)
_VCP_PORT_OFFSET = const(3) _VCP_PORT_OFFSET = const(3)
if utils.EMULATOR: if utils.EMULATOR:
import uos # pyright: ignore[reportMissingModuleSource] import uos
UDP_PORT = int(uos.getenv("TREZOR_UDP_PORT") or "21324") UDP_PORT = int(uos.getenv("TREZOR_UDP_PORT") or "21324")

@ -0,0 +1,71 @@
from common import *
from trezor import io
from trezor.loop import wait
from trezor.wire import thp_v1
from trezor.wire.thp import channel
from storage import cache_thp
from ubinascii import hexlify
from trezor.wire.thp import ChannelState
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def getBytes(a):
return hexlify(a).decode("utf-8")
class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_v1.set_buffer(buffer)
channel._decode_iface = dummy_decode_iface
def test_simple(self):
self.assertTrue(True)
def test_channel_allocation(self):
cid_req = (
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_v1.CHANNELS), 0)
self.assertFalse(test_counter in thp_v1.CHANNELS)
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req)
gen.send(None)
self.assertEqual(
getBytes(self.interface.data[-1]),
expected_response,
)
self.assertTrue(test_counter in thp_v1.CHANNELS)
self.assertEqual(len(thp_v1.CHANNELS), 1)
def test_channel_default_state_is_TH1(self):
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
if __name__ == "__main__":
unittest.main()

@ -1,12 +1,13 @@
from common import * from common import *
from ubinascii import hexlify, unhexlify from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor.wire.thp.writer import REPORT_LENGTH
from ubinascii import hexlify
import ustruct import ustruct
from trezor import io, utils from trezor import io, utils
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import thp_v1 from trezor.wire import thp_v1
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import MessageWithId from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp.thp_session as THP import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum from trezor.wire.thp import checksum
@ -39,9 +40,7 @@ CONT = 0x80
HEADER_INIT_LENGTH = 5 HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3 HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = ( INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
def make_header(ctrl_byte, cid, length): def make_header(ctrl_byte, cid, length):
@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
if not utils.USE_THP: if not utils.USE_THP:
import storage.cache_thp # noqa: F401 import storage.cache_thp # noqa: F401
def test_simple(self): def _simple(self):
cid_req_header = make_header( cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
) )
@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header) self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self): def _read_one_packet(self):
# zero length message - just a header # zero length message - just a header
PLAINTEXT = getPlaintext() PLAINTEXT = getPlaintext()
header = make_header( header = make_header(
@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self): def _read_many_packets(self):
message = bytes(range(256)) message = bytes(range(256))
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer ) # message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self): def _read_large_message(self):
message = b"hello world" message = b"hello world"
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours # read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00") self.assertEqual(buffer, b"\x00")
def test_roundtrip(self): def _roundtrip(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data) self.assertEqual(result.data, message.data)
def test_write_one_packet(self): def _write_one_packet(self):
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
) )
@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
) )
self.assertTrue(self.interface.data == [expected_message]) self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self): def _write_multiple_packets(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :] last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1]) self.assertEqual(last_packet, self.interface.data[-1])
def test_read_huge_packet(self): def _read_huge_packet(self):
PACKET_COUNT = 1180 PACKET_COUNT = 1180
# message that takes up 1 180 USB packets # message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * ( message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
- HEADER_CONT_LENGTH
- CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH ) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory # ensure that a message this big won't fit into memory

@ -2,6 +2,10 @@ import sys
from trezor.utils import ensure from trezor.utils import ensure
DEFAULT_COLOR = "\033[0m"
ERROR_COLOR = "\033[31m"
OK_COLOR = "\033[32m"
class SkipTest(Exception): class SkipTest(Exception):
pass pass
@ -252,16 +256,16 @@ def run_class(c, test_result):
raise RuntimeError(f"{name} should not return a result.") raise RuntimeError(f"{name} should not return a result.")
finally: finally:
tear_down() tear_down()
print("\033[32mok\033[0m") print(f"{OK_COLOR} ok{DEFAULT_COLOR}")
except SkipTest as e: except SkipTest as e:
print(" skipped:", e.args[0]) print(" skipped:", e.args[0])
test_result.skippedNum += 1 test_result.skippedNum += 1
except AssertionError as e: except AssertionError as e:
print("\033[31mfailed\033[0m") print(f"{ERROR_COLOR} failed{DEFAULT_COLOR}")
sys.print_exception(e) sys.print_exception(e)
test_result.failuresNum += 1 test_result.failuresNum += 1
except BaseException as e: except BaseException as e:
print("\033[31merrored:\033[0m", e) print(f"{ERROR_COLOR} errored:{DEFAULT_COLOR}", e)
sys.print_exception(e) sys.print_exception(e)
test_result.errorsNum += 1 test_result.errorsNum += 1

Loading…
Cancel
Save