diff --git a/src/apps/common/confirm.py b/src/apps/common/confirm.py index 6ab5032bfb..aeeed28b6e 100644 --- a/src/apps/common/confirm.py +++ b/src/apps/common/confirm.py @@ -7,7 +7,7 @@ signal = loop.Signal() @unimport -async def confirm(session_id, content, code=None, *args, **kwargs): +async def confirm(ctx, content, code=None, *args, **kwargs): from trezor.ui.confirm import ConfirmDialog, CONFIRMED from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequestType import Other @@ -19,12 +19,12 @@ async def confirm(session_id, content, code=None, *args, **kwargs): if code is None: code = Other - await wire.call(session_id, ButtonRequest(code=code), ButtonAck) + await ctx.call(ButtonRequest(code=code), ButtonAck) return await loop.Wait((signal, dialog)) == CONFIRMED @unimport -async def hold_to_confirm(session_id, content, code=None, *args, **kwargs): +async def hold_to_confirm(ctx, content, code=None, *args, **kwargs): from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequestType import Other @@ -36,7 +36,7 @@ async def hold_to_confirm(session_id, content, code=None, *args, **kwargs): if code is None: code = Other - await wire.call(session_id, ButtonRequest(code=code), ButtonAck) + await ctx.call(ButtonRequest(code=code), ButtonAck) return await loop.Wait((signal, dialog)) == CONFIRMED diff --git a/src/apps/common/request_passphrase.py b/src/apps/common/request_passphrase.py index 88cb418ceb..6aacf7abb9 100644 --- a/src/apps/common/request_passphrase.py +++ b/src/apps/common/request_passphrase.py @@ -1,7 +1,7 @@ from trezor import ui, wire -async def request_passphrase(session_id): +async def request_passphrase(ctx): from trezor.messages.FailureType import ActionCancelled from trezor.messages.PassphraseRequest import PassphraseRequest from trezor.messages.wire_types import PassphraseAck, Cancel @@ -12,17 +12,17 @@ async def request_passphrase(session_id): 'Please enter passphrase', 'on your computer.') text.render() - ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel) + ack = await ctx.call(PassphraseRequest(), PassphraseAck, Cancel) if ack.MESSAGE_WIRE_TYPE == Cancel: raise wire.FailureError(ActionCancelled, 'Passphrase cancelled') return ack.passphrase -async def protect_by_passphrase(session_id): +async def protect_by_passphrase(ctx): from apps.common import storage if storage.is_protected_by_passphrase(): - return await request_passphrase(session_id) + return await request_passphrase(ctx) else: return '' diff --git a/src/apps/common/request_pin.py b/src/apps/common/request_pin.py index f85356c34b..06a8967dc4 100644 --- a/src/apps/common/request_pin.py +++ b/src/apps/common/request_pin.py @@ -7,7 +7,7 @@ if __debug__: @unimport -async def request_pin_on_display(session_id: int, code: int=None) -> str: +async def request_pin_on_display(ctx: wire.Context, code: int=None) -> str: from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.FailureType import PinCancelled @@ -20,9 +20,8 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str: _, label = _get_code_and_label(code) - await wire.call(session_id, - ButtonRequest(code=ProtectCall), - ButtonAck) + await ctx.call(ButtonRequest(code=ProtectCall), + ButtonAck) ui.display.clear() matrix = PinMatrix(label) @@ -36,7 +35,7 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str: @unimport -async def request_pin_on_client(session_id: int, code: int=None) -> str: +async def request_pin_on_client(ctx: wire.Context, code: int=None) -> str: from trezor.messages.FailureType import PinCancelled from trezor.messages.PinMatrixRequest import PinMatrixRequest from trezor.messages.wire_types import PinMatrixAck, Cancel @@ -51,9 +50,8 @@ async def request_pin_on_client(session_id: int, code: int=None) -> str: matrix = PinMatrix(label) matrix.render() - ack = await wire.call(session_id, - PinMatrixRequest(type=code), - PinMatrixAck, Cancel) + ack = await ctx.call(PinMatrixRequest(type=code), + PinMatrixAck, Cancel) digits = matrix.digits matrix = None @@ -66,12 +64,12 @@ request_pin = request_pin_on_client @unimport -async def request_pin_twice(session_id: int) -> str: +async def request_pin_twice(ctx: wire.Context) -> str: from trezor.messages.FailureType import ActionCancelled from trezor.messages import PinMatrixRequestType - pin_first = await request_pin(session_id, PinMatrixRequestType.NewFirst) - pin_again = await request_pin(session_id, PinMatrixRequestType.NewSecond) + pin_first = await request_pin(ctx, PinMatrixRequestType.NewFirst) + pin_again = await request_pin(ctx, PinMatrixRequestType.NewSecond) if pin_first != pin_again: # changed message due to consistency with T1 msgs raise wire.FailureError(ActionCancelled, 'PIN change failed') @@ -79,22 +77,22 @@ async def request_pin_twice(session_id: int) -> str: return pin_first -async def protect_by_pin_repeatedly(session_id: int, at_least_once: bool=False): +async def protect_by_pin_repeatedly(ctx: wire.Context, at_least_once: bool=False): from . import storage locked = storage.is_locked() or at_least_once while locked: - pin = await request_pin(session_id) + pin = await request_pin(ctx) locked = not storage.unlock(pin, _render_pin_failure) -async def protect_by_pin_or_fail(session_id: int, at_least_once: bool=False): +async def protect_by_pin_or_fail(ctx: wire.Context, at_least_once: bool=False): from trezor.messages.FailureType import PinInvalid from . import storage locked = storage.is_locked() or at_least_once if locked: - pin = await request_pin(session_id) + pin = await request_pin(ctx) if not storage.unlock(pin, _render_pin_failure): raise wire.FailureError(PinInvalid, 'PIN invalid') diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 1d1bb12664..af8a15ddd2 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -5,20 +5,20 @@ from trezor.crypto import bip39 _DEFAULT_CURVE = 'secp256k1' -async def get_root(session_id: int, curve_name=_DEFAULT_CURVE): - seed = await get_seed(session_id) +async def get_root(ctx: wire.Context, curve_name=_DEFAULT_CURVE): + seed = await get_seed(ctx) root = bip32.from_seed(seed, curve_name) return root -async def get_seed(session_id: int) -> bytes: +async def get_seed(ctx: wire.Context) -> bytes: from . import cache if cache.seed is None: - cache.seed = await compute_seed(session_id) + cache.seed = await compute_seed(ctx) return cache.seed -async def compute_seed(session_id: int) -> bytes: +async def compute_seed(ctx: wire.Context) -> bytes: from trezor.messages.FailureType import ProcessError from .request_passphrase import protect_by_passphrase from .request_pin import protect_by_pin @@ -27,9 +27,9 @@ async def compute_seed(session_id: int) -> bytes: if not storage.is_initialized(): raise wire.FailureError(ProcessError, 'Device is not initialized') - await protect_by_pin(session_id) + await protect_by_pin(ctx) - passphrase = await protect_by_passphrase(session_id) + passphrase = await protect_by_passphrase(ctx) return bip39.seed(storage.get_mnemonic(), passphrase) diff --git a/src/apps/debug/__init__.py b/src/apps/debug/__init__.py index 7874e86a2b..0cca27e460 100644 --- a/src/apps/debug/__init__.py +++ b/src/apps/debug/__init__.py @@ -4,13 +4,13 @@ from trezor.messages.wire_types import \ DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase -async def dispatch_DebugLinkDecision(session_id, msg): +async def dispatch_DebugLinkDecision(ctx, msg): from trezor.ui.confirm import CONFIRMED, CANCELLED from apps.common.confirm import signal signal.send(CONFIRMED if msg.yes_no else CANCELLED) -async def dispatch_DebugLinkGetState(session_id, msg): +async def dispatch_DebugLinkGetState(ctx, msg): from trezor.messages.DebugLinkState import DebugLinkState from apps.common import storage, request_pin from apps.management import reset_device @@ -36,11 +36,11 @@ async def dispatch_DebugLinkGetState(session_id, msg): return m -async def dispatch_DebugLinkStop(session_id, msg): +async def dispatch_DebugLinkStop(ctx, msg): pass -async def dispatch_DebugLinkMemoryRead(session_id, msg): +async def dispatch_DebugLinkMemoryRead(ctx, msg): from trezor.messages.DebugLinkMemory import DebugLinkMemory from uctypes import bytes_at m = DebugLinkMemory() @@ -48,14 +48,14 @@ async def dispatch_DebugLinkMemoryRead(session_id, msg): return m -async def dispatch_DebugLinkMemoryWrite(session_id, msg): +async def dispatch_DebugLinkMemoryWrite(ctx, msg): from uctypes import bytearray_at l = len(msg.memory) data = bytearray_at(msg.address, l) data[0:l] = msg.memory -async def dispatch_DebugLinkFlashErase(session_id, msg): +async def dispatch_DebugLinkFlashErase(ctx, msg): # TODO: erase(msg.sector) pass diff --git a/src/apps/ethereum/ethereum_get_address.py b/src/apps/ethereum/ethereum_get_address.py index a5b2a0ee6a..a00d522727 100644 --- a/src/apps/ethereum/ethereum_get_address.py +++ b/src/apps/ethereum/ethereum_get_address.py @@ -3,13 +3,13 @@ from trezor.utils import unimport @unimport -async def layout_ethereum_get_address(session_id, msg): +async def layout_ethereum_get_address(ctx, msg): from trezor.messages.EthereumAddress import EthereumAddress from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import sha3_256 from ..common import seed - node = await seed.get_root(session_id) + node = await seed.get_root(ctx) node.derive_path(msg.address_n or ()) seckey = node.private_key() @@ -17,11 +17,11 @@ async def layout_ethereum_get_address(session_id, msg): address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak if msg.show_display: - await _show_address(session_id, address) + await _show_address(ctx, address) return EthereumAddress(address=address) -async def _show_address(session_id, address): +async def _show_address(ctx, address): from trezor.messages.ButtonRequestType import Address from trezor.ui.text import Text from ..common.confirm import require_confirm @@ -30,7 +30,7 @@ async def _show_address(session_id, address): content = Text('Confirm address', ui.ICON_RESET, ui.MONO, *_split_address(address)) - await require_confirm(session_id, content, code=Address) + await require_confirm(ctx, content, code=Address) def _split_address(address): diff --git a/src/apps/fido_u2f/__init__.py b/src/apps/fido_u2f/__init__.py index 4b9594a7d8..b7da0c7e86 100644 --- a/src/apps/fido_u2f/__init__.py +++ b/src/apps/fido_u2f/__init__.py @@ -226,11 +226,11 @@ class Cmd: return Msg(self.cid, cla, ins, p1, p2, lc, data) -async def read_cmd(iface: int) -> Cmd: +async def read_cmd(iface: io.HID) -> Cmd: desc_init = frame_init() desc_cont = frame_cont() - buf, = await loop.select(iface) + buf, = await loop.select(iface.iface_num()) # log.debug(__name__, 'read init %s', buf) ifrm = overlay_struct(buf, desc_init) @@ -252,7 +252,7 @@ async def read_cmd(iface: int) -> Cmd: data = data[:bcnt] while datalen < bcnt: - buf, = await loop.select(iface) + buf, = await loop.select(iface.iface_num()) # log.debug(__name__, 'read cont %s', buf) cfrm = overlay_struct(buf, desc_cont) @@ -282,7 +282,7 @@ async def read_cmd(iface: int) -> Cmd: return Cmd(ifrm.cid, ifrm.cmd, data) -def send_cmd(cmd: Cmd, iface: int) -> None: +def send_cmd(cmd: Cmd, iface: io.HID) -> None: init_desc = frame_init() cont_desc = frame_cont() offset = 0 @@ -295,7 +295,7 @@ def send_cmd(cmd: Cmd, iface: int) -> None: frm.bcnt = datalen offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) - io.send(iface, buf) + iface.write(buf) # log.debug(__name__, 'send init %s', buf) if offset < datalen: @@ -304,18 +304,17 @@ def send_cmd(cmd: Cmd, iface: int) -> None: while offset < datalen: frm.seq = seq offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) - utime.sleep_ms(1) # FIXME: do async send - io.send(iface, buf) + utime.sleep_ms(1) # FIXME: async write + iface.write(buf) # log.debug(__name__, 'send cont %s', buf) seq += 1 -def boot(): - iface = 0x03 +def boot(iface: io.HID): loop.schedule_task(handle_reports(iface)) -async def handle_reports(iface: int): +async def handle_reports(iface: io.HID): while True: try: req = await read_cmd(iface) diff --git a/src/apps/homescreen/__init__.py b/src/apps/homescreen/__init__.py index 5a1aebe423..3f506ec2a4 100644 --- a/src/apps/homescreen/__init__.py +++ b/src/apps/homescreen/__init__.py @@ -4,7 +4,7 @@ from trezor.messages.wire_types import Initialize, GetFeatures, Ping @unimport -async def respond_Features(session_id, msg): +async def respond_Features(ctx, msg): from apps.common import storage, coins from trezor.messages.Features import Features @@ -28,7 +28,7 @@ async def respond_Features(session_id, msg): @unimport -async def respond_Pong(session_id, msg): +async def respond_Pong(ctx, msg): from trezor.messages.Success import Success s = Success() @@ -36,11 +36,11 @@ async def respond_Pong(session_id, msg): if msg.pin_protection: from apps.common.request_pin import protect_by_pin - await protect_by_pin(session_id) + await protect_by_pin(ctx) if msg.passphrase_protection: from apps.common.request_passphrase import protect_by_passphrase - await protect_by_passphrase(session_id) + await protect_by_passphrase(ctx) # TODO: handle other fields: # button_protection diff --git a/src/apps/management/apply_settings.py b/src/apps/management/apply_settings.py index 20e7916bd8..815f105927 100644 --- a/src/apps/management/apply_settings.py +++ b/src/apps/management/apply_settings.py @@ -3,7 +3,7 @@ from trezor.utils import unimport @unimport -async def layout_apply_settings(session_id, msg): +async def layout_apply_settings(ctx, msg): from trezor.messages.Success import Success from trezor.messages.FailureType import ProcessError from trezor.ui.text import Text @@ -11,7 +11,7 @@ async def layout_apply_settings(session_id, msg): from ..common.request_pin import protect_by_pin from ..common import storage - await protect_by_pin(session_id) + await protect_by_pin(ctx) if msg.homescreen is not None: raise wire.FailureError( @@ -21,20 +21,20 @@ async def layout_apply_settings(session_id, msg): raise wire.FailureError(ProcessError, 'No setting provided') if msg.label is not None: - await require_confirm(session_id, Text( + await require_confirm(ctx, Text( 'Change label', ui.ICON_RESET, 'Do you really want to', 'change label to', ui.BOLD, '%s' % msg.label)) if msg.language is not None: - await require_confirm(session_id, Text( + await require_confirm(ctx, Text( 'Change language', ui.ICON_RESET, 'Do you really want to', 'change language to', ui.BOLD, '%s' % msg.language, ui.NORMAL, '?')) if msg.use_passphrase is not None: - await require_confirm(session_id, Text( + await require_confirm(ctx, Text( 'Enable passphrase' if msg.use_passphrase else 'Disable passphrase', ui.ICON_RESET, 'Do you really want to', diff --git a/src/apps/management/change_pin.py b/src/apps/management/change_pin.py index cbb03a3618..f432465f89 100644 --- a/src/apps/management/change_pin.py +++ b/src/apps/management/change_pin.py @@ -2,52 +2,52 @@ from trezor import ui from trezor.utils import unimport -def confirm_set_pin(session_id): +def confirm_set_pin(ctx): from apps.common.confirm import require_confirm from trezor.ui.text import Text - return require_confirm(session_id, Text( + return require_confirm(ctx, Text( 'Change PIN', ui.ICON_RESET, 'Do you really want to', ui.BOLD, 'set new PIN?')) -def confirm_change_pin(session_id): +def confirm_change_pin(ctx): from apps.common.confirm import require_confirm from trezor.ui.text import Text - return require_confirm(session_id, Text( + return require_confirm(ctx, Text( 'Change PIN', ui.ICON_RESET, 'Do you really want to', ui.BOLD, 'change current PIN?')) -def confirm_remove_pin(session_id): +def confirm_remove_pin(ctx): from apps.common.confirm import require_confirm from trezor.ui.text import Text - return require_confirm(session_id, Text( + return require_confirm(ctx, Text( 'Remove PIN', ui.ICON_RESET, 'Do you really want to', ui.BOLD, 'remove current PIN?')) @unimport -async def layout_change_pin(session_id, msg): +async def layout_change_pin(ctx, msg): from trezor.messages.Success import Success from apps.common.request_pin import protect_by_pin, request_pin_twice from apps.common import storage if msg.remove: if storage.is_protected_by_pin(): - await confirm_remove_pin(session_id) - await protect_by_pin(session_id, at_least_once=True) + await confirm_remove_pin(ctx) + await protect_by_pin(ctx, at_least_once=True) pin = '' else: if storage.is_protected_by_pin(): - await confirm_change_pin(session_id) - await protect_by_pin(session_id, at_least_once=True) + await confirm_change_pin(ctx) + await protect_by_pin(ctx, at_least_once=True) else: - await confirm_set_pin(session_id) - pin = await request_pin_twice(session_id) + await confirm_set_pin(ctx) + pin = await request_pin_twice(ctx) storage.load_settings(pin=pin) if pin: diff --git a/src/apps/management/load_device.py b/src/apps/management/load_device.py index 29b9cddbb3..89c850cca8 100644 --- a/src/apps/management/load_device.py +++ b/src/apps/management/load_device.py @@ -3,7 +3,7 @@ from trezor.utils import unimport @unimport -async def layout_load_device(session_id, msg): +async def layout_load_device(ctx, msg): from trezor.crypto import bip39 from trezor.messages.Success import Success from trezor.messages.FailureType import UnexpectedMessage, ProcessError @@ -20,7 +20,7 @@ async def layout_load_device(session_id, msg): if not msg.skip_checksum and not bip39.check(msg.mnemonic): raise wire.FailureError(ProcessError, 'Mnemonic is not valid') - await require_confirm(session_id, Text( + await require_confirm(ctx, Text( 'Loading seed', ui.ICON_RESET, ui.BOLD, 'Loading private seed', 'is not recommended.', ui.NORMAL, 'Continue only if you', 'know what you are doing!')) diff --git a/src/apps/management/recovery_device.py b/src/apps/management/recovery_device.py index e9fdb9f2e0..3dc2dd2cc5 100644 --- a/src/apps/management/recovery_device.py +++ b/src/apps/management/recovery_device.py @@ -11,7 +11,7 @@ def nth(n): @unimport -async def layout_recovery_device(session_id, msg): +async def layout_recovery_device(ctx, msg): msg = 'Please enter ' + nth(msg.word_count) + ' word' diff --git a/src/apps/management/reset_device.py b/src/apps/management/reset_device.py index 5bbd8c340a..0990b1f903 100644 --- a/src/apps/management/reset_device.py +++ b/src/apps/management/reset_device.py @@ -10,7 +10,7 @@ if __debug__: @unimport -async def layout_reset_device(session_id, msg): +async def layout_reset_device(ctx, msg): from trezor.ui.text import Text from trezor.crypto import hashlib, random, bip39 from trezor.messages.EntropyRequest import EntropyRequest @@ -39,21 +39,21 @@ async def layout_reset_device(session_id, msg): if msg.display_random: entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16) entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines) - await require_confirm(session_id, entropy_content, ButtonRequestType.ResetDevice) + await require_confirm(ctx, entropy_content, ButtonRequestType.ResetDevice) if msg.pin_protection: - pin = await request_pin_twice(session_id) + pin = await request_pin_twice(ctx) else: pin = None - external_entropy_ack = await wire.call(session_id, EntropyRequest(), EntropyAck) + external_entropy_ack = await ctx.call(EntropyRequest(), EntropyAck) ctx = hashlib.sha256() ctx.update(internal_entropy) ctx.update(external_entropy_ack.entropy) entropy = ctx.digest() mnemonic = bip39.from_data(entropy[:msg.strength // 8]) - await show_mnemonic_by_word(session_id, mnemonic) + await show_mnemonic_by_word(ctx, mnemonic) storage.load_mnemonic(mnemonic) storage.load_settings(pin=pin, @@ -64,7 +64,7 @@ async def layout_reset_device(session_id, msg): return Success(message='Initialized') -async def show_mnemonic_by_word(session_id, mnemonic): +async def show_mnemonic_by_word(ctx, mnemonic): from trezor.ui.text import Text from trezor.messages.ButtonRequestType import ConfirmWord from apps.common.confirm import confirm @@ -80,7 +80,7 @@ async def show_mnemonic_by_word(session_id, mnemonic): while index < len(words): word = words[index] current_word = word - await confirm(session_id, + await confirm(ctx, Text( 'Recovery seed setup', ui.ICON_RESET, ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ', diff --git a/src/apps/management/wipe_device.py b/src/apps/management/wipe_device.py index 732683f38d..b86c6d1be3 100644 --- a/src/apps/management/wipe_device.py +++ b/src/apps/management/wipe_device.py @@ -3,13 +3,13 @@ from trezor.utils import unimport @unimport -async def layout_wipe_device(session_id, msg): +async def layout_wipe_device(ctx, msg): from trezor.messages.Success import Success from trezor.ui.text import Text from ..common.confirm import hold_to_confirm from ..common import storage - await hold_to_confirm(session_id, Text( + await hold_to_confirm(ctx, Text( 'WIPE DEVICE', ui.ICON_WIPE, ui.NORMAL, 'Do you really want to', 'wipe the device?', diff --git a/src/apps/wallet/cipher_key_value.py b/src/apps/wallet/cipher_key_value.py index 7657e7177d..f8eded02f0 100644 --- a/src/apps/wallet/cipher_key_value.py +++ b/src/apps/wallet/cipher_key_value.py @@ -26,7 +26,7 @@ def cipher_key_value(msg, seckey: bytes) -> bytes: @unimport -async def layout_cipher_key_value(session_id, msg): +async def layout_cipher_key_value(ctx, msg): from trezor.messages.CipheredKeyValue import CipheredKeyValue from ..common import seed @@ -38,7 +38,7 @@ async def layout_cipher_key_value(session_id, msg): ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK) - node = await seed.get_root(session_id) + node = await seed.get_root(ctx) node.derive_path(msg.address_n) value = cipher_key_value(msg, node.private_key()) diff --git a/src/apps/wallet/get_address.py b/src/apps/wallet/get_address.py index 7dadae8073..63adac96fd 100644 --- a/src/apps/wallet/get_address.py +++ b/src/apps/wallet/get_address.py @@ -3,7 +3,7 @@ from trezor.utils import unimport @unimport -async def layout_get_address(session_id, msg): +async def layout_get_address(ctx, msg): from trezor.messages.Address import Address from trezor.messages.FailureType import ProcessError from ..common import coins @@ -15,18 +15,18 @@ async def layout_get_address(session_id, msg): address_n = msg.address_n or () coin_name = msg.coin_name or 'Bitcoin' - node = await seed.get_root(session_id) + node = await seed.get_root(ctx) node.derive_path(address_n) coin = coins.by_name(coin_name) address = node.address(coin.address_type) if msg.show_display: - await _show_address(session_id, address) + await _show_address(ctx, address) return Address(address=address) -async def _show_address(session_id, address): +async def _show_address(ctx, address): from trezor.messages.ButtonRequestType import Address from trezor.ui.text import Text from trezor.ui.qr import Qr @@ -37,7 +37,7 @@ async def _show_address(session_id, address): content = Container( Qr(address, (120, 135), 3), Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines)) - await require_confirm(session_id, content, code=Address) + await require_confirm(ctx, content, code=Address) def _split_address(address): diff --git a/src/apps/wallet/get_entropy.py b/src/apps/wallet/get_entropy.py index b7e606bb31..648a922dd4 100644 --- a/src/apps/wallet/get_entropy.py +++ b/src/apps/wallet/get_entropy.py @@ -3,18 +3,18 @@ from trezor.utils import unimport @unimport -async def layout_get_entropy(session_id, msg): +async def layout_get_entropy(ctx, msg): from trezor.messages.Entropy import Entropy from trezor.crypto import random l = min(msg.size, 1024) - await _show_entropy(session_id) + await _show_entropy(ctx) return Entropy(entropy=random.bytes(l)) -async def _show_entropy(session_id): +async def _show_entropy(ctx): from trezor.messages.ButtonRequestType import ProtectCall from trezor.ui.text import Text from trezor.ui.container import Container @@ -23,4 +23,4 @@ async def _show_entropy(session_id): content = Container( Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?')) - await require_confirm(session_id, content, code=ProtectCall) + await require_confirm(ctx, content, code=ProtectCall) diff --git a/src/apps/wallet/get_public_key.py b/src/apps/wallet/get_public_key.py index 66a2beff2f..03784fafc0 100644 --- a/src/apps/wallet/get_public_key.py +++ b/src/apps/wallet/get_public_key.py @@ -2,7 +2,7 @@ from trezor.utils import unimport @unimport -async def layout_get_public_key(session_id, msg): +async def layout_get_public_key(ctx, msg): from trezor.messages.HDNodeType import HDNodeType from trezor.messages.PublicKey import PublicKey from ..common import coins @@ -11,7 +11,7 @@ async def layout_get_public_key(session_id, msg): address_n = msg.address_n or () coin_name = msg.coin_name or 'Bitcoin' - node = await seed.get_root(session_id) + node = await seed.get_root(ctx) node.derive_path(address_n) coin = coins.by_name(coin_name) diff --git a/src/apps/wallet/sign_identity.py b/src/apps/wallet/sign_identity.py index 1696d7a442..fe8ab5e463 100644 --- a/src/apps/wallet/sign_identity.py +++ b/src/apps/wallet/sign_identity.py @@ -83,7 +83,7 @@ def sign_challenge(seckey: bytes, @unimport -async def layout_sign_identity(session_id, msg): +async def layout_sign_identity(ctx, msg): from trezor.messages.SignedIdentity import SignedIdentity from ..common import coins from ..common import seed @@ -92,7 +92,7 @@ async def layout_sign_identity(session_id, msg): display_identity(identity, msg.challenge_visual) address_n = get_identity_path(identity, msg.identity.index or 0) - node = await seed.get_root(session_id, msg.ecdsa_curve_name) + node = await seed.get_root(ctx, msg.ecdsa_curve_name) node.derive_path(address_n) coin = coins.by_name('Bitcoin') diff --git a/src/apps/wallet/sign_message.py b/src/apps/wallet/sign_message.py index b812d9d573..26b54f6112 100644 --- a/src/apps/wallet/sign_message.py +++ b/src/apps/wallet/sign_message.py @@ -3,7 +3,7 @@ from trezor.utils import unimport @unimport -async def layout_sign_message(session_id, msg): +async def layout_sign_message(ctx, msg): from trezor.messages.MessageSignature import MessageSignature from trezor.crypto.curve import secp256k1 from ..common import coins @@ -18,7 +18,7 @@ async def layout_sign_message(session_id, msg): coin_name = msg.coin_name or 'Bitcoin' coin = coins.by_name(coin_name) - node = await seed.get_root(session_id) + node = await seed.get_root(ctx) node.derive_path(msg.address_n) seckey = node.private_key() diff --git a/src/apps/wallet/sign_tx/__init__.py b/src/apps/wallet/sign_tx/__init__.py index e753d1334b..02cc913a4c 100644 --- a/src/apps/wallet/sign_tx/__init__.py +++ b/src/apps/wallet/sign_tx/__init__.py @@ -3,7 +3,7 @@ from trezor import wire @unimport -async def sign_tx(session_id, msg): +async def sign_tx(ctx, msg): from trezor.messages.RequestType import TXFINISHED from trezor.messages.wire_types import TxAck @@ -11,7 +11,7 @@ async def sign_tx(session_id, msg): from . import signing from . import layout - root = await seed.get_root(session_id) + root = await seed.get_root(ctx) signer = signing.sign_tx(msg, root) res = None @@ -23,13 +23,13 @@ async def sign_tx(session_id, msg): if req.__qualname__ == 'TxRequest': if req.request_type == TXFINISHED: break - res = await wire.call(session_id, req, TxAck) + res = await ctx.call(req, TxAck) elif req.__qualname__ == 'UiConfirmOutput': - res = await layout.confirm_output(session_id, req.output, req.coin) + res = await layout.confirm_output(ctx, req.output, req.coin) elif req.__qualname__ == 'UiConfirmTotal': - res = await layout.confirm_total(session_id, req.spending, req.fee, req.coin) + res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin) elif req.__qualname__ == 'UiConfirmFeeOverThreshold': - res = await layout.confirm_feeoverthreshold(session_id, req.fee, req.coin) + res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) else: raise TypeError('Invalid signing instruction') return req diff --git a/src/apps/wallet/sign_tx/layout.py b/src/apps/wallet/sign_tx/layout.py index 4d36bb34d0..991b714497 100644 --- a/src/apps/wallet/sign_tx/layout.py +++ b/src/apps/wallet/sign_tx/layout.py @@ -14,22 +14,22 @@ def split_address(address): return chunks(address, 17) -async def confirm_output(session_id, output, coin): +async def confirm_output(ctx, output, coin): content = Text('Confirm output', ui.ICON_RESET, ui.BOLD, format_amount(output.amount, coin), ui.NORMAL, 'to', ui.MONO, *split_address(output.address)) - return await confirm(session_id, content, ButtonRequestType.ConfirmOutput) + return await confirm(ctx, content, ButtonRequestType.ConfirmOutput) -async def confirm_total(session_id, spending, fee, coin): +async def confirm_total(ctx, spending, fee, coin): content = Text('Confirm transaction', ui.ICON_RESET, 'Sending: %s' % format_amount(spending, coin), 'Fee: %s' % format_amount(fee, coin)) - return await hold_to_confirm(session_id, content, ButtonRequestType.SignTx) + return await hold_to_confirm(ctx, content, ButtonRequestType.SignTx) -async def confirm_feeoverthreshold(session_id, fee, coin): +async def confirm_feeoverthreshold(ctx, fee, coin): content = Text('Confirm high fee:', ui.ICON_RESET, ui.BOLD, format_amount(fee, coin)) - return await confirm(session_id, content, ButtonRequestType.FeeOverThreshold) + return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold) diff --git a/src/apps/wallet/verify_message.py b/src/apps/wallet/verify_message.py index f6cc905cff..c3731cbb8a 100644 --- a/src/apps/wallet/verify_message.py +++ b/src/apps/wallet/verify_message.py @@ -3,7 +3,7 @@ from trezor.utils import unimport @unimport -async def layout_verify_message(session_id, msg): +async def layout_verify_message(ctx, msg): from trezor.messages.Success import Success from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import ripemd160, sha256 diff --git a/src/main.py b/src/main.py index 36b3bb1218..38ee70a3c8 100644 --- a/src/main.py +++ b/src/main.py @@ -4,26 +4,7 @@ from trezor import io from trezor import wire from trezor import main -# Load applications -from apps.common import storage -if __debug__: - from apps import debug -from apps import homescreen -from apps import management -from apps import wallet -from apps import ethereum -from apps import fido_u2f - -# Boot applications -if __debug__: - debug.boot() -homescreen.boot() -management.boot() -wallet.boot() -ethereum.boot() -fido_u2f.boot() - -# Intialize the USB stack +# initialize the USB stack usb_wire = io.HID( iface_num=0x00, ep_in=0x81, @@ -90,11 +71,30 @@ usb.add(usb_vcp) usb.add(usb_u2f) usb.open() -# Initialize the wire codec pipeline -wire.setup(usb_wire.iface_num()) +# load applications +from apps.common import storage +if __debug__: + from apps import debug +from apps import homescreen +from apps import management +from apps import wallet +from apps import ethereum +from apps import fido_u2f -# Load default homescreen +# boot applications +if __debug__: + debug.boot() +homescreen.boot() +management.boot() +wallet.boot() +ethereum.boot() +fido_u2f.boot(usb_u2f) + +# initialize the wire codec pipeline +wire.setup(usb_wire) + +# load default homescreen from apps.homescreen.homescreen import layout_homescreen -# Run main even loop and specify which screen is default +# run main even loop and specify which screen is default main.run(default_workflow=layout_homescreen) diff --git a/src/protobuf.py b/src/protobuf.py index f97bcd7fc6..7c0d166323 100644 --- a/src/protobuf.py +++ b/src/protobuf.py @@ -14,7 +14,7 @@ async def load_uvarint(reader): shift = 0 byte = 0x80 while byte & 0x80: - await reader.readinto(buffer) + await reader.areadinto(buffer) byte = buffer[0] result += (byte & 0x7F) << shift shift += 7 @@ -27,7 +27,7 @@ async def dump_uvarint(writer, n): while shifted: shifted = n >> 7 buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) - await writer.write(buffer) + await writer.awrite(buffer) n = shifted @@ -69,11 +69,11 @@ class LimitedReader: self.reader = reader self.limit = limit - async def readinto(self, buf): + async def areadinto(self, buf): if self.limit < len(buf): raise EOFError else: - nread = await self.reader.readinto(buf) + nread = await self.reader.areadinto(buf) self.limit -= nread return nread @@ -83,7 +83,7 @@ class CountingWriter: def __init__(self): self.size = 0 - async def write(self, buf): + async def awrite(self, buf): nwritten = len(buf) self.size += nwritten return nwritten @@ -112,7 +112,7 @@ async def load_message(reader, msg_type): await load_uvarint(reader) elif wtype == 2: ivalue = await load_uvarint(reader) - await reader.readinto(bytearray(ivalue)) + await reader.areadinto(bytearray(ivalue)) else: raise ValueError continue @@ -129,10 +129,10 @@ async def load_message(reader, msg_type): fvalue = bool(ivalue) elif ftype is BytesType: fvalue = bytearray(ivalue) - await reader.readinto(fvalue) + await reader.areadinto(fvalue) elif ftype is UnicodeType: fvalue = bytearray(ivalue) - await reader.readinto(fvalue) + await reader.areadinto(fvalue) fvalue = str(fvalue, 'utf8') elif issubclass(ftype, MessageType): fvalue = await load_message(LimitedReader(reader, ivalue), ftype) @@ -186,11 +186,11 @@ async def dump_message(writer, msg): elif ftype is BytesType: await dump_uvarint(writer, len(svalue)) - await writer.write(svalue) + await writer.awrite(svalue) elif ftype is UnicodeType: await dump_uvarint(writer, len(svalue)) - await writer.write(bytes(svalue, 'utf8')) + await writer.awrite(bytes(svalue, 'utf8')) elif issubclass(ftype, MessageType): counter = CountingWriter() diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 07435e987c..22305a4098 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -8,58 +8,82 @@ from trezor import workflow from . import codec_v1 from . import codec_v2 -workflows = {} +workflow_handlers = {} -def register(wire_type, handler, *args): - if wire_type in workflows: +def register(mtype, handler, *args): + '''Register `handler` to get scheduled after `mtype` message is received.''' + if mtype in workflow_handlers: raise KeyError - workflows[wire_type] = (handler, args) + workflow_handlers[mtype] = (handler, args) -def setup(interface): - session_supervisor = codec_v2.SesssionSupervisor(interface, - session_handler) +def setup(iface): + '''Initialize the wire stack on passed USB interface.''' + session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler) session_supervisor.open(codec_v1.SESSION_ID) loop.schedule_task(session_supervisor.listen()) class Context: - def __init__(self, interface, session_id): - self.interface = interface - self.session_id = session_id - - def get_reader(self): - if self.session_id == codec_v1.SESSION_ID: - return codec_v1.Reader(self.interface) - else: - return codec_v2.Reader(self.interface, self.session_id) - - def get_writer(self, mtype, msize): - if self.session_id == codec_v1.SESSION_ID: - return codec_v1.Writer(self.interface, mtype, msize) - else: - return codec_v2.Writer(self.interface, self.session_id, mtype, msize) - - async def read(self, types): - reader = self.get_reader() - await reader.open() - if reader.type not in types: - raise UnexpectedMessageError(reader) - return await protobuf.load_message(reader, - messages.get_type(reader.type)) - - async def write(self, msg): - counter = protobuf.CountingWriter() - await protobuf.dump_message(counter, msg) - writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size) - await protobuf.dump_message(writer, msg) - await writer.close() + def __init__(self, iface, sid): + self.iface = iface + self.sid = sid async def call(self, msg, types): + ''' + Reply with `msg` and wait for one of `types`. See `self.write()` and + `self.read()`. + ''' await self.write(msg) return await self.read(types) + async def read(self, types): + ''' + Wait for incoming message on this wire context and return it. Raises + `UnexpectedMessageError` if the message type does not match one of + `types`; and caller should always make sure to re-raise it. + ''' + reader = self.getreader() + + await reader.aopen() # wait for the message header + + # if we got a message with unexpected type, raise the reader via + # `UnexpectedMessageError` and let the session handler deal with it + if reader.type not in types: + raise UnexpectedMessageError(reader) + + # look up the protobuf class and parse the message + pbtype = messages.get_type(reader.type) + return await protobuf.load_message(reader, pbtype) + + async def write(self, msg): + ''' + Write a protobuf message to this wire context. + ''' + writer = self.getwriter() + + # get the message size + counter = protobuf.CountingWriter() + await protobuf.dump_message(counter, msg) + + # write the message + writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size) + await protobuf.dump_message(writer, msg) + await writer.aclose() + + def getreader(self): + if self.sid == codec_v1.SESSION_ID: + return codec_v1.Reader(self.iface) + else: + return codec_v2.Reader(self.iface, self.sid) + + def getwriter(self): + if self.sid == codec_v1.SESSION_ID: + return codec_v1.Writer(self.iface) + else: + return codec_v2.Writer(self.iface, self.sid) + class UnexpectedMessageError(Exception): def __init__(self, reader): @@ -74,60 +98,69 @@ class FailureError(Exception): self.message = message -class Workflow: - def __init__(self, default): - self.handlers = {} - self.default = default - - async def __call__(self, interface, session_id): - ctx = Context(interface, session_id) - while True: +async def session_handler(iface, sid): + reader = None + ctx = Context(iface, sid) + while True: + try: + # wait for new message, if needed, and find handler + if not reader: + reader = ctx.getreader() + await reader.aopen() try: - reader = ctx.get_reader() - await reader.open() - try: - handler = self.handlers[reader.type] - except KeyError: - handler = self.default - try: - await handler(ctx, reader) - except UnexpectedMessageError as unexp_msg: - reader = unexp_msg.reader - except Exception as e: - log.exception(__name__, e) + handler, args = workflow_handlers[reader.type] + except KeyError: + handler, args = unexpected_msg, () + + await handler(ctx, reader, *args) + + except UnexpectedMessageError as exc: + # retry with opened reader from the exception + reader = exc.reader + continue + except FailureError as exc: + # we log FailureError as warning, not as exception + log.warning(__name__, 'failure: %s', exc.message) + except Exception as exc: + # sessions are never closed by raised exceptions + log.exception(__name__, exc) + + # read new message in next iteration + reader = None async def protobuf_workflow(ctx, reader, handler, *args): - msg = await protobuf.load_message(reader, messages.get_type(reader.type)) + from trezor.messages.Failure import Failure + from trezor.messages.FailureType import FirmwareError + + req = await protobuf.load_message(reader, messages.get_type(reader.type)) try: - res = await handler(reader.sid, msg, *args) - except Exception as exc: - if not isinstance(exc, UnexpectedMessageError): - await ctx.write(make_failure_msg(exc)) + res = await handler(ctx, req, *args) + except UnexpectedMessageError: + # session handler takes care of this one raise - else: - if res: - await ctx.write(res) + except FailureError as exc: + # respond with specific code and message + await ctx.write(Failure(code=exc.code, message=exc.message)) + raise + except Exception as exc: + # respond with a generic code and message + await ctx.write(Failure(code=FirmwareError, message='Firmware error')) + raise + if res: + # respond with a specific response + await ctx.write(res) -async def handle_unexp_msg(ctx, reader): +async def unexpected_msg(ctx, reader): + from trezor.messages.Failure import Failure + from trezor.messages.FailureType import UnexpectedMessage + # receive the message and throw it away while reader.size > 0: buf = bytearray(reader.size) - await reader.readinto(buf) + await reader.areadinto(buf) + # respond with an unknown message error - from trezor.messages.Failure import Failure - from trezor.messages.FailureType import UnexpectedMessage await ctx.write( Failure(code=UnexpectedMessage, message='Unexpected message')) - -def make_failure_msg(exc): - from trezor.messages.Failure import Failure - from trezor.messages.FailureType import FirmwareError - if isinstance(exc, FailureError): - code = exc.code - message = exc.message - else: - code = FirmwareError - message = 'Firmware Error' - return Failure(code=code, message=message) diff --git a/src/trezor/wire/codec_v1.py b/src/trezor/wire/codec_v1.py index c296a1b809..488a81a0af 100644 --- a/src/trezor/wire/codec_v1.py +++ b/src/trezor/wire/codec_v1.py @@ -1,7 +1,6 @@ from micropython import const import ustruct -from trezor import io from trezor import loop from trezor import utils @@ -32,13 +31,13 @@ class Reader: def __repr__(self): return '' % (self.type, self.size) - async def open(self): + async def aopen(self): ''' Begin the message transmission by waiting for initial V2 message report on this session. `self.type` and `self.size` are initialized and - available after `open()` returns. + available after `aopen()` returns. ''' - read = loop.select(self.iface | loop.READ) + read = loop.select(self.iface.iface_num() | loop.READ) while True: # wait for initial report report = await read @@ -55,7 +54,7 @@ class Reader: self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize] self.ofs = 0 - async def readinto(self, buf): + async def areadinto(self, buf): ''' Read exactly `len(buf)` bytes into `buf`, waiting for additional reports, if needed. Raises `EOFError` if end-of-message is encountered @@ -64,7 +63,7 @@ class Reader: if self.size < len(buf): raise EOFError - read = loop.select(self.iface | loop.READ) + read = loop.select(self.iface.iface_num() | loop.READ) nread = 0 while nread < len(buf): if self.ofs == len(self.data): @@ -93,20 +92,28 @@ class Writer: async-file-like interface. ''' - def __init__(self, iface, mtype, msize): + def __init__(self, iface): self.iface = iface - self.type = mtype - self.size = msize + self.type = None + self.size = None self.data = bytearray(_REP_LEN) - self.ofs = _REP_INIT_DATA - - # load the report with initial header - ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize) + self.ofs = 0 def __repr__(self): return '' % (self.type, self.size) - async def write(self, buf): + def setheader(self, mtype, msize): + ''' + Reset the writer state and load the message header with passed type and + total message size. + ''' + self.type = mtype + self.size = msize + ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, + _REP_MAGIC, mtype, msize) + self.ofs = _REP_INIT_DATA + + async def awrite(self, buf): ''' Encode and write every byte from `buf`. Does not need to be called in case message has zero length. Raises `EOFError` if the length of `buf` @@ -115,7 +122,7 @@ class Writer: if self.size < len(buf): raise EOFError - write = loop.select(self.iface | loop.WRITE) + write = loop.select(self.iface.iface_num() | loop.WRITE) nwritten = 0 while nwritten < len(buf): # copy as much as possible to report buffer @@ -127,12 +134,12 @@ class Writer: if self.ofs == _REP_LEN: # we are at the end of the report, flush it await write - io.write(self.iface, self.data) + self.iface.write(self.data) self.ofs = _REP_CONT_DATA return nwritten - async def close(self): + async def aclose(self): '''Flush and close the message transmission.''' if self.ofs != _REP_CONT_DATA: # we didn't write anything or last write() wasn't report-aligned, @@ -141,5 +148,5 @@ class Writer: self.data[self.ofs] = 0x00 self.ofs += 1 - await loop.select(self.iface | loop.WRITE) - io.send(self.iface, self.data) + await loop.select(self.iface.iface_num() | loop.WRITE) + self.iface.write(self.data) diff --git a/src/trezor/wire/codec_v2.py b/src/trezor/wire/codec_v2.py index bdd79415f7..31176e49b0 100644 --- a/src/trezor/wire/codec_v2.py +++ b/src/trezor/wire/codec_v2.py @@ -1,7 +1,6 @@ from micropython import const import ustruct -from trezor import io from trezor import loop from trezor import utils from trezor.crypto import random @@ -28,11 +27,11 @@ _REP_MARKER_CONT = const(0x02) _REP_MARKER_OPEN = const(0x03) _REP_MARKER_CLOSE = const(0x04) -_REP = '>BL' # marker, session_id +_REP = '>BL' # marker, session_id _REP_INIT = '>BLLL' # marker, session_id, message_type, message_size -_REP_CONT = '>BLL' # marker, session_id, sequence +_REP_CONT = '>BLL' # marker, session_id, sequence _REP_INIT_DATA = const(13) # offset of data in init report -_REP_CONT_DATA = const(9) # offset of data in cont report +_REP_CONT_DATA = const(9) # offset of data in cont report class Reader: @@ -51,15 +50,16 @@ class Reader: self.seq = 0 def __repr__(self): - return '' % (self.sid, self.type, self.size) + return '' % (self.sid, self.type, + self.size) - async def open(self): + async def aopen(self): ''' Begin the message transmission by waiting for initial V2 message report on this session. `self.type` and `self.size` are initialized and - available after `open()` returns. + available after `aopen()` returns. ''' - read = loop.select(self.iface | loop.READ) + read = loop.select(self.iface.iface_num() | loop.READ) while True: # wait for initial report report = await read @@ -74,7 +74,7 @@ class Reader: self.ofs = 0 self.seq = 0 - async def readinto(self, buf): + async def areadinto(self, buf): ''' Read exactly `len(buf)` bytes into `buf`, waiting for additional reports, if needed. Raises `EOFError` if end-of-message is encountered @@ -83,7 +83,7 @@ class Reader: if self.size < len(buf): raise EOFError - read = loop.select(self.iface | loop.READ) + read = loop.select(self.iface.iface_num() | loop.READ) nread = 0 while nread < len(buf): if self.ofs == len(self.data): @@ -115,20 +115,28 @@ class Writer: interface. ''' - def __init__(self, iface, sid, mtype, msize): + def __init__(self, iface, sid): self.iface = iface self.sid = sid + self.type = None + self.size = None + self.data = bytearray(_REP_LEN) + self.ofs = 0 + self.seq = 0 + + def setheader(self, mtype, msize): + ''' + Reset the writer state and load the message header with passed type and + total message size. + ''' self.type = mtype self.size = msize - self.data = bytearray(_REP_LEN) + ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER_INIT, self.sid, + mtype, msize) self.ofs = _REP_INIT_DATA self.seq = 0 - # load the report with initial header - ustruct.pack_into(_REP_INIT, self.data, 0, - _REP_MARKER_INIT, sid, mtype, msize) - - async def write(self, buf): + async def awrite(self, buf): ''' Encode and write every byte from `buf`. Does not need to be called in case message has zero length. Raises `EOFError` if the length of `buf` @@ -137,7 +145,7 @@ class Writer: if self.size < len(buf): raise EOFError - write = loop.select(self.iface | loop.WRITE) + write = loop.select(self.iface.iface_num() | loop.WRITE) nwritten = 0 while nwritten < len(buf): # copy as much as possible to report buffer @@ -149,15 +157,15 @@ class Writer: if self.ofs == _REP_LEN: # we are at the end of the report, flush it, and prepare header await write - io.send(self.iface, self.data) - ustruct.pack_into(_REP_CONT, self.data, 0, - _REP_MARKER_CONT, self.sid, self.seq) + self.iface.write(self.data) + ustruct.pack_into(_REP_CONT, self.data, 0, _REP_MARKER_CONT, + self.sid, self.seq) self.ofs = _REP_CONT_DATA self.seq += 1 return nwritten - async def close(self): + async def aclose(self): '''Flush and close the message transmission.''' if self.ofs != _REP_CONT_DATA: # we didn't write anything or last write() wasn't report-aligned, @@ -166,8 +174,8 @@ class Writer: self.data[self.ofs] = 0x00 self.ofs += 1 - await loop.select(self.iface | loop.WRITE) - io.send(self.iface, self.data) + await loop.select(self.iface.iface_num() | loop.WRITE) + self.iface.write(self.data) class SesssionSupervisor: @@ -186,8 +194,8 @@ class SesssionSupervisor: After close request, the handling task is closed and session terminated. Both requests receive responses confirming the operation. ''' - read = loop.select(self.iface | loop.READ) - write = loop.select(self.iface | loop.WRITE) + read = loop.select(self.iface.iface_num() | loop.READ) + write = loop.select(self.iface.iface_num() | loop.WRITE) while True: report = await read repmarker, repsid = ustruct.unpack(_REP, report) @@ -225,8 +233,8 @@ class SesssionSupervisor: def writeopen(self, sid): ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid) - io.write(self.iface, self.session_report) + self.iface.write(self.session_report) def writeclose(self, sid): ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid) - io.write(self.iface, self.session_report) + self.iface.write(self.session_report) diff --git a/tests/test_trezor.wire.codec_v1.py b/tests/test_trezor.wire.codec_v1.py index 915eb0f2a4..eee14b708e 100644 --- a/tests/test_trezor.wire.codec_v1.py +++ b/tests/test_trezor.wire.codec_v1.py @@ -26,26 +26,26 @@ def test_reader(): # open, expected one read first_report = report_header + message[:rep_len - len(report_header)] - assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) + assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_eq(reader.type, message_type) assert_eq(reader.size, message_len) # empty read empty_buffer = bytearray() - assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),]) assert_eq(len(empty_buffer), 0) assert_eq(reader.size, message_len) # short read, expected no read short_buffer = bytearray(32) - assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),]) assert_eq(len(short_buffer), 32) assert_eq(short_buffer, message[:len(short_buffer)]) assert_eq(reader.size, message_len - len(short_buffer)) # aligned read, expected no read aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) - assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),]) assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) @@ -53,12 +53,12 @@ def test_reader(): next_report_header = bytearray(unhexlify('3f')) next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] onebyte_buffer = bytearray(1) - assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) + assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) # too long read, raises eof - assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) + assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),]) # long read, expect multiple reads start_size = reader.size @@ -74,12 +74,12 @@ def test_reader(): prev_report = next_reports[i - 1] if i > 0 else None expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((next_reports[-1], StopIteration())) - assert_async(reader.readinto(long_buffer), expected_syscalls) + assert_async(reader.areadinto(long_buffer), expected_syscalls) assert_eq(long_buffer, message[-start_size:]) assert_eq(reader.size, 0) # one byte read, raises eof - assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) + assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),]) def test_writer(): @@ -87,7 +87,8 @@ def test_writer(): interface = 0xdeadbeef message_type = 0x87654321 message_len = 1024 - writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len) + writer = codec_v1.Writer(interface, codec_v1.SESSION_ID) + writer.setheader(message_type, message_len) # init header corresponding to the data above report_header = bytearray(unhexlify('3f2323432100000400')) @@ -96,14 +97,14 @@ def test_writer(): # empty write start_size = writer.size - assert_async(writer.write(bytearray()), [(None, StopIteration()),]) + assert_async(writer.awrite(bytearray()), [(None, StopIteration()),]) assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) assert_eq(writer.size, start_size) # short write, expected no report start_size = writer.size short_payload = bytearray(range(4)) - assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_async(writer.awrite(short_payload), [(None, StopIteration()),]) assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.data, report_header @@ -118,7 +119,7 @@ def test_writer(): + short_payload + aligned_payload + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) - assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) + assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) assert_eq(writer.size, start_size - len(aligned_payload)) msg.send.assert_called_n_times(1) msg.send = msg.send.original @@ -126,7 +127,7 @@ def test_writer(): # short write, expected no report, but data starts with correct seq and cont marker report_header = bytearray(unhexlify('3f')) start_size = writer.size - assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_async(writer.awrite(short_payload), [(None, StopIteration()),]) assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.data[:len(report_header) + len(short_payload)], report_header + short_payload) @@ -142,19 +143,19 @@ def test_writer(): # test write expected_write_reports = expected_reports[:-1] msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) - assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(writer.size, start_size - len(long_payload)) msg.send.assert_called_n_times(len(expected_write_reports)) msg.send = msg.send.original # test write raises eof msg.send = mock_call(msg.send, []) - assert_async(writer.write(bytearray(1)), [(None, EOFError())]) + assert_async(writer.awrite(bytearray(1)), [(None, EOFError())]) msg.send.assert_called_n_times(0) msg.send = msg.send.original # test close expected_close_reports = expected_reports[-1:] msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) - assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(writer.size, 0) msg.send.assert_called_n_times(len(expected_close_reports)) msg.send = msg.send.original diff --git a/tests/test_trezor.wire.codec_v2.py b/tests/test_trezor.wire.codec_v2.py index 56fc6ab03a..684f35caba 100644 --- a/tests/test_trezor.wire.codec_v2.py +++ b/tests/test_trezor.wire.codec_v2.py @@ -26,26 +26,26 @@ def test_reader(): # open, expected one read first_report = report_header + message[:rep_len - len(report_header)] - assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) + assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_eq(reader.type, message_type) assert_eq(reader.size, message_len) # empty read empty_buffer = bytearray() - assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),]) assert_eq(len(empty_buffer), 0) assert_eq(reader.size, message_len) # short read, expected no read short_buffer = bytearray(32) - assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),]) assert_eq(len(short_buffer), 32) assert_eq(short_buffer, message[:len(short_buffer)]) assert_eq(reader.size, message_len - len(short_buffer)) # aligned read, expected no read aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) - assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) + assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),]) assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) @@ -53,12 +53,12 @@ def test_reader(): next_report_header = bytearray(unhexlify('021234567800000000')) next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] onebyte_buffer = bytearray(1) - assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) + assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) # too long read, raises eof - assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) + assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),]) # long read, expect multiple reads start_size = reader.size @@ -74,12 +74,12 @@ def test_reader(): prev_report = next_reports[i - 1] if i > 0 else None expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((next_reports[-1], StopIteration())) - assert_async(reader.readinto(long_buffer), expected_syscalls) + assert_async(reader.areadinto(long_buffer), expected_syscalls) assert_eq(long_buffer, message[-start_size:]) assert_eq(reader.size, 0) # one byte read, raises eof - assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) + assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),]) def test_writer(): @@ -88,7 +88,8 @@ def test_writer(): session_id = 0x12345678 message_type = 0x87654321 message_len = 1024 - writer = codec_v2.Writer(interface, session_id, message_type, message_len) + writer = codec_v2.Writer(interface, session_id) + writer.setheader(message_type, message_len) # init header corresponding to the data above report_header = bytearray(unhexlify('01123456788765432100000400')) @@ -97,14 +98,14 @@ def test_writer(): # empty write start_size = writer.size - assert_async(writer.write(bytearray()), [(None, StopIteration()),]) + assert_async(writer.awrite(bytearray()), [(None, StopIteration()),]) assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) assert_eq(writer.size, start_size) # short write, expected no report start_size = writer.size short_payload = bytearray(range(4)) - assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_async(writer.awrite(short_payload), [(None, StopIteration()),]) assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.data, report_header @@ -119,7 +120,7 @@ def test_writer(): + short_payload + aligned_payload + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) - assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) + assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) assert_eq(writer.size, start_size - len(aligned_payload)) msg.send.assert_called_n_times(1) msg.send = msg.send.original @@ -127,7 +128,7 @@ def test_writer(): # short write, expected no report, but data starts with correct seq and cont marker report_header = bytearray(unhexlify('021234567800000000')) start_size = writer.size - assert_async(writer.write(short_payload), [(None, StopIteration()),]) + assert_async(writer.awrite(short_payload), [(None, StopIteration()),]) assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.data[:len(report_header) + len(short_payload)], report_header + short_payload) @@ -145,19 +146,19 @@ def test_writer(): # test write expected_write_reports = expected_reports[:-1] msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) - assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(writer.size, start_size - len(long_payload)) msg.send.assert_called_n_times(len(expected_write_reports)) msg.send = msg.send.original # test write raises eof msg.send = mock_call(msg.send, []) - assert_async(writer.write(bytearray(1)), [(None, EOFError())]) + assert_async(writer.awrite(bytearray(1)), [(None, EOFError())]) msg.send.assert_called_n_times(0) msg.send = msg.send.original # test close expected_close_reports = expected_reports[-1:] msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) - assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) + assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(writer.size, 0) msg.send.assert_called_n_times(len(expected_close_reports)) msg.send = msg.send.original