Implement Pairing flow

M1nd3r/thp5
M1nd3r 1 month ago
parent 54661fb5f9
commit 5f4c9f5666

@ -62,7 +62,7 @@ message ThpNewSession{
* @next ThpNfcUnidirectionalTag // Sent by the Host
*/
message ThpStartPairingRequest{
optional bytes host_name = 1; // Human-readable host name
optional string host_name = 1; // Human-readable host name
}
/**

@ -1,33 +1,31 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import log
from trezor.enums import ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpNfcUnideirectionalSecret,
ThpNfcUnidirectionalTag,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.thp_session import ThpError
if TYPE_CHECKING:
from trezor.enums import ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpNfcUnideirectionalSecret,
ThpNfcUnidirectionalTag,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
# TODO implement the following handlers
async def handle_pairing_request(
channel: Channel, message: ThpStartPairingRequest
) -> ThpCodeEntryCommitment | None:
if __debug__:
log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1)
if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry):
channel.set_channel_state(ChannelState.TP2)

@ -6167,12 +6167,12 @@ if TYPE_CHECKING:
return isinstance(msg, cls)
class ThpStartPairingRequest(protobuf.MessageType):
host_name: "bytes | None"
host_name: "str | None"
def __init__(
self,
*,
host_name: "bytes | None" = None,
host_name: "str | None" = None,
) -> None:
pass

@ -7,7 +7,7 @@ from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType, MessageType # , ThpPairingMethod
from trezor.messages import Failure
from trezor.messages import Failure, ThpDeviceProperties
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.handler_provider import get_handler
@ -61,9 +61,7 @@ class Channel(Context):
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.selected_pairing_methods = (
[]
) # TODO better # ThpPairingMethod.PairingMethod_NoMethod
self.selected_pairing_methods = []
from trezor.wire.thp.session_context import load_cached_sessions
self.connection_context = None
@ -300,6 +298,22 @@ class Channel(Context):
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
device_properties = thp_messages.decode_message(
self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpDeviceProperties",
)
if TYPE_CHECKING:
assert isinstance(device_properties, ThpDeviceProperties)
for i in device_properties.pairing_methods:
self.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
@ -308,13 +322,20 @@ class Channel(Context):
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response()
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired=paired),
)
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
# TODO add credential recognition
if paired:
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
self.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
if __debug__:
@ -354,7 +375,11 @@ class Channel(Context):
if self.connection_context is None:
self.connection_context = PairingContext(self)
loop.schedule(self.connection_context.handle())
print("TEST selected methods")
for i in self.selected_pairing_methods:
print("method:", i)
self._decrypt_buffer(message_length)
message_type = ustruct.unpack(

@ -13,7 +13,7 @@ from apps.thp.pairing import handle_pairing_request
from .channel import Channel
if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports]
from typing import Container, Generator # pyright:ignore[reportShadowedImports]
pass
@ -34,9 +34,6 @@ class PairingContext:
take = self.incoming_message.take()
next_message: MessageWithType | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
# TODO modules = utils.unimport_begin()
while True:
try:
if next_message is None:
@ -74,8 +71,6 @@ class PairingContext:
if next_message is None:
# Shut down the loop if there is no next message waiting.
# Let the session be restarted from `main`.
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc:
@ -146,7 +141,6 @@ async def handle_pairing_message(
if TYPE_CHECKING:
assert isinstance(req_msg, ThpStartPairingRequest) # TODO remove
task = handler(ctx.channel, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
@ -154,7 +148,7 @@ async def handle_pairing_message(
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
# res_msg = await workflow.spawn(context.with_context(ctx, task))
res_msg = await workflow.spawn(with_context(ctx, task))
pass # TODO
else:
# For debug messages, ignore workflow processing and just await
@ -187,7 +181,6 @@ async def handle_pairing_message(
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
@ -197,3 +190,34 @@ async def handle_pairing_message(
def get_handler(messageType: int):
return handle_pairing_request
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None

@ -1,7 +1,7 @@
import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf
from trezor import log, protobuf
from .. import message_handler
from ..protocol_common import Message
@ -90,16 +90,27 @@ def get_handshake_init_response() -> bytes:
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
def get_handshake_completion_response() -> bytes:
def get_handshake_completion_response(paired: bool) -> bytes:
if paired:
return (
TREZOR_STATE_PAIRED
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
)
return (
TREZOR_STATE_PAIRED
TREZOR_STATE_UNPAIRED
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
)
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
print("decode message")
expected_type = protobuf.type_for_wire(msg_type)
def decode_message(
buffer: bytes, msg_type: int, message_name: str | None = None
) -> protobuf.MessageType:
if __debug__:
log.debug(__name__, "decode message")
if message_name is not None:
expected_type = protobuf.type_for_name(message_name)
else:
expected_type = protobuf.type_for_wire(msg_type)
x = message_handler.wrap_protobuf_load(buffer, expected_type)
print("result decoded", x)
return x

@ -7804,13 +7804,13 @@ class ThpNewSession(protobuf.MessageType):
class ThpStartPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("host_name", "bytes", repeated=False, required=False, default=None),
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["bytes"] = None,
host_name: Optional["str"] = None,
) -> None:
self.host_name = host_name

@ -835,7 +835,7 @@ impl ::protobuf::reflect::ProtobufValue for ThpNewSession {
pub struct ThpStartPairingRequest {
// message fields
// @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpStartPairingRequest.host_name)
pub host_name: ::std::option::Option<::std::vec::Vec<u8>>,
pub host_name: ::std::option::Option<::std::string::String>,
// special fields
// @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpStartPairingRequest.special_fields)
pub special_fields: ::protobuf::SpecialFields,
@ -852,12 +852,12 @@ impl ThpStartPairingRequest {
::std::default::Default::default()
}
// optional bytes host_name = 1;
// optional string host_name = 1;
pub fn host_name(&self) -> &[u8] {
pub fn host_name(&self) -> &str {
match self.host_name.as_ref() {
Some(v) => v,
None => &[],
None => "",
}
}
@ -870,22 +870,22 @@ impl ThpStartPairingRequest {
}
// Param is passed by value, moved
pub fn set_host_name(&mut self, v: ::std::vec::Vec<u8>) {
pub fn set_host_name(&mut self, v: ::std::string::String) {
self.host_name = ::std::option::Option::Some(v);
}
// Mutable pointer to the field.
// If field is not initialized, it is initialized with default value first.
pub fn mut_host_name(&mut self) -> &mut ::std::vec::Vec<u8> {
pub fn mut_host_name(&mut self) -> &mut ::std::string::String {
if self.host_name.is_none() {
self.host_name = ::std::option::Option::Some(::std::vec::Vec::new());
self.host_name = ::std::option::Option::Some(::std::string::String::new());
}
self.host_name.as_mut().unwrap()
}
// Take field
pub fn take_host_name(&mut self) -> ::std::vec::Vec<u8> {
self.host_name.take().unwrap_or_else(|| ::std::vec::Vec::new())
pub fn take_host_name(&mut self) -> ::std::string::String {
self.host_name.take().unwrap_or_else(|| ::std::string::String::new())
}
fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData {
@ -915,7 +915,7 @@ impl ::protobuf::Message for ThpStartPairingRequest {
while let Some(tag) = is.read_raw_tag_or_eof()? {
match tag {
10 => {
self.host_name = ::std::option::Option::Some(is.read_bytes()?);
self.host_name = ::std::option::Option::Some(is.read_string()?);
},
tag => {
::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?;
@ -930,7 +930,7 @@ impl ::protobuf::Message for ThpStartPairingRequest {
fn compute_size(&self) -> u64 {
let mut my_size = 0;
if let Some(v) = self.host_name.as_ref() {
my_size += ::protobuf::rt::bytes_size(1, &v);
my_size += ::protobuf::rt::string_size(1, &v);
}
my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields());
self.special_fields.cached_size().set(my_size as u32);
@ -939,7 +939,7 @@ impl ::protobuf::Message for ThpStartPairingRequest {
fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> {
if let Some(v) = self.host_name.as_ref() {
os.write_bytes(1, v)?;
os.write_string(1, v)?;
}
os.write_unknown_fields(self.special_fields.unknown_fields())?;
::std::result::Result::Ok(())
@ -3236,9 +3236,9 @@ static file_descriptor_proto_data: &'static [u8] = b"\
\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x1b\n\ton_dev\
ice\x18\x02\x20\x01(\x08R\x08onDevice\"5\n\rThpNewSession\x12$\n\x0enew_\
session_id\x18\x01\x20\x01(\rR\x0cnewSessionId\"5\n\x16ThpStartPairingRe\
quest\x12\x1b\n\thost_name\x18\x01\x20\x01(\x0cR\x08hostName\"8\n\x16Thp\
CodeEntryCommitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitm\
ent\"5\n\x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\
quest\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostName\"8\n\x16ThpCo\
deEntryCommitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitmen\
t\"5\n\x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\
\x0cR\tchallenge\"J\n\x15ThpCodeEntryCpaceHost\x121\n\x15cpace_host_publ\
ic_key\x18\x01\x20\x01(\x0cR\x12cpaceHostPublicKey\"P\n\x17ThpCodeEntryC\
paceTrezor\x125\n\x17cpace_trezor_public_key\x18\x01\x20\x01(\x0cR\x14cp\

Loading…
Cancel
Save