Compare commits
98 Commits
6a43513bde
...
6f89cb4c07
Author | SHA1 | Date |
---|---|---|
M1nd3r | 6f89cb4c07 | 2 weeks ago |
M1nd3r | dfaf01b86b | 2 weeks ago |
M1nd3r | 416c63df90 | 2 weeks ago |
M1nd3r | 9acb08e985 | 2 weeks ago |
M1nd3r | cf25a0fcc7 | 2 weeks ago |
M1nd3r | 8dceb3bb10 | 2 weeks ago |
M1nd3r | 68e40e4893 | 2 weeks ago |
M1nd3r | e2ec7065d2 | 2 weeks ago |
M1nd3r | 7d4619362b | 2 weeks ago |
M1nd3r | db3bb8bbe0 | 2 weeks ago |
M1nd3r | cf41ec9995 | 2 weeks ago |
M1nd3r | 25e164a2fd | 2 weeks ago |
M1nd3r | 3a86c5f948 | 2 weeks ago |
M1nd3r | 3ebee69ce1 | 2 weeks ago |
M1nd3r | 919e4a544c | 2 weeks ago |
M1nd3r | 22f9f03e9e | 2 weeks ago |
M1nd3r | 06657dbf6e | 2 weeks ago |
M1nd3r | c61531d69b | 2 weeks ago |
M1nd3r | 9b64a3f51f | 2 weeks ago |
M1nd3r | eb6cd2d438 | 2 weeks ago |
M1nd3r | 55153fa245 | 2 weeks ago |
M1nd3r | f64e51f5b2 | 2 weeks ago |
M1nd3r | 329152c1d9 | 2 weeks ago |
M1nd3r | 9e0c6b7ad5 | 2 weeks ago |
M1nd3r | 4d60687be0 | 2 weeks ago |
M1nd3r | fc55f45c0f | 2 weeks ago |
M1nd3r | 30930474b0 | 2 weeks ago |
M1nd3r | 750375c5e3 | 2 weeks ago |
M1nd3r | ac4edb4b17 | 2 weeks ago |
M1nd3r | 01f5fdef73 | 2 weeks ago |
M1nd3r | 03224c95b4 | 2 weeks ago |
M1nd3r | f6563ce34d | 2 weeks ago |
M1nd3r | e3aa720c6a | 2 weeks ago |
M1nd3r | 1eb6d824c1 | 2 weeks ago |
M1nd3r | 2fed976c72 | 2 weeks ago |
M1nd3r | a1ed2444d9 | 2 weeks ago |
M1nd3r | eb7c256106 | 2 weeks ago |
M1nd3r | 057dcd6af8 | 2 weeks ago |
M1nd3r | fc5df2af9f | 2 weeks ago |
M1nd3r | ffe6edc9e9 | 2 weeks ago |
M1nd3r | ab091a8516 | 2 weeks ago |
M1nd3r | b13c9db682 | 2 weeks ago |
M1nd3r | ded9c4ddbe | 2 weeks ago |
M1nd3r | ca788c2437 | 2 weeks ago |
M1nd3r | ff6cf3f4ab | 2 weeks ago |
M1nd3r | 6c57a11ad4 | 2 weeks ago |
M1nd3r | 3ee201d69e | 2 weeks ago |
M1nd3r | 870ce35f15 | 2 weeks ago |
M1nd3r | ea46e27b56 | 2 weeks ago |
M1nd3r | 7fc59c6d18 | 2 weeks ago |
M1nd3r | 62e0bc65ae | 2 weeks ago |
M1nd3r | b1439f2b9d | 2 weeks ago |
M1nd3r | aefa245dc4 | 2 weeks ago |
M1nd3r | fd40f9004e | 2 weeks ago |
M1nd3r | 9b4f56cfdd | 2 weeks ago |
M1nd3r | 9cb4c0f7c2 | 2 weeks ago |
M1nd3r | b667e8e033 | 2 weeks ago |
M1nd3r | 1607b41a26 | 2 weeks ago |
M1nd3r | 8f776fcced | 2 weeks ago |
M1nd3r | 750c37697e | 2 weeks ago |
M1nd3r | 29629ffb9e | 2 weeks ago |
M1nd3r | ca5ddb5c66 | 2 weeks ago |
M1nd3r | 00d67f8b0b | 2 weeks ago |
M1nd3r | 7f7e42a8ce | 2 weeks ago |
M1nd3r | aa615de463 | 2 weeks ago |
M1nd3r | 685cadf8c9 | 2 weeks ago |
M1nd3r | 5493cd90f1 | 2 weeks ago |
M1nd3r | 78b09ef2b5 | 2 weeks ago |
M1nd3r | 15c0c537e0 | 2 weeks ago |
M1nd3r | 650d38a6fb | 2 weeks ago |
M1nd3r | ca6c1cca74 | 2 weeks ago |
M1nd3r | 6a97c7e88a | 2 weeks ago |
M1nd3r | fe72c472ee | 2 weeks ago |
M1nd3r | 6a89178368 | 2 weeks ago |
M1nd3r | dbd29cf0aa | 2 weeks ago |
M1nd3r | d4622b1b15 | 2 weeks ago |
M1nd3r | 714a949919 | 2 weeks ago |
M1nd3r | 5bdd2e7fa5 | 2 weeks ago |
M1nd3r | a245ef195e | 2 weeks ago |
M1nd3r | 912c85e21e | 2 weeks ago |
M1nd3r | 37547b19da | 2 weeks ago |
M1nd3r | 6f3db981ec | 2 weeks ago |
M1nd3r | 7c447ac5d1 | 2 weeks ago |
M1nd3r | 42873b1c30 | 2 weeks ago |
M1nd3r | f75ee29ffa | 2 weeks ago |
M1nd3r | fb99d1dbe6 | 2 weeks ago |
M1nd3r | ef1b429c62 | 2 weeks ago |
M1nd3r | 4f9b3944ab | 2 weeks ago |
M1nd3r | 45b0293371 | 2 weeks ago |
M1nd3r | 947cd8fa1d | 2 weeks ago |
M1nd3r | 92cbacba9e | 2 weeks ago |
M1nd3r | 9683039111 | 2 weeks ago |
M1nd3r | 2066790d6a | 2 weeks ago |
matejcik | fa6c9322e2 | 2 weeks ago |
M1nd3r | eabf2e62e3 | 2 weeks ago |
M1nd3r | 2923cc99c5 | 2 weeks ago |
M1nd3r | cab2ec2d34 | 2 weeks ago |
M1nd3r | 54221797f9 | 2 weeks ago |
@ -0,0 +1,190 @@
|
||||
syntax = "proto2";
|
||||
package hw.trezor.messages.thp;
|
||||
|
||||
// Sugar for easier handling in Java
|
||||
option java_package = "com.satoshilabs.trezor.lib.protobuf";
|
||||
option java_outer_classname = "TrezorMessageThp";
|
||||
|
||||
|
||||
/**
|
||||
* Numeric identifiers of pairing methods.
|
||||
* @embed
|
||||
*/
|
||||
enum ThpPairingMethod {
|
||||
PairingMethod_NoMethod = 1; // Trust without MITM protection.
|
||||
PairingMethod_CodeEntry = 2; // User types code diplayed on Trezor into the host application.
|
||||
PairingMethod_QrCode = 3; // User scans code displayed on Trezor into host application.
|
||||
PairingMethod_NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
|
||||
}
|
||||
|
||||
/**
|
||||
* @embed
|
||||
*/
|
||||
message ThpDeviceProperties {
|
||||
optional string internal_model = 1; // Internal model name e.g. "T2B1".
|
||||
optional uint32 model_variant = 2; // Encodes the device properties such as color.
|
||||
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
|
||||
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
|
||||
repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
|
||||
}
|
||||
|
||||
/**
|
||||
* @embed
|
||||
*/
|
||||
message ThpHandshakeCompletionReqNoisePayload {
|
||||
optional bytes host_pairing_credential = 1; // Host's pairing credential
|
||||
repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Ask device for a new session with given passphrase.
|
||||
* @start
|
||||
* @next ThpNewSession
|
||||
*/
|
||||
message ThpCreateNewSession{
|
||||
optional string passphrase = 1;
|
||||
optional bool on_device = 2; // User wants to enter passphrase on the device
|
||||
optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Contains session_id of the newly created session.
|
||||
* @end
|
||||
*/
|
||||
message ThpNewSession{
|
||||
optional uint32 new_session_id = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Start pairing process.
|
||||
* @start
|
||||
* @next ThpCodeEntryCommitment
|
||||
* @next ThpPairingPreparationsFinished
|
||||
*/
|
||||
message ThpStartPairingRequest{
|
||||
optional string host_name = 1; // Human-readable host name
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Pairing is ready for user input / OOB communication.
|
||||
* @next ThpCodeEntryCpace
|
||||
* @next ThpQrCodeTag
|
||||
* @next ThpNfcUnidirectionalTag
|
||||
*/
|
||||
message ThpPairingPreparationsFinished{
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment.
|
||||
* @next ThpCodeEntryChallenge
|
||||
*/
|
||||
message ThpCodeEntryCommitment {
|
||||
optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Host responds to Trezor's Code Entry commitment with a challenge.
|
||||
* @next ThpPairingPreparationsFinished
|
||||
*/
|
||||
message ThpCodeEntryChallenge {
|
||||
optional bytes challenge = 1; // host's random 32-byte challenge
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor.
|
||||
* @next ThpCodeEntryCpaceTrezor
|
||||
*/
|
||||
message ThpCodeEntryCpaceHost {
|
||||
optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor continues with the CPACE protocol.
|
||||
* @next ThpCodeEntryTag
|
||||
*/
|
||||
message ThpCodeEntryCpaceTrezor {
|
||||
optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Host continues with the CPACE protocol.
|
||||
* @next ThpCodeEntrySecret
|
||||
*/
|
||||
message ThpCodeEntryTag {
|
||||
optional bytes tag = 2; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor finishes the CPACE protocol.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpCodeEntrySecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected QR Code pairing option. Host sends a QR Tag.
|
||||
* @next ThpQrCodeSecret
|
||||
*/
|
||||
message ThpQrCodeTag {
|
||||
optional bytes tag = 1; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor sends the QR secret.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpQrCodeSecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag.
|
||||
* @next ThpNfcUnideirectionalSecret
|
||||
*/
|
||||
message ThpNfcUnidirectionalTag {
|
||||
optional bytes tag = 1; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor sends the Unidirectioal NFC secret.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpNfcUnideirectionalSecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Host requests issuance of a new pairing credential.
|
||||
* @start
|
||||
* @next ThpCredentialResponse
|
||||
*/
|
||||
message ThpCredentialRequest {
|
||||
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake.
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor issues a new pairing credential.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpCredentialResponse {
|
||||
optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake.
|
||||
optional bytes credential = 2; // The pairing credential issued by the Trezor to the host.
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Host requests transition to the encrypted traffic phase.
|
||||
* @start
|
||||
* @next ThpEndResponse
|
||||
*/
|
||||
message ThpEndRequest {}
|
||||
|
||||
/**
|
||||
* Response: Trezor approves transition to the encrypted traffic phase
|
||||
* @end
|
||||
*/
|
||||
message ThpEndResponse {}
|
@ -0,0 +1,33 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, loop
|
||||
from trezor.messages import ThpCreateNewSession, ThpNewSession
|
||||
from trezor.wire.thp import SessionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire.thp import ChannelContext
|
||||
|
||||
|
||||
async def create_new_session(
|
||||
channel: ChannelContext, message: ThpCreateNewSession
|
||||
) -> ThpNewSession:
|
||||
# from apps.common.seed import get_seed TODO
|
||||
from trezor.wire.thp.session_manager import create_new_session
|
||||
|
||||
session = create_new_session(channel)
|
||||
session.set_session_state(SessionState.ALLOCATED)
|
||||
channel.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
new_session_id: int = session.session_id
|
||||
# await get_seed() TODO
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"create_new_session - new session created. Passphrase: %s, Session id: %d",
|
||||
message.passphrase if message.passphrase is not None else "",
|
||||
session.session_id,
|
||||
)
|
||||
print(channel.sessions)
|
||||
|
||||
return ThpNewSession(new_session_id=new_session_id)
|
@ -0,0 +1,199 @@
|
||||
from trezor import protobuf
|
||||
from trezor.enums import MessageType, ThpPairingMethod
|
||||
from trezor.messages import (
|
||||
ThpCodeEntryChallenge,
|
||||
ThpCodeEntryCommitment,
|
||||
ThpCodeEntryCpaceHost,
|
||||
ThpCodeEntryCpaceTrezor,
|
||||
ThpCodeEntrySecret,
|
||||
ThpCodeEntryTag,
|
||||
ThpCredentialRequest,
|
||||
ThpCredentialResponse,
|
||||
ThpEndRequest,
|
||||
ThpEndResponse,
|
||||
ThpNfcUnideirectionalSecret,
|
||||
ThpNfcUnidirectionalTag,
|
||||
ThpPairingPreparationsFinished,
|
||||
ThpQrCodeSecret,
|
||||
ThpQrCodeTag,
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
from trezor.wire.errors import UnexpectedMessage
|
||||
from trezor.wire.thp import ChannelState
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
from trezor.wire.thp.thp_session import ThpError
|
||||
|
||||
# TODO implement the following handlers
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
|
||||
async def handle_pairing_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpStartPairingRequest.is_type_of(message)
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_pairing_request")
|
||||
|
||||
_check_state(ctx, ChannelState.TP1)
|
||||
|
||||
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
|
||||
|
||||
response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
|
||||
return await _handle_code_entry_challenge(ctx, response)
|
||||
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
|
||||
response = await ctx.call_any(
|
||||
ThpPairingPreparationsFinished(),
|
||||
MessageType.ThpQrCodeTag,
|
||||
MessageType.ThpNfcUnidirectionalTag,
|
||||
)
|
||||
if ThpQrCodeTag.is_type_of(response):
|
||||
return await _handle_qr_code_tag(ctx, response)
|
||||
if ThpNfcUnidirectionalTag.is_type_of(response):
|
||||
return await _handle_nfc_unidirectional_tag(ctx, response)
|
||||
raise Exception(
|
||||
"TODO Change this exception message and type. This exception should result in channel destruction."
|
||||
)
|
||||
|
||||
|
||||
async def _handle_code_entry_challenge(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpCodeEntryChallenge.is_type_of(message)
|
||||
|
||||
_check_state(ctx, ChannelState.TP2)
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
|
||||
response = await ctx.call_any(
|
||||
ThpPairingPreparationsFinished(),
|
||||
MessageType.ThpCodeEntryCpaceHost,
|
||||
MessageType.ThpQrCodeTag,
|
||||
MessageType.ThpNfcUnidirectionalTag,
|
||||
)
|
||||
if ThpCodeEntryCpaceHost.is_type_of(response):
|
||||
return await _handle_code_entry_cpace(ctx, response)
|
||||
if ThpQrCodeTag.is_type_of(response):
|
||||
return await _handle_qr_code_tag(ctx, response)
|
||||
if ThpNfcUnidirectionalTag.is_type_of(response):
|
||||
return await _handle_nfc_unidirectional_tag(ctx, response)
|
||||
raise Exception(
|
||||
"TODO Change this exception message and type. This exception should result in channel destruction."
|
||||
)
|
||||
|
||||
|
||||
async def _handle_code_entry_cpace(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpCodeEntryCpaceHost.is_type_of(message)
|
||||
|
||||
_check_state(ctx, ChannelState.TP3)
|
||||
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
|
||||
response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
|
||||
return await _handle_code_entry_tag(ctx, response)
|
||||
|
||||
|
||||
async def _handle_code_entry_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpCodeEntryTag.is_type_of(message)
|
||||
return await _handle_tag_message(
|
||||
ctx,
|
||||
expected_state=ChannelState.TP4,
|
||||
used_method=ThpPairingMethod.PairingMethod_CodeEntry,
|
||||
msg=ThpCodeEntrySecret(),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_qr_code_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpQrCodeTag.is_type_of(message)
|
||||
return await _handle_tag_message(
|
||||
ctx,
|
||||
expected_state=ChannelState.TP3,
|
||||
used_method=ThpPairingMethod.PairingMethod_QrCode,
|
||||
msg=ThpQrCodeSecret(),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_nfc_unidirectional_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpNfcUnidirectionalTag.is_type_of(message)
|
||||
return await _handle_tag_message(
|
||||
ctx,
|
||||
expected_state=ChannelState.TP3,
|
||||
used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional,
|
||||
msg=ThpNfcUnideirectionalSecret(),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_credential_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpCredentialRequest.is_type_of(message)
|
||||
|
||||
_check_state(ctx, ChannelState.TC1)
|
||||
response = await ctx.call_any(
|
||||
ThpCredentialResponse(),
|
||||
MessageType.ThpCredentialRequest,
|
||||
MessageType.ThpEndRequest,
|
||||
)
|
||||
return await _handle_credential_request_or_end_request(ctx, response)
|
||||
|
||||
|
||||
async def _handle_end_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
assert ThpEndRequest.is_type_of(message)
|
||||
|
||||
_check_state(ctx, ChannelState.TC1)
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
return ThpEndResponse()
|
||||
|
||||
|
||||
async def _handle_tag_message(
|
||||
ctx: PairingContext,
|
||||
expected_state: ChannelState,
|
||||
used_method: ThpPairingMethod,
|
||||
msg: protobuf.MessageType,
|
||||
) -> ThpEndResponse:
|
||||
_check_state(ctx, expected_state)
|
||||
_check_method_is_allowed(ctx, used_method)
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
|
||||
response = await ctx.call_any(
|
||||
msg,
|
||||
MessageType.ThpCredentialRequest,
|
||||
MessageType.ThpEndRequest,
|
||||
)
|
||||
return await _handle_credential_request_or_end_request(ctx, response)
|
||||
|
||||
|
||||
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
|
||||
if expected_state is not ctx.channel_ctx.get_channel_state():
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
|
||||
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
|
||||
if not _is_method_included(ctx, method):
|
||||
raise ThpError("Unexpected pairing method")
|
||||
|
||||
|
||||
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
|
||||
return method in ctx.channel_ctx.selected_pairing_methods
|
||||
|
||||
|
||||
async def _handle_credential_request_or_end_request(
|
||||
ctx: PairingContext, response: protobuf.MessageType | None
|
||||
) -> ThpEndResponse:
|
||||
if ThpCredentialRequest.is_type_of(response):
|
||||
return await _handle_credential_request(ctx, response)
|
||||
if ThpEndRequest.is_type_of(response):
|
||||
return await _handle_end_request(ctx, response)
|
||||
raise UnexpectedMessage(
|
||||
"Received message is not credential request or end request."
|
||||
)
|
@ -0,0 +1,145 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_active_session() -> SessionCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
return _SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
@ -0,0 +1,67 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: # noqa: F811
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
)
|
||||
super().__init__()
|
@ -0,0 +1,451 @@
|
||||
import builtins
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar # pyright: ignore[reportShadowedImports]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
# THP specific constants
|
||||
_MAX_CHANNELS_COUNT = 10
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
|
||||
|
||||
|
||||
_CHANNEL_STATE_LENGTH = const(1)
|
||||
_WIRE_INTERFACE_LENGTH = const(1)
|
||||
_SESSION_STATE_LENGTH = const(1)
|
||||
_CHANNEL_ID_LENGTH = const(2)
|
||||
SESSION_ID_LENGTH = const(1)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
KEY_LENGTH = const(32)
|
||||
TAG_LENGTH = const(16)
|
||||
_UNALLOCATED_STATE = const(0)
|
||||
|
||||
|
||||
class ConnectionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.channel_id[:] = b""
|
||||
self.last_usage = 0
|
||||
super().clear()
|
||||
|
||||
|
||||
class ChannelCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
|
||||
self.enc_key = bytearray(KEY_LENGTH)
|
||||
self.dec_key = bytearray(KEY_LENGTH)
|
||||
self.state = bytearray(_CHANNEL_STATE_LENGTH)
|
||||
self.iface = bytearray(1) # TODO add decoding
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.session_id_counter = 0x00
|
||||
self.fields = ()
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.state[:] = bytearray(
|
||||
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
|
||||
) # Set state to UNALLOCATED
|
||||
# TODO clear all sessions that are under this channel
|
||||
super().clear()
|
||||
|
||||
|
||||
class SessionThpCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
|
||||
self.session_id[:] = b""
|
||||
super().clear()
|
||||
|
||||
|
||||
_CHANNELS: list[ChannelCache] = []
|
||||
_SESSIONS: list[SessionThpCache] = []
|
||||
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _CHANNELS
|
||||
global _SESSIONS
|
||||
global _UNAUTHENTICATED_SESSIONS
|
||||
|
||||
for _ in range(_MAX_CHANNELS_COUNT):
|
||||
_CHANNELS.append(ChannelCache())
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for channel in _CHANNELS:
|
||||
channel.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
for session in _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
# THP vars
|
||||
_next_unauthenicated_session_index: int = 0 # TODO remove
|
||||
|
||||
# First unauthenticated channel will have index 0
|
||||
_is_active_session_authenticated: bool
|
||||
_active_session_idx: int | None = None
|
||||
_usage_counter = 0
|
||||
|
||||
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
||||
cid_counter: int = 4659 # TODO change to random value on start
|
||||
|
||||
|
||||
def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
|
||||
if len(iface) != _WIRE_INTERFACE_LENGTH:
|
||||
raise Exception("Invalid WireInterface (encoded) length")
|
||||
|
||||
new_cid = get_next_channel_id()
|
||||
index = _get_next_unauthenticated_channel_index()
|
||||
|
||||
# clear sessions from replaced channel
|
||||
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
|
||||
old_cid = _CHANNELS[index].channel_id
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == old_cid:
|
||||
session.clear()
|
||||
|
||||
_CHANNELS[index] = ChannelCache()
|
||||
_CHANNELS[index].channel_id[:] = new_cid
|
||||
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
|
||||
_CHANNELS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||
)
|
||||
_CHANNELS[index].iface[:] = bytearray(iface)
|
||||
return _CHANNELS[index]
|
||||
|
||||
|
||||
def get_all_allocated_channels() -> list[ChannelCache]:
|
||||
_list: list[ChannelCache] = []
|
||||
for channel in _CHANNELS:
|
||||
if _get_channel_state(channel) != _UNALLOCATED_STATE:
|
||||
_list.append(channel)
|
||||
return _list
|
||||
|
||||
|
||||
def get_all_allocated_sessions() -> list[SessionThpCache]:
|
||||
_list: list[SessionThpCache] = []
|
||||
for session in _SESSIONS:
|
||||
if _get_session_state(session) != _UNALLOCATED_STATE:
|
||||
_list.append(session)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "session %s is not in UNALLOCATED state", str(session)
|
||||
)
|
||||
elif __debug__:
|
||||
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))
|
||||
return _list
|
||||
|
||||
|
||||
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
|
||||
if len(key) != KEY_LENGTH:
|
||||
raise Exception("Invalid key length")
|
||||
channel.host_ephemeral_pubkey = key
|
||||
|
||||
|
||||
def get_new_session(channel: ChannelCache):
|
||||
new_sid = get_next_session_id(channel)
|
||||
index = _get_next_session_index()
|
||||
|
||||
_SESSIONS[index] = SessionThpCache()
|
||||
_SESSIONS[index].channel_id[:] = channel.channel_id
|
||||
_SESSIONS[index].session_id[:] = new_sid
|
||||
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
|
||||
channel.last_usage = (
|
||||
_get_usage_counter_and_increment()
|
||||
) # increment also use of the channel so it does not get replaced
|
||||
_SESSIONS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
|
||||
)
|
||||
return _SESSIONS[index]
|
||||
|
||||
|
||||
def _get_usage_counter() -> int:
|
||||
global _usage_counter
|
||||
return _usage_counter
|
||||
|
||||
|
||||
def _get_usage_counter_and_increment() -> int:
|
||||
global _usage_counter
|
||||
_usage_counter += 1
|
||||
return _usage_counter
|
||||
|
||||
|
||||
def _get_next_unauthenticated_channel_index() -> int:
|
||||
idx = _get_unallocated_channel_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
|
||||
|
||||
|
||||
def _get_next_session_index() -> int:
|
||||
idx = _get_unallocated_session_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
|
||||
|
||||
|
||||
def _get_unallocated_channel_index() -> int | None:
|
||||
for i in range(_MAX_CHANNELS_COUNT):
|
||||
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_unallocated_session_index() -> int | None:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_channel_state(channel: ChannelCache) -> int:
|
||||
if channel is None:
|
||||
return _UNALLOCATED_STATE
|
||||
return int.from_bytes(channel.state, "big")
|
||||
|
||||
|
||||
def _get_session_state(session: SessionThpCache) -> int:
|
||||
if session is None:
|
||||
return _UNALLOCATED_STATE
|
||||
return int.from_bytes(session.state, "big")
|
||||
|
||||
|
||||
def get_active_session_id() -> bytearray | None:
|
||||
active_session = get_active_session()
|
||||
|
||||
if active_session is None:
|
||||
return None
|
||||
return active_session.session_id
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
if _is_active_session_authenticated:
|
||||
return _SESSIONS[_active_session_idx]
|
||||
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def get_next_channel_id() -> bytes:
|
||||
global cid_counter
|
||||
while True:
|
||||
cid_counter += 1
|
||||
if cid_counter >= BROADCAST_CHANNEL_ID:
|
||||
cid_counter = 1
|
||||
if _is_cid_unique():
|
||||
break
|
||||
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def get_next_session_id(channel: ChannelCache) -> bytes:
|
||||
while True:
|
||||
if channel.session_id_counter >= 255:
|
||||
channel.session_id_counter = 1
|
||||
else:
|
||||
channel.session_id_counter += 1
|
||||
if _is_session_id_unique(channel):
|
||||
break
|
||||
new_sid = channel.session_id_counter
|
||||
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def _is_session_id_unique(channel: ChannelCache) -> bool:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == channel.channel_id:
|
||||
if session.session_id == channel.session_id_counter:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
if cid_counter == _get_cid(session):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_cid(session: SessionThpCache) -> int:
|
||||
return int.from_bytes(session.session_id[2:], "big")
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
|
||||
if len(session_id) != SESSION_ID_LENGTH:
|
||||
raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
global _next_unauthenicated_session_index
|
||||
|
||||
i = _next_unauthenicated_session_index
|
||||
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
|
||||
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
|
||||
_next_unauthenicated_session_index += 1
|
||||
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
|
||||
_next_unauthenicated_session_index = 0
|
||||
|
||||
# Set session as active if and only if there is no active session
|
||||
if _active_session_idx is None:
|
||||
_active_session_idx = i
|
||||
_is_active_session_authenticated = False
|
||||
return _UNAUTHENTICATED_SESSIONS[i]
|
||||
|
||||
|
||||
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
||||
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||
unauth_session_idx = get_unauth_session_index(unauth_session)
|
||||
if unauth_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
|
||||
# replace least recently used authenticated session by the new session
|
||||
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
|
||||
|
||||
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
||||
_UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear()
|
||||
|
||||
_SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment()
|
||||
return _SESSIONS[new_auth_session_index]
|
||||
|
||||
|
||||
def get_least_recently_used_authetnicated_session_index() -> int:
|
||||
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
|
||||
|
||||
|
||||
def get_least_recently_used_item(
|
||||
list: list[ChannelCache] | list[SessionThpCache], max_count: int
|
||||
):
|
||||
lru_counter = _get_usage_counter()
|
||||
lru_item_index = 0
|
||||
for i in range(max_count):
|
||||
if list[i].last_usage < lru_counter:
|
||||
lru_counter = list[i].last_usage
|
||||
lru_item_index = i
|
||||
return lru_item_index
|
||||
|
||||
|
||||
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
||||
def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
|
||||
if session_id is not None:
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
|
||||
channel = get_new_unauthenticated_channel(b"\x00")
|
||||
|
||||
new_session_id = get_next_session_id(channel)
|
||||
|
||||
new_session = create_new_unauthenticated_session(new_session_id)
|
||||
|
||||
index = get_unauth_session_index(new_session)
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
|
||||
return new_session_id
|
||||
|
||||
|
||||
def start_existing_session(session_id: bytes) -> bytes:
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
|
||||
if session_id is None:
|
||||
raise ValueError("session_id cannot be None")
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
raise ValueError("There is no active session with provided session_id")
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
active_session = get_active_session()
|
||||
if active_session is None:
|
||||
return
|
||||
active_session.clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
@ -0,0 +1,8 @@
|
||||
# Automatically generated by pb2py
|
||||
# fmt: off
|
||||
# isort:skip_file
|
||||
|
||||
PairingMethod_NoMethod = 1
|
||||
PairingMethod_CodeEntry = 2
|
||||
PairingMethod_QrCode = 3
|
||||
PairingMethod_NFC_Unidirectional = 4
|
@ -0,0 +1,206 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import context, protocol_common
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[[Msg], HandlerTask]
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if __debug__:
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
async def handle_single_message(
|
||||
ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool
|
||||
) -> protocol_common.MessageWithType | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
If the workflow finished normally or with an error, the return value is None.
|
||||
|
||||
If an unexpected message had arrived on the wire while the workflow was processing,
|
||||
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
||||
considered an "error condition" to return over the wire -- instead the message
|
||||
is processed as if starting a new workflow.
|
||||
In such case, the `UnexpectedMessage` is caught and the message is returned
|
||||
to the caller. It will then be processed in the next iteration of the message loop.
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
if ctx.channel_id is not None:
|
||||
sid = int.from_bytes(ctx.channel_id, "big")
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
sid,
|
||||
msg_type,
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:unknown_sid receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type. Should not raise.
|
||||
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
# respond with failure.
|
||||
await ctx.write(unexpected_message())
|
||||
return None
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
if use_workflow:
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
else:
|
||||
# For debug messages, ignore workflow processing and just await
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except context.UnexpectedMessageWithId as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
return exc.msg
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
return None
|
||||
|
||||
|
||||
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
"""Placeholder handler lookup before a proper one is registered."""
|
||||
return None
|
||||
|
||||
|
||||
find_handler = _find_handler_placeholder
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
@ -0,0 +1,78 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import (
|
||||
Container,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
message_data: bytes,
|
||||
) -> None:
|
||||
self.data = message_data
|
||||
|
||||
def to_bytes(self):
|
||||
return self.data
|
||||
|
||||
|
||||
class MessageWithType(Message):
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
) -> None:
|
||||
self.type = message_type
|
||||
super().__init__(message_data)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.type.to_bytes(2, "big") + self.data
|
||||
|
||||
|
||||
class MessageWithId(MessageWithType):
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
session_id: bytearray | None = None,
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
super().__init__(message_type, message_data)
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
|
||||
self.iface: WireInterface = iface
|
||||
self.channel_id: bytes = channel_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
@ -0,0 +1,90 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
from trezorio import WireInterface
|
||||
from typing import Protocol
|
||||
|
||||
from storage.cache_thp import ChannelCache
|
||||
from trezor import loop, protobuf, utils
|
||||
from trezor.enums import FailureType
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
class ChannelContext(Protocol):
|
||||
buffer: utils.BufferType
|
||||
iface: WireInterface
|
||||
channel_id: bytes
|
||||
channel_cache: ChannelCache
|
||||
selected_pairing_methods = [] # TODO add type
|
||||
sessions: dict[int, SessionContext]
|
||||
waiting_for_ack_timeout: loop.spawn | None
|
||||
write_task_spawn: loop.spawn | None
|
||||
connection_context: PairingContext | None
|
||||
|
||||
def get_channel_state(self) -> int: ...
|
||||
|
||||
def set_channel_state(self, state: "ChannelState") -> None: ...
|
||||
|
||||
async def write(
|
||||
self, msg: protobuf.MessageType, session_id: int = 0
|
||||
) -> None: ...
|
||||
|
||||
async def write_error(self, err_type: FailureType, message: str) -> None: ...
|
||||
|
||||
async def write_handshake_message(
|
||||
self, ctrl_byte: int, payload: bytes
|
||||
) -> None: ...
|
||||
|
||||
def decrypt_buffer(self, message_length: int) -> None: ...
|
||||
|
||||
def get_channel_id_int(self) -> int: ...
|
||||
|
||||
else:
|
||||
IntEnum = object
|
||||
|
||||
|
||||
class ChannelState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
TH1 = 1
|
||||
TH2 = 2
|
||||
TP1 = 3
|
||||
TP2 = 4
|
||||
TP3 = 5
|
||||
TP4 = 6
|
||||
TC1 = 7
|
||||
ENCRYPTED_TRANSPORT = 8
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
ALLOCATED = 1
|
||||
|
||||
|
||||
class WireInterfaceType(IntEnum):
|
||||
MOCK = 0
|
||||
USB = 1
|
||||
BLE = 2
|
||||
|
||||
|
||||
def is_channel_state_pairing(state: int) -> bool:
|
||||
if state in (
|
||||
ChannelState.TP1,
|
||||
ChannelState.TP2,
|
||||
ChannelState.TP3,
|
||||
ChannelState.TP4,
|
||||
ChannelState.TC1,
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if __debug__:
|
||||
|
||||
def state_to_str(state: int) -> str:
|
||||
name = {
|
||||
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
|
||||
}.get(state)
|
||||
if name is not None:
|
||||
return name
|
||||
return "UNKNOWN_STATE"
|
@ -0,0 +1,30 @@
|
||||
from storage.cache_thp import ChannelCache, SessionThpCache
|
||||
from trezor import log
|
||||
|
||||
from . import thp_session as THP
|
||||
|
||||
|
||||
def is_ack_valid(cache: SessionThpCache | ChannelCache, sync_bit: int) -> bool:
|
||||
if not _is_ack_expected(cache):
|
||||
return False
|
||||
|
||||
if not _has_ack_correct_sync_bit(cache, sync_bit):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _is_ack_expected(cache: SessionThpCache | ChannelCache) -> bool:
|
||||
is_expected: bool = not THP.sync_can_send_message(cache)
|
||||
if __debug__ and not is_expected:
|
||||
log.debug(__name__, "Received unexpected ACK message")
|
||||
return is_expected
|
||||
|
||||
|
||||
def _has_ack_correct_sync_bit(
|
||||
cache: SessionThpCache | ChannelCache, sync_bit: int
|
||||
) -> bool:
|
||||
is_correct: bool = THP.sync_get_send_bit(cache) == sync_bit
|
||||
if __debug__ and not is_correct:
|
||||
log.debug(__name__, "Received ACK message with wrong sync bit")
|
||||
return is_correct
|
@ -0,0 +1,276 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_thp import TAG_LENGTH, ChannelCache
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.wire.thp import interface_manager, received_message_handler
|
||||
|
||||
from . import ChannelState, checksum, control_byte, crypto, memory_manager
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
|
||||
from .thp_session import ThpError
|
||||
from .writer import (
|
||||
CONT_DATA_OFFSET,
|
||||
INIT_DATA_OFFSET,
|
||||
MESSAGE_TYPE_LENGTH,
|
||||
write_payload_to_wire,
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
from . import state_to_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
from . import ChannelContext, PairingContext
|
||||
from .session_context import SessionContext
|
||||
else:
|
||||
ChannelContext = object
|
||||
|
||||
|
||||
class Channel:
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "channel initialization")
|
||||
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
||||
self.channel_cache: ChannelCache = channel_cache
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read: int = 0
|
||||
self.buffer: utils.BufferType
|
||||
self.channel_id: bytes = channel_cache.channel_id
|
||||
self.selected_pairing_methods = []
|
||||
self.sessions: dict[int, SessionContext] = {}
|
||||
self.waiting_for_ack_timeout: loop.spawn | None = None
|
||||
self.write_task_spawn: loop.spawn | None = None
|
||||
self.connection_context: PairingContext | None = None
|
||||
|
||||
# ACCESS TO CHANNEL_DATA
|
||||
def get_channel_id_int(self) -> int:
|
||||
return int.from_bytes(self.channel_id, "big")
|
||||
|
||||
def get_channel_state(self) -> int:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
if __debug__:
|
||||
log.debug(__name__, "get_channel_state: %s", state_to_str(state))
|
||||
return state
|
||||
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
if __debug__:
|
||||
log.debug(__name__, "set_channel_state: %s", state_to_str(state))
|
||||
|
||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
if __debug__:
|
||||
log.debug(__name__, "set_buffer: %s", type(self.buffer))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
async def receive_packet(self, packet: utils.BufferType):
|
||||
if __debug__:
|
||||
log.debug(__name__, "receive_packet")
|
||||
|
||||
await self._handle_received_packet(packet)
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
|
||||
|
||||
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
||||
self._finish_message()
|
||||
await received_message_handler.handle_received_message(self, self.buffer)
|
||||
elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read:
|
||||
self.is_cont_packet_expected = True
|
||||
else:
|
||||
raise ThpError(
|
||||
"Read more bytes than is the expected length of the message, this should not happen!"
|
||||
)
|
||||
|
||||
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
||||
ctrl_byte = packet[0]
|
||||
if control_byte.is_continuation(ctrl_byte):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
|
||||
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_init_packet")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
self.expected_payload_length = payload_length
|
||||
packet_payload = packet[5:]
|
||||
# If the channel does not "own" the buffer lock, decrypt first packet
|
||||
# TODO do it only when needed!
|
||||
if control_byte.is_encrypted_transport(ctrl_byte):
|
||||
packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
||||
|
||||
self.buffer = memory_manager.select_buffer(
|
||||
self.get_channel_state(),
|
||||
self.buffer,
|
||||
packet_payload,
|
||||
payload_length,
|
||||
)
|
||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
|
||||
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
|
||||
|
||||
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_cont_packet")
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
||||
|
||||
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
||||
payload_buffer = bytearray(payload)
|
||||
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
||||
return payload_buffer
|
||||
|
||||
def decrypt_buffer(self, message_length: int) -> None:
|
||||
if not isinstance(self.buffer, bytearray):
|
||||
self.buffer = bytearray(self.buffer)
|
||||
crypto.decrypt(
|
||||
b"\x00",
|
||||
b"\x00",
|
||||
self.buffer,
|
||||
INIT_DATA_OFFSET,
|
||||
message_length - INIT_DATA_OFFSET - CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "encrypt")
|
||||
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
|
||||
new_buffer = bytearray(min_required_length)
|
||||
utils.memcpy(new_buffer, 0, buffer, 0)
|
||||
buffer = new_buffer
|
||||
tag = crypto.encrypt(
|
||||
b"\x00",
|
||||
b"\x00",
|
||||
buffer,
|
||||
0,
|
||||
noise_payload_len,
|
||||
)
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
async def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self):
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
self.buffer, msg, session_id
|
||||
)
|
||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
|
||||
async def write_error(self, err_type: FailureType, message: str) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_error")
|
||||
msg_size = memory_manager.encode_error_into_buffer(
|
||||
memoryview(self.buffer), err_type, message
|
||||
)
|
||||
data_length = MESSAGE_TYPE_LENGTH + msg_size
|
||||
header: InitHeader = InitHeader(
|
||||
ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(
|
||||
header.to_bytes() + memoryview(self.buffer[:data_length])
|
||||
)
|
||||
|
||||
utils.memcpy(self.buffer, data_length, chksum, 0)
|
||||
await write_payload_to_wire(
|
||||
self.iface, header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
|
||||
)
|
||||
|
||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
||||
payload_length = len(payload)
|
||||
|
||||
if not isinstance(self.buffer, bytearray):
|
||||
self.buffer = bytearray(self.buffer)
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
|
||||
if self.write_task_spawn is not None:
|
||||
self.write_task_spawn.close() # UPS TODO migh break something
|
||||
print("\nCLOSED\n")
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(
|
||||
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
)
|
||||
|
||||
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||
)
|
||||
|
||||
def _prepare_write(self) -> None:
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
THP.sync_set_can_send_message(self.channel_cache, False)
|
||||
|
||||
async def _write_encrypted_payload_loop(
|
||||
self, ctrl_byte: int, payload: bytes
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_encrypted_payload_loop")
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
||||
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
||||
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
||||
chksum = checksum.compute(header.to_bytes() + payload)
|
||||
payload = payload + chksum
|
||||
|
||||
while True:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"write_encrypted_payload_loop - loop start, sync_bit: %d, sync_send_bit: %d",
|
||||
(header.ctrl_byte & 0x10) >> 4,
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await write_payload_to_wire(self.iface, header, payload)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
if THP.sync_can_send_message(self.channel_cache):
|
||||
# TODO This can happen when ack is received before the message was sent,
|
||||
# but after it was scheduled to be sent (i.e. ACK was already expected)
|
||||
# This case should be removed or improved upon before production.
|
||||
break
|
||||
else:
|
||||
await self.waiting_for_ack_timeout
|
||||
except loop.TaskClosed:
|
||||
break
|
||||
|
||||
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
||||
|
||||
# Let the main loop be restarted and clear loop, if there is no other
|
||||
# workflow and the state is ENCRYPTED_TRANSPORT
|
||||
if self._can_clear_loop():
|
||||
if __debug__:
|
||||
log.debug(__name__, "clearing loop from channel")
|
||||
loop.clear()
|
||||
|
||||
def _can_clear_loop(self) -> bool:
|
||||
return (
|
||||
not workflow.tasks
|
||||
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
||||
|
||||
async def _wait_for_ack(self) -> None:
|
||||
await loop.sleep(1000)
|
@ -0,0 +1,30 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from trezor import utils
|
||||
|
||||
from . import ChannelState, interface_manager
|
||||
from .channel import Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
|
||||
|
||||
|
||||
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> "Channel":
|
||||
channel_cache = cache_thp.get_new_unauthenticated_channel(
|
||||
interface_manager.encode_iface(iface)
|
||||
)
|
||||
r = Channel(channel_cache)
|
||||
r.set_buffer(buffer)
|
||||
r.set_channel_state(ChannelState.TH1)
|
||||
return r
|
||||
|
||||
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
|
||||
channels: dict[int, Channel] = {}
|
||||
cached_channels = cache_thp.get_all_allocated_channels()
|
||||
for c in cached_channels:
|
||||
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
|
||||
for c in channels.values():
|
||||
c.set_buffer(buffer)
|
||||
return channels
|
@ -0,0 +1,15 @@
|
||||
from micropython import const
|
||||
|
||||
from trezor import utils
|
||||
from trezor.crypto import crc
|
||||
|
||||
CHECKSUM_LENGTH = const(4)
|
||||
|
||||
|
||||
def compute(data: bytes | utils.BufferType) -> bytes:
|
||||
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||
|
||||
|
||||
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
|
||||
data_checksum = compute(data)
|
||||
return checksum == data_checksum
|
@ -0,0 +1,36 @@
|
||||
from trezor.wire.thp.thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
HANDSHAKE_COMP_REQ,
|
||||
HANDSHAKE_INIT_REQ,
|
||||
)
|
||||
from trezor.wire.thp.thp_session import ThpError
|
||||
|
||||
|
||||
def add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if sync_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected synchronization bit")
|
||||
|
||||
|
||||
def is_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ACK_MESSAGE
|
||||
|
||||
|
||||
def is_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
||||
|
||||
|
||||
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
|
||||
|
||||
|
||||
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ
|
@ -0,0 +1,37 @@
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
|
||||
PUBKEY_LENGTH = const(32)
|
||||
|
||||
|
||||
# TODO implement
|
||||
|
||||
|
||||
def encrypt(
|
||||
key: bytes,
|
||||
nonce: bytes,
|
||||
buffer: bytearray,
|
||||
init_offset: int = 0,
|
||||
payload_length: int = 0,
|
||||
) -> bytes:
|
||||
"""
|
||||
Returns a 16-byte long encryption tag, the encryption itself is performed on the buffer provided.
|
||||
"""
|
||||
return DUMMY_TAG
|
||||
|
||||
|
||||
def decrypt(
|
||||
key: bytes,
|
||||
nonce: bytes,
|
||||
buffer: bytearray,
|
||||
init_offset: int = 0,
|
||||
payload_length: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Decryption in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def is_tag_valid(key: bytes, nonce: bytes, payload: bytes, noise_tag: bytes) -> bool:
|
||||
return True
|
@ -0,0 +1,16 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import protobuf
|
||||
|
||||
from apps.thp import create_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_handler_for_channel_message(
|
||||
msg: protobuf.MessageType,
|
||||
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
|
||||
return create_session.create_new_session
|
@ -0,0 +1,32 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import usb
|
||||
|
||||
_MOCK_INTERFACE_HID = b"\x00"
|
||||
_WIRE_INTERFACE_USB = b"\x01"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
|
||||
|
||||
|
||||
def decode_iface(cached_iface: bytes) -> WireInterface:
|
||||
"""Decode the cached wire interface."""
|
||||
if cached_iface == _WIRE_INTERFACE_USB:
|
||||
iface = usb.iface_wire
|
||||
if iface is None:
|
||||
raise RuntimeError("There is no valid USB WireInterface")
|
||||
return iface
|
||||
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
|
||||
raise NotImplementedError("Should return MockHID WireInterface")
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
def encode_iface(iface: WireInterface) -> bytes:
|
||||
"""Encode wire interface into bytes."""
|
||||
if iface is usb.iface_wire:
|
||||
return _WIRE_INTERFACE_USB
|
||||
# TODO implement bluetooth interface
|
||||
if __debug__:
|
||||
return _MOCK_INTERFACE_HID
|
||||
raise Exception("Unknown WireInterface")
|
@ -0,0 +1,128 @@
|
||||
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
||||
from trezor import log, protobuf, utils
|
||||
from trezor.enums import FailureType, MessageType
|
||||
from trezor.messages import Failure
|
||||
|
||||
from . import ChannelState
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .thp_session import ThpError
|
||||
from .writer import (
|
||||
INIT_DATA_OFFSET,
|
||||
MAX_PAYLOAD_LEN,
|
||||
MESSAGE_TYPE_LENGTH,
|
||||
REPORT_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def select_buffer(
|
||||
channel_state: int,
|
||||
channel_buffer: utils.BufferType,
|
||||
packet_payload: utils.BufferType,
|
||||
payload_length: int,
|
||||
) -> utils.BufferType:
|
||||
|
||||
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
session_id = packet_payload[0]
|
||||
if session_id == 0:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
else:
|
||||
pass
|
||||
# TODO use big buffer but only if the channel owns the buffer lock.
|
||||
# Otherwise send BUSY message and return
|
||||
else:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
try:
|
||||
# TODO for now, we create a new big buffer every time. It should be changed
|
||||
buffer: utils.BufferType = _get_buffer_for_message(
|
||||
payload_length, channel_buffer
|
||||
)
|
||||
return buffer
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
raise Exception("Failed to create a buffer for channel") # TODO handle better
|
||||
|
||||
|
||||
def encode_into_buffer(
|
||||
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
||||
) -> int:
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
||||
|
||||
if required_min_size > len(buffer):
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(required_min_size)
|
||||
|
||||
_encode_session_into_buffer(memoryview(buffer), session_id)
|
||||
_encode_message_type_into_buffer(
|
||||
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
|
||||
)
|
||||
_encode_message_into_buffer(
|
||||
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
)
|
||||
|
||||
return payload_size
|
||||
|
||||
|
||||
def encode_error_into_buffer(
|
||||
buffer: memoryview, err_code: FailureType, message: str
|
||||
) -> int:
|
||||
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
|
||||
_encode_message_type_into_buffer(buffer, MessageType.Failure)
|
||||
_encode_message_into_buffer(buffer, error_message, MESSAGE_TYPE_LENGTH)
|
||||
return protobuf.encoded_length(error_message)
|
||||
|
||||
|
||||
def _encode_session_into_buffer(
|
||||
buffer: memoryview, session_id: int, buffer_offset: int = 0
|
||||
) -> None:
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_type_into_buffer(
|
||||
buffer: memoryview, message_type: int, offset: int = 0
|
||||
) -> None:
|
||||
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
|
||||
utils.memcpy(buffer, offset, msg_type_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_into_buffer(
|
||||
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
|
||||
) -> None:
|
||||
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
|
||||
|
||||
|
||||
def _get_buffer_for_message(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
length = payload_length + INIT_DATA_OFFSET
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"get_buffer_for_message - length: %d, %s %s",
|
||||
length,
|
||||
"existing buffer type:",
|
||||
type(existing_buffer),
|
||||
)
|
||||
if length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
|
||||
if length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:length]
|
@ -0,0 +1,203 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import loop, protobuf, workflow
|
||||
from trezor.wire import context, message_handler, protocol_common
|
||||
from trezor.wire.context import UnexpectedMessageWithId
|
||||
from trezor.wire.errors import ActionCancelled
|
||||
from trezor.wire.protocol_common import Context, MessageWithType
|
||||
|
||||
from .session_context import UnexpectedMessageWithType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Container
|
||||
|
||||
from . import ChannelContext
|
||||
|
||||
pass
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
|
||||
class PairingContext(Context):
|
||||
def __init__(self, channel_ctx: ChannelContext) -> None:
|
||||
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
|
||||
self.channel_ctx = channel_ctx
|
||||
self.incoming_message = loop.chan()
|
||||
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle - start")
|
||||
if is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
take = self.incoming_message.take()
|
||||
next_message: MessageWithType | None = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
if next_message is None:
|
||||
# If the previous run did not keep an unprocessed message for us,
|
||||
# wait for a new one.
|
||||
try:
|
||||
message: MessageWithType = await take
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
await self.write(message_handler.failure(e))
|
||||
continue
|
||||
else:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
|
||||
try:
|
||||
next_message = await handle_pairing_request_message(
|
||||
self, message, use_workflow=not is_debug_session
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
if not __debug__ or not is_debug_session:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# This is not done for the debug session because the snapshot taken
|
||||
# in a debug session would clear modules which are in use by the
|
||||
# workflow running on wire.
|
||||
# TODO utils.unimport_end(modules)
|
||||
|
||||
if next_message is None:
|
||||
|
||||
# Shut down the loop if there is no next message waiting.
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
except Exception as exc:
|
||||
# Log and try again. The session handler can only exit explicitly via
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
|
||||
message: MessageWithType = await self.incoming_message.take()
|
||||
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageWithType(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_ctx.write(msg)
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||
) -> protobuf.MessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_types: int
|
||||
) -> protobuf.MessageType:
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read(expected_types)
|
||||
|
||||
|
||||
async def handle_pairing_request_message(
|
||||
pairing_ctx: PairingContext,
|
||||
msg: protocol_common.MessageWithType,
|
||||
use_workflow: bool,
|
||||
) -> protocol_common.MessageWithType | None:
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
from apps.thp.pairing import handle_pairing_request
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handle_pairing_request(pairing_ctx, req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
if use_workflow:
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
|
||||
pass # TODO
|
||||
else:
|
||||
# For debug messages, ignore workflow processing and just await
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except UnexpectedMessageWithId as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
return exc.msg
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = message_handler.failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await pairing_ctx.write(res_msg)
|
||||
return None
|
@ -0,0 +1,350 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.protocol_common import MessageWithType
|
||||
from trezor.wire.thp import ack_handler, thp_messages
|
||||
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
|
||||
from trezor.wire.thp.crypto import PUBKEY_LENGTH
|
||||
from trezor.wire.thp.thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
HANDSHAKE_COMP_RES,
|
||||
HANDSHAKE_INIT_RES,
|
||||
InitHeader,
|
||||
)
|
||||
|
||||
from . import (
|
||||
ChannelState,
|
||||
SessionState,
|
||||
checksum,
|
||||
control_byte,
|
||||
is_channel_state_pairing,
|
||||
)
|
||||
from . import thp_session as THP
|
||||
from .thp_session import ThpError
|
||||
from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
|
||||
|
||||
from . import ChannelContext
|
||||
|
||||
if __debug__:
|
||||
from . import state_to_str
|
||||
|
||||
|
||||
async def handle_received_message(
|
||||
ctx: ChannelContext, message_buffer: utils.BufferType
|
||||
) -> None:
|
||||
"""Handle a message received from the channel."""
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_message")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
||||
message_length = payload_length + INIT_DATA_OFFSET
|
||||
|
||||
_check_checksum(message_length, message_buffer)
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_completed_message - sync bit of message: %d",
|
||||
sync_bit,
|
||||
)
|
||||
|
||||
# 1: Handle ACKs
|
||||
if control_byte.is_ack(ctrl_byte):
|
||||
await _handle_ack(ctx, sync_bit)
|
||||
return
|
||||
|
||||
if _should_have_ctrl_byte_encrypted_transport(
|
||||
ctx
|
||||
) and not control_byte.is_encrypted_transport(ctrl_byte):
|
||||
raise ThpError("Message is not encrypted. Ignoring")
|
||||
|
||||
# 2: Handle message with unexpected synchronization bit
|
||||
if sync_bit != THP.sync_get_receive_expected_bit(ctx.channel_cache):
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "Received message with an unexpected synchronization bit"
|
||||
)
|
||||
await _send_ack(ctx, sync_bit)
|
||||
raise ThpError("Received message with an unexpected synchronization bit")
|
||||
|
||||
# 3: Send ACK in response
|
||||
await _send_ack(ctx, sync_bit)
|
||||
|
||||
THP.sync_set_receive_expected_bit(ctx.channel_cache, 1 - sync_bit)
|
||||
|
||||
await _handle_message_to_app_or_channel(
|
||||
ctx, payload_length, message_length, ctrl_byte, sync_bit
|
||||
)
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_message - end")
|
||||
|
||||
|
||||
async def _send_ack(ctx: ChannelContext, ack_bit: int) -> None:
|
||||
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes())
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
ctx.get_channel_id_int(),
|
||||
ack_bit,
|
||||
)
|
||||
await write_payload_to_wire(ctx.iface, header, chksum)
|
||||
|
||||
|
||||
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
||||
if __debug__:
|
||||
log.debug(__name__, "check_checksum")
|
||||
if not checksum.is_valid(
|
||||
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||
data=message_buffer[: message_length - CHECKSUM_LENGTH],
|
||||
):
|
||||
if __debug__:
|
||||
log.debug(__name__, "Invalid checksum, ignoring message.")
|
||||
raise ThpError("Invalid checksum, ignoring message.")
|
||||
|
||||
|
||||
# TEST THIS
|
||||
|
||||
|
||||
async def _handle_ack(ctx: ChannelContext, sync_bit: int):
|
||||
if not ack_handler.is_ack_valid(ctx.channel_cache, sync_bit):
|
||||
return
|
||||
# ACK is expected and it has correct sync bit
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received ACK message with correct sync bit")
|
||||
if ctx.waiting_for_ack_timeout is not None:
|
||||
ctx.waiting_for_ack_timeout.close()
|
||||
if __debug__:
|
||||
log.debug(__name__, 'Closed "waiting for ack" task')
|
||||
|
||||
THP.sync_set_can_send_message(ctx.channel_cache, True)
|
||||
|
||||
if ctx.write_task_spawn is not None:
|
||||
if __debug__:
|
||||
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
||||
await ctx.write_task_spawn
|
||||
# Note that no the write_task_spawn could result in loop.clear(),
|
||||
# which will result in terminations of this function - any code after
|
||||
# this await might not be executed
|
||||
|
||||
|
||||
async def _handle_message_to_app_or_channel(
|
||||
ctx: ChannelContext,
|
||||
payload_length: int,
|
||||
message_length: int,
|
||||
ctrl_byte: int,
|
||||
sync_bit: int,
|
||||
) -> None:
|
||||
state = ctx.get_channel_state()
|
||||
if __debug__:
|
||||
log.debug(__name__, "state: %s", state_to_str(state))
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||
return
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
await _handle_state_TH1(
|
||||
ctx, payload_length, message_length, ctrl_byte, sync_bit
|
||||
)
|
||||
return
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
await _handle_state_TH2(ctx, message_length, ctrl_byte, sync_bit)
|
||||
return
|
||||
|
||||
if is_channel_state_pairing(state):
|
||||
await _handle_pairing(ctx, message_length)
|
||||
return
|
||||
|
||||
raise ThpError("Unimplemented channel state")
|
||||
|
||||
|
||||
async def _handle_state_TH1(
|
||||
ctx: ChannelContext,
|
||||
payload_length: int,
|
||||
message_length: int,
|
||||
ctrl_byte: int,
|
||||
sync_bit: int,
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_state_TH1")
|
||||
if not control_byte.is_handshake_init_req(ctrl_byte):
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError("Message received is not a valid handshake init request!")
|
||||
host_ephemeral_key = bytearray(
|
||||
ctx.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(ctx.channel_cache, host_ephemeral_key)
|
||||
|
||||
# send handshake init response message
|
||||
await ctx.write_handshake_message(
|
||||
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
|
||||
)
|
||||
ctx.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
|
||||
async def _handle_state_TH2(
|
||||
ctx: ChannelContext, message_length: int, ctrl_byte: int, sync_bit: int
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_state_TH2")
|
||||
if not control_byte.is_handshake_comp_req(ctrl_byte):
|
||||
raise ThpError("Message received is not a handshake completion request!")
|
||||
host_encrypted_static_pubkey = ctx.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = ctx.buffer[
|
||||
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
|
||||
]
|
||||
|
||||
noise_payload = thp_messages.decode_message(
|
||||
ctx.buffer[
|
||||
INIT_DATA_OFFSET
|
||||
+ KEY_LENGTH
|
||||
+ TAG_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
0,
|
||||
"ThpHandshakeCompletionReqNoisePayload",
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
|
||||
for i in noise_payload.pairing_methods:
|
||||
ctx.selected_pairing_methods.append(i)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"host static pubkey: %s, noise payload: %s",
|
||||
utils.get_bytes_as_str(host_encrypted_static_pubkey),
|
||||
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
|
||||
)
|
||||
|
||||
# TODO add credential recognition
|
||||
paired: bool = True # TODO should be output from credential check
|
||||
|
||||
# send hanshake completion response
|
||||
await ctx.write_handshake_message(
|
||||
HANDSHAKE_COMP_RES,
|
||||
thp_messages.get_handshake_completion_response(paired),
|
||||
)
|
||||
if paired:
|
||||
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
else:
|
||||
ctx.set_channel_state(ChannelState.TP1)
|
||||
|
||||
|
||||
async def _handle_state_ENCRYPTED_TRANSPORT(
|
||||
ctx: ChannelContext, message_length: int
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
||||
|
||||
ctx.decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
await _handle_channel_message(ctx, message_length, message_type)
|
||||
return
|
||||
if session_id not in ctx.sessions:
|
||||
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
|
||||
raise ThpError("Unalloacted session")
|
||||
|
||||
session_state = ctx.sessions[session_id].get_session_state()
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
|
||||
raise ThpError("Unalloacted session")
|
||||
ctx.sessions[session_id].incoming_message.publish(
|
||||
MessageWithType(
|
||||
message_type,
|
||||
ctx.buffer[
|
||||
INIT_DATA_OFFSET
|
||||
+ MESSAGE_TYPE_LENGTH
|
||||
+ SESSION_ID_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _handle_pairing(ctx: ChannelContext, message_length: int) -> None:
|
||||
|
||||
from .pairing_context import PairingContext
|
||||
|
||||
if ctx.connection_context is None:
|
||||
ctx.connection_context = PairingContext(ctx)
|
||||
|
||||
loop.schedule(ctx.connection_context.handle())
|
||||
ctx.decrypt_buffer(message_length)
|
||||
|
||||
message_type = ustruct.unpack(
|
||||
">H", ctx.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
|
||||
)[0]
|
||||
|
||||
ctx.connection_context.incoming_message.publish(
|
||||
MessageWithType(
|
||||
message_type,
|
||||
ctx.buffer[
|
||||
INIT_DATA_OFFSET
|
||||
+ MESSAGE_TYPE_LENGTH
|
||||
+ SESSION_ID_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
)
|
||||
)
|
||||
# 1. Check that message is expected with respect to the current state
|
||||
# 2. Handle the message
|
||||
pass
|
||||
|
||||
|
||||
def _should_have_ctrl_byte_encrypted_transport(ctx: ChannelContext) -> bool:
|
||||
if ctx.get_channel_state() in [
|
||||
ChannelState.UNALLOCATED,
|
||||
ChannelState.TH1,
|
||||
ChannelState.TH2,
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _handle_channel_message(
|
||||
ctx: ChannelContext, message_length: int, message_type: int
|
||||
) -> None:
|
||||
buf = ctx.buffer[
|
||||
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||
]
|
||||
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
|
||||
if not ThpCreateNewSession.is_type_of(message):
|
||||
raise ThpError(
|
||||
"The received message cannot be handled by channel itself. It must be sent to allocated session."
|
||||
)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
|
||||
|
||||
handler = get_handler_for_channel_message(message)
|
||||
task = handler(ctx, message)
|
||||
response_message = await task
|
||||
# TODO handle
|
||||
await ctx.write(response_message)
|
||||
if __debug__:
|
||||
log.debug(__name__, "_handle_channel_message - end")
|
@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire.thp import ChannelContext
|
||||
|
||||
|
||||
class Retransmission:
|
||||
|
||||
def __init__(
|
||||
self, channel_context: ChannelContext, ctrl_byte: int, payload: memoryview
|
||||
) -> None:
|
||||
self.channel_context: ChannelContext = channel_context
|
||||
self.ctrl_byte: int = ctrl_byte
|
||||
self.payload: memoryview = payload
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def change_ctrl_byte(self, ctrl_byte: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
@ -0,0 +1,162 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import log, loop, protobuf
|
||||
from trezor.wire import message_handler, protocol_common
|
||||
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import SessionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import ( # pyright: ignore[reportShadowedImports]
|
||||
Any,
|
||||
Awaitable,
|
||||
Container,
|
||||
)
|
||||
|
||||
from . import ChannelContext
|
||||
|
||||
pass
|
||||
|
||||
_EXIT_LOOP = True
|
||||
_REPEAT_LOOP = False
|
||||
|
||||
|
||||
class UnexpectedMessageWithType(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: MessageWithType) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class SessionContext(Context):
|
||||
def __init__(
|
||||
self, channel_ctx: ChannelContext, session_cache: SessionThpCache
|
||||
) -> None:
|
||||
if channel_ctx.channel_id != session_cache.channel_id:
|
||||
raise Exception(
|
||||
"The session has different channel id than the provided channel context!"
|
||||
)
|
||||
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
|
||||
self.channel_ctx = channel_ctx
|
||||
self.session_cache = session_cache
|
||||
self.session_id = int.from_bytes(session_cache.session_id, "big")
|
||||
self.incoming_message = loop.chan()
|
||||
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
if __debug__:
|
||||
self._handle_debug(is_debug_session)
|
||||
|
||||
take = self.incoming_message.take()
|
||||
next_message: MessageWithType | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
# TODO modules = utils.unimport_begin()
|
||||
while True:
|
||||
try:
|
||||
if await self._handle_message(take, next_message, is_debug_session):
|
||||
return
|
||||
except Exception as exc:
|
||||
# Log and try again.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
def _handle_debug(self, is_debug_session: bool) -> None:
|
||||
log.debug(__name__, "handle - start (session_id: %d)", self.session_id)
|
||||
if is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
take: Awaitable[Any],
|
||||
next_message: MessageWithType | None,
|
||||
is_debug_session: bool,
|
||||
) -> bool:
|
||||
|
||||
try:
|
||||
message = await self._get_message(take, next_message)
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
await self.write(failure(e))
|
||||
return _REPEAT_LOOP
|
||||
|
||||
try:
|
||||
next_message = await message_handler.handle_single_message(
|
||||
self, message, use_workflow=not is_debug_session
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
if not __debug__ or not is_debug_session:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# This is not done for the debug session because the snapshot taken
|
||||
# in a debug session would clear modules which are in use by the
|
||||
# workflow running on wire.
|
||||
# TODO utils.unimport_end(modules)
|
||||
|
||||
if next_message is None and message.type not in AVOID_RESTARTING_FOR:
|
||||
# Shut down the loop if there is no next message waiting.
|
||||
return _EXIT_LOOP # pylint: disable=lost-exception
|
||||
return _REPEAT_LOOP # pylint: disable=lost-exception
|
||||
|
||||
async def _get_message(
|
||||
self, take: Awaitable[Any], next_message: MessageWithType | None
|
||||
) -> MessageWithType:
|
||||
if next_message is None:
|
||||
# If the previous run did not keep an unprocessed message for us,
|
||||
# wait for a new one.
|
||||
message: MessageWithType = await take
|
||||
else:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
return message
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
message: MessageWithType = await self.incoming_message.take()
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageWithType(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_ctx.write(msg, self.session_id)
|
||||
|
||||
# ACCESS TO SESSION DATA
|
||||
|
||||
def get_session_state(self) -> SessionState:
|
||||
state = int.from_bytes(self.session_cache.state, "big")
|
||||
return SessionState(state)
|
||||
|
||||
def set_session_state(self, state: SessionState) -> None:
|
||||
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
|
@ -0,0 +1,36 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from trezor import loop
|
||||
|
||||
from .session_context import SessionContext
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import ChannelContext
|
||||
|
||||
|
||||
def create_new_session(channel_ctx: ChannelContext) -> SessionContext:
|
||||
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
|
||||
return SessionContext(channel_ctx, session_cache)
|
||||
|
||||
|
||||
def load_cached_sessions(channel_ctx: ChannelContext) -> dict[int, SessionContext]:
|
||||
if __debug__:
|
||||
log.debug(__name__, "load_cached_sessions")
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"load_cached_sessions - loaded a total of %d sessions from cache",
|
||||
len(cached_sessions),
|
||||
)
|
||||
for session in cached_sessions:
|
||||
if session.channel_id == channel_ctx.channel_id:
|
||||
sid = int.from_bytes(session.session_id, "big")
|
||||
sessions[sid] = SessionContext(channel_ctx, session)
|
||||
loop.schedule(sessions[sid].handle())
|
||||
return sessions
|
@ -0,0 +1,115 @@
|
||||
import ustruct # pyright:ignore[reportMissingModuleSource]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import log, protobuf
|
||||
|
||||
from .. import message_handler
|
||||
from ..protocol_common import Message
|
||||
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
HANDSHAKE_INIT_REQ = 0x00
|
||||
HANDSHAKE_INIT_RES = 0x01
|
||||
HANDSHAKE_COMP_REQ = 0x02
|
||||
HANDSHAKE_COMP_RES = 0x03
|
||||
ENCRYPTED_TRANSPORT = 0x04
|
||||
|
||||
ACK_MESSAGE = 0x20
|
||||
ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
|
||||
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return ustruct.pack(
|
||||
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
||||
)
|
||||
|
||||
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(
|
||||
InitHeader.format_str,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(">BH", buffer, buffer_offset, CONTINUATION_PACKET, self.cid)
|
||||
|
||||
@classmethod
|
||||
def get_error_header(cls, cid, length):
|
||||
return cls(ERROR, cid, length)
|
||||
|
||||
@classmethod
|
||||
def get_channel_allocation_response_header(cls, length):
|
||||
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
|
||||
|
||||
|
||||
class InterruptingInitPacket:
|
||||
def __init__(self, report: bytes) -> None:
|
||||
self.initReport = report
|
||||
|
||||
|
||||
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
||||
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
||||
)
|
||||
|
||||
_ERROR_UNALLOCATED_SESSION = (
|
||||
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
||||
)
|
||||
|
||||
|
||||
def get_device_properties() -> Message:
|
||||
return Message(_ENCODED_PROTOBUF_DEVICE_PROPERTIES)
|
||||
|
||||
|
||||
def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes:
|
||||
props_msg = get_device_properties()
|
||||
return nonce + new_cid + props_msg.to_bytes()
|
||||
|
||||
|
||||
def get_error_unallocated_channel() -> bytes:
|
||||
return _ERROR_UNALLOCATED_SESSION
|
||||
|
||||
|
||||
def get_handshake_init_response() -> bytes:
|
||||
# TODO implement - 32 bytes ephemeral key, 48 bytes encrypted and masked public key, 16 bytes ciphertext of empty string (i.e. noise tag)
|
||||
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
|
||||
|
||||
|
||||
def get_handshake_completion_response(paired: bool) -> bytes:
|
||||
if paired:
|
||||
return (
|
||||
TREZOR_STATE_PAIRED
|
||||
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
|
||||
)
|
||||
return (
|
||||
TREZOR_STATE_UNPAIRED
|
||||
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
|
||||
)
|
||||
|
||||
|
||||
def decode_message(
|
||||
buffer: bytes, msg_type: int, message_name: str | None = None
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
log.debug(__name__, "decode message")
|
||||
if message_name is not None:
|
||||
expected_type = protobuf.type_for_name(message_name)
|
||||
else:
|
||||
expected_type = protobuf.type_for_wire(msg_type)
|
||||
x = message_handler.wrap_protobuf_load(buffer, expected_type)
|
||||
return x
|
@ -0,0 +1,148 @@
|
||||
import ustruct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp as storage_thp_cache
|
||||
from storage.cache_thp import ChannelCache, SessionThpCache
|
||||
from trezor import log
|
||||
from trezor.wire.protocol_common import WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
from trezorio import WireInterface
|
||||
else:
|
||||
IntEnum = object
|
||||
|
||||
|
||||
class ThpError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
|
||||
PAIRED = 2
|
||||
UNPAIRED = 3
|
||||
PAIRING = 4
|
||||
APP_TRAFFIC = 5
|
||||
|
||||
|
||||
def create_autenticated_session(unauthenticated_session: SessionThpCache):
|
||||
# storage_thp_cache.start_session() - TODO something like this but for THP
|
||||
raise NotImplementedError("Secure channel is not implemented, yet.")
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
|
||||
session_id = _get_id(iface, cid)
|
||||
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
|
||||
set_session_state(new_session, SessionState.INITIALIZED)
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
return storage_thp_cache.get_active_session()
|
||||
|
||||
|
||||
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
|
||||
session_id = _get_id(iface, cid)
|
||||
return get_session_from_id(session_id)
|
||||
|
||||
|
||||
def get_session_from_id(session_id) -> SessionThpCache | None:
|
||||
session = _get_authenticated_session_or_none(session_id)
|
||||
if session is None:
|
||||
session = _get_unauthenticated_session_or_none(session_id)
|
||||
return session
|
||||
|
||||
|
||||
def get_state(session: SessionThpCache | None) -> int:
|
||||
if session is None:
|
||||
return SessionState.UNALLOCATED
|
||||
return _decode_session_state(session.state)
|
||||
|
||||
|
||||
def get_cid(session: SessionThpCache) -> int:
|
||||
return storage_thp_cache._get_cid(session)
|
||||
|
||||
|
||||
def sync_can_send_message(cache: SessionThpCache | ChannelCache) -> bool:
|
||||
return cache.sync & 0x80 == 0x80
|
||||
|
||||
|
||||
def sync_get_receive_expected_bit(cache: SessionThpCache | ChannelCache) -> int:
|
||||
return (cache.sync & 0x40) >> 6
|
||||
|
||||
|
||||
def sync_get_send_bit(cache: SessionThpCache | ChannelCache) -> int:
|
||||
return (cache.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def sync_set_can_send_message(
|
||||
cache: SessionThpCache | ChannelCache, can_send: bool
|
||||
) -> None:
|
||||
cache.sync &= 0x7F
|
||||
if can_send:
|
||||
cache.sync |= 0x80
|
||||
|
||||
|
||||
def sync_set_receive_expected_bit(
|
||||
cache: SessionThpCache | ChannelCache, bit: int
|
||||
) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "Set sync receive expected bit to %d", bit)
|
||||
if bit not in (0, 1):
|
||||
raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# set second bit to "bit" value
|
||||
cache.sync &= 0xBF
|
||||
if bit:
|
||||
cache.sync |= 0x40
|
||||
|
||||
|
||||
def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None:
|
||||
_sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache))
|
||||
|
||||
|
||||
def is_active_session(session: SessionThpCache):
|
||||
if session is None:
|
||||
return False
|
||||
return session.session_id == storage_thp_cache.get_active_session_id()
|
||||
|
||||
|
||||
def set_session_state(session: SessionThpCache, new_state: SessionState):
|
||||
session.state = bytearray(new_state.to_bytes(1, "big"))
|
||||
|
||||
|
||||
def _get_id(iface: WireInterface, cid: int) -> bytes:
|
||||
return ustruct.pack(">HH", iface.iface_num(), cid)
|
||||
|
||||
|
||||
def _get_authenticated_session_or_none(session_id) -> SessionThpCache | None:
|
||||
for authenticated_session in storage_thp_cache._SESSIONS:
|
||||
if authenticated_session.session_id == session_id:
|
||||
return authenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None:
|
||||
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||
if unauthenticated_session.session_id == session_id:
|
||||
return unauthenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
|
||||
if bit not in (0, 1):
|
||||
raise ThpError("Unexpected send sync bit")
|
||||
if __debug__:
|
||||
log.debug(__name__, "setting sync send bit to %d", bit)
|
||||
# set third bit to "bit" value
|
||||
cache.sync &= 0xDF
|
||||
if bit:
|
||||
cache.sync |= 0x20
|
||||
|
||||
|
||||
def _decode_session_state(state: bytearray) -> int:
|
||||
return ustruct.unpack("B", state)[0]
|
||||
|
||||
|
||||
def _encode_session_state(state: SessionState) -> bytes:
|
||||
return SessionState.to_bytes(state, 1, "big")
|
@ -0,0 +1,57 @@
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from trezor import io, log, loop, utils
|
||||
from trezor.wire.thp.thp_messages import InitHeader
|
||||
|
||||
INIT_DATA_OFFSET = const(5)
|
||||
CONT_DATA_OFFSET = const(3)
|
||||
REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
MESSAGE_TYPE_LENGTH = const(2)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
|
||||
async def write_payload_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
):
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_payload_to_wire")
|
||||
# prepare the report buffer with header data
|
||||
payload_len = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
|
||||
await _write_report_to_wire(iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_len:
|
||||
if nwritten >= payload_len - REPORT_LENGTH:
|
||||
# Sanitation of last report
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report_to_wire(iface, report)
|
||||
|
||||
|
||||
async def _write_report_to_wire(iface: WireInterface, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "write_report_to_wire: %s", utils.get_bytes_as_str(report)
|
||||
)
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
@ -0,0 +1,160 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import io, log, loop, utils
|
||||
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ChannelState, channel_manager, checksum, session_manager, thp_messages
|
||||
from .thp.channel import Channel
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader
|
||||
from .thp.thp_session import ThpError
|
||||
from .thp.writer import MAX_PAYLOAD_LEN, REPORT_LENGTH, write_payload_to_wire
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
||||
|
||||
|
||||
_BUFFER: bytearray
|
||||
_BUFFER_LOCK = None
|
||||
|
||||
CHANNELS: dict[int, Channel] = {}
|
||||
|
||||
|
||||
def set_buffer(buffer):
|
||||
global _BUFFER
|
||||
_BUFFER = buffer
|
||||
|
||||
|
||||
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
global CHANNELS
|
||||
global _BUFFER
|
||||
CHANNELS = channel_manager.load_cached_channels(_BUFFER)
|
||||
for ch in CHANNELS.values():
|
||||
ch.sessions = session_manager.load_cached_sessions(ch)
|
||||
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if __debug__:
|
||||
log.debug(__name__, "thp_main_loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
if ctrl_byte == CODEC_V1:
|
||||
pass
|
||||
# TODO add handling of (unsupported) codec_v1 packets
|
||||
# possibly ignore continuation packets, i.e. if the
|
||||
# following bytes are not "##"", do not respond
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
continue
|
||||
|
||||
if cid in CHANNELS:
|
||||
await _handle_allocated(iface, cid, packet)
|
||||
else:
|
||||
await _handle_unallocated(iface, cid)
|
||||
|
||||
except ThpError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
|
||||
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
|
||||
|
||||
|
||||
async def _handle_broadcast(
|
||||
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
|
||||
) -> MessageWithId | None:
|
||||
global _BUFFER
|
||||
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received valid message on broadcast channel ")
|
||||
|
||||
length, nonce = ustruct.unpack(">H8s", packet[3:])
|
||||
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
||||
payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
||||
|
||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
new_channel: Channel = channel_manager.create_new_channel(iface, _BUFFER)
|
||||
cid = int.from_bytes(new_channel.channel_id, "big")
|
||||
CHANNELS[cid] = new_channel
|
||||
|
||||
response_data = thp_messages.get_channel_allocation_response(
|
||||
nonce, new_channel.channel_id
|
||||
)
|
||||
response_header = InitHeader.get_channel_allocation_response_header(
|
||||
len(response_data) + CHECKSUM_LENGTH,
|
||||
)
|
||||
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
||||
if __debug__:
|
||||
log.debug(__name__, "New channel allocated with id %d", cid)
|
||||
|
||||
await write_payload_to_wire(iface, response_header, response_data + chksum)
|
||||
|
||||
|
||||
async def _handle_allocated(
|
||||
iface: WireInterface, cid: int, packet: utils.BufferType
|
||||
) -> None:
|
||||
channel = CHANNELS[cid]
|
||||
if channel is None:
|
||||
# TODO send error message to wire
|
||||
raise ThpError("Invalid state of a channel")
|
||||
if channel.iface is not iface:
|
||||
# TODO send error message to wire
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
await channel.receive_packet(packet)
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
||||
data = thp_messages.get_error_unallocated_channel()
|
||||
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes() + data)
|
||||
await write_payload_to_wire(iface, header, data + chksum)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
return _try_allocate_new_buffer(payload_length)
|
||||
return _reuse_existing_buffer(payload_length, existing_buffer)
|
||||
|
||||
|
||||
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
|
||||
def _reuse_existing_buffer(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def deprecated_read_message(
|
||||
iface: WireInterface, buffer: utils.BufferType
|
||||
) -> MessageWithId:
|
||||
return MessageWithId(-1, b"\x00")
|
||||
|
||||
|
||||
async def deprecated_write_message(
|
||||
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||
) -> None:
|
||||
pass
|
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
declare -a results
|
||||
declare -i passed=0 failed=0 exit_code=0
|
||||
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
|
||||
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
|
||||
print_summary() {
|
||||
echo
|
||||
echo 'Summary:'
|
||||
echo '-------------------'
|
||||
printf '%b\n' "${results[@]}"
|
||||
if [ $exit_code == 0 ]; then
|
||||
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
|
||||
else
|
||||
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
|
||||
fi
|
||||
}
|
||||
|
||||
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
|
||||
|
||||
cd $(dirname $0)
|
||||
|
||||
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
|
||||
|
||||
declare -i num_of_tests=${#tests[@]}
|
||||
|
||||
for test_case in ${tests[@]}; do
|
||||
echo ${MICROPYTHON}
|
||||
echo ${test_case}
|
||||
echo
|
||||
if $MICROPYTHON $test_case; then
|
||||
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
|
||||
((passed++))
|
||||
else
|
||||
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
|
||||
((failed++))
|
||||
exit_code=1
|
||||
fi
|
||||
done
|
||||
|
||||
print_summary
|
||||
exit $exit_code
|
@ -0,0 +1,71 @@
|
||||
from common import *
|
||||
|
||||
from trezor import io
|
||||
from trezor.loop import wait
|
||||
from trezor.wire import thp_v1
|
||||
from trezor.wire.thp import channel
|
||||
from storage import cache_thp
|
||||
from ubinascii import hexlify
|
||||
from trezor.wire.thp import ChannelState
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
|
||||
def dummy_decode_iface(cached_iface: bytes):
|
||||
return MockHID(0xDEADBEEF)
|
||||
|
||||
|
||||
def getBytes(a):
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
class TestTrezorHostProtocol(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
buffer = bytearray(64)
|
||||
thp_v1.set_buffer(buffer)
|
||||
channel._decode_iface = dummy_decode_iface
|
||||
|
||||
def test_simple(self):
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_channel_allocation(self):
|
||||
cid_req = (
|
||||
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
)
|
||||
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
|
||||
test_counter = cache_thp.cid_counter + 1
|
||||
self.assertEqual(len(thp_v1.CHANNELS), 0)
|
||||
self.assertFalse(test_counter in thp_v1.CHANNELS)
|
||||
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
self.assertEqual(
|
||||
getBytes(self.interface.data[-1]),
|
||||
expected_response,
|
||||
)
|
||||
self.assertTrue(test_counter in thp_v1.CHANNELS)
|
||||
self.assertEqual(len(thp_v1.CHANNELS), 1)
|
||||
|
||||
def test_channel_default_state_is_TH1(self):
|
||||
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,341 @@
|
||||
from common import *
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor.wire.thp.writer import REPORT_LENGTH
|
||||
from ubinascii import hexlify
|
||||
import ustruct
|
||||
|
||||
from trezor import io, utils
|
||||
from trezor.loop import wait
|
||||
from trezor.utils import chunks
|
||||
from trezor.wire import thp_v1
|
||||
from trezor.wire.protocol_common import MessageWithId
|
||||
import trezor.wire.thp.thp_session as THP
|
||||
from trezor.wire.thp import checksum
|
||||
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
MESSAGE_TYPE_BYTES = b"\x42\x42"
|
||||
_MESSAGE_TYPE_LEN = 2
|
||||
PLAINTEXT_0 = 0x01
|
||||
PLAINTEXT_1 = 0x11
|
||||
COMMON_CID = 4660
|
||||
CONT = 0x80
|
||||
|
||||
HEADER_INIT_LENGTH = 5
|
||||
HEADER_CONT_LENGTH = 3
|
||||
INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
||||
|
||||
|
||||
def make_header(ctrl_byte, cid, length):
|
||||
return ustruct.pack(">BHH", ctrl_byte, cid, length)
|
||||
|
||||
|
||||
def make_cont_header():
|
||||
return ustruct.pack(">BH", CONT, COMMON_CID)
|
||||
|
||||
|
||||
def makeSimpleMessage(header, message_type, message_data):
|
||||
return header + ustruct.pack(">H", message_type) + message_data
|
||||
|
||||
|
||||
def makeCidRequest(header, message_data):
|
||||
return header + message_data
|
||||
|
||||
|
||||
def printBytes(a):
|
||||
print(hexlify(a).decode("utf-8"))
|
||||
|
||||
|
||||
def getPlaintext() -> bytes:
|
||||
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
||||
return PLAINTEXT_1
|
||||
return PLAINTEXT_0
|
||||
|
||||
|
||||
def getCid() -> int:
|
||||
return THP.get_cid(THP.get_active_session())
|
||||
|
||||
|
||||
# This test suite is an adaptation of test_trezor.wire.codec_v1
|
||||
class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
if not utils.USE_THP:
|
||||
import storage.cache_thp # noqa: F401
|
||||
|
||||
def _simple(self):
|
||||
cid_req_header = make_header(
|
||||
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
||||
)
|
||||
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
|
||||
|
||||
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
|
||||
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
|
||||
message = makeSimpleMessage(
|
||||
message_header,
|
||||
MESSAGE_TYPE,
|
||||
cid_request_dummy_data + cid_request_dummy_data_checksum,
|
||||
)
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(cid_req_message)
|
||||
gen.send(None)
|
||||
gen.send(message)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, cid_request_dummy_data)
|
||||
|
||||
buffer_without_zeroes = buffer[: len(message) - 5]
|
||||
message_without_header = message[5:]
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer_without_zeroes, message_without_header)
|
||||
|
||||
def _read_one_packet(self):
|
||||
# zero length message - just a header
|
||||
PLAINTEXT = getPlaintext()
|
||||
header = make_header(
|
||||
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||
message = header + MESSAGE_TYPE_BYTES + chksum
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(message)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, b"")
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
|
||||
|
||||
def _read_many_packets(self):
|
||||
message = bytes(range(256))
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||
)
|
||||
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||
|
||||
# first packet is init header + 59 bytes of data
|
||||
# other packets are cont header + 61 bytes of data
|
||||
cont_header = make_cont_header()
|
||||
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||
64 - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
buffer = bytearray(262)
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in packets:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
# last packet will stop
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# message should have been read into the buffer )
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
|
||||
|
||||
def _read_large_message(self):
|
||||
message = b"hello world"
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
packet = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ message
|
||||
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||
)
|
||||
|
||||
# make sure we fit into one packet, to make this easier
|
||||
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
|
||||
|
||||
buffer = bytearray(1)
|
||||
self.assertTrue(len(buffer) <= len(packet))
|
||||
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(packet)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# read should have allocated its own buffer and not touch ours
|
||||
self.assertEqual(buffer, b"\x00")
|
||||
|
||||
def _roundtrip(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
# exhaust the iterator:
|
||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||
for query in gen:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
buffer = bytearray(1024)
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in self.interface.data:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
printBytes(packet)
|
||||
query = gen.send(packet)
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message.data)
|
||||
|
||||
def _write_one_packet(self):
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
header = make_header(
|
||||
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||
)
|
||||
expected_message = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
|
||||
)
|
||||
self.assertTrue(self.interface.data == [expected_message])
|
||||
|
||||
def _write_multiple_packets(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_1,
|
||||
COMMON_CID,
|
||||
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||
)
|
||||
cont_header = make_cont_header()
|
||||
chksum = checksum.compute(
|
||||
header + message.type.to_bytes(2, "big") + message.data
|
||||
)
|
||||
packets = [
|
||||
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
|
||||
] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
|
||||
for _ in packets:
|
||||
# we receive as many queries as there are packets
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
# the first sent None only started the generator. the len(packets)-th None
|
||||
# will finish writing and raise StopIteration
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
# packets must be identical up to the last one
|
||||
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||
# last packet must be identical up to message length. remaining bytes in
|
||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||
# previous packet
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||
self.assertEqual(last_packet, self.interface.data[-1])
|
||||
|
||||
def _read_huge_packet(self):
|
||||
PACKET_COUNT = 1180
|
||||
# message that takes up 1 180 USB packets
|
||||
message_size = (PACKET_COUNT - 1) * (
|
||||
REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
|
||||
) + INIT_MESSAGE_DATA_LENGTH
|
||||
|
||||
# ensure that a message this big won't fit into memory
|
||||
# Note: this control is changed, because THP has only 2 byte length field
|
||||
self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN)
|
||||
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||
buffer = bytearray(65536)
|
||||
gen = thp_v1.deprecated_read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
|
||||
# THP returns "Message too large" error after reading the message size,
|
||||
# it is different from codec_v1 as it does not allow big enough messages
|
||||
# to raise MemoryError in this test
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(thp_v1.ThpError) as e:
|
||||
query = gen.send(packet)
|
||||
|
||||
self.assertEqual(e.value.args[0], "Message too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1 @@
|
||||
../../vendor/trezor-common/protob/messages-thp.proto
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue