1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-28 16:21:03 +00:00

wire: pass Context to apps

This commit is contained in:
Jan Pochyla 2017-08-15 15:09:09 +02:00
parent b1b84fb233
commit 3562ffdc54
30 changed files with 355 additions and 308 deletions

View File

@ -7,7 +7,7 @@ signal = loop.Signal()
@unimport @unimport
async def confirm(session_id, content, code=None, *args, **kwargs): async def confirm(ctx, content, code=None, *args, **kwargs):
from trezor.ui.confirm import ConfirmDialog, CONFIRMED from trezor.ui.confirm import ConfirmDialog, CONFIRMED
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import Other from trezor.messages.ButtonRequestType import Other
@ -19,12 +19,12 @@ async def confirm(session_id, content, code=None, *args, **kwargs):
if code is None: if code is None:
code = Other code = Other
await wire.call(session_id, ButtonRequest(code=code), ButtonAck) await ctx.call(ButtonRequest(code=code), ButtonAck)
return await loop.Wait((signal, dialog)) == CONFIRMED return await loop.Wait((signal, dialog)) == CONFIRMED
@unimport @unimport
async def hold_to_confirm(session_id, content, code=None, *args, **kwargs): async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import Other from trezor.messages.ButtonRequestType import Other
@ -36,7 +36,7 @@ async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
if code is None: if code is None:
code = Other code = Other
await wire.call(session_id, ButtonRequest(code=code), ButtonAck) await ctx.call(ButtonRequest(code=code), ButtonAck)
return await loop.Wait((signal, dialog)) == CONFIRMED return await loop.Wait((signal, dialog)) == CONFIRMED

View File

@ -1,7 +1,7 @@
from trezor import ui, wire from trezor import ui, wire
async def request_passphrase(session_id): async def request_passphrase(ctx):
from trezor.messages.FailureType import ActionCancelled from trezor.messages.FailureType import ActionCancelled
from trezor.messages.PassphraseRequest import PassphraseRequest from trezor.messages.PassphraseRequest import PassphraseRequest
from trezor.messages.wire_types import PassphraseAck, Cancel from trezor.messages.wire_types import PassphraseAck, Cancel
@ -12,17 +12,17 @@ async def request_passphrase(session_id):
'Please enter passphrase', 'on your computer.') 'Please enter passphrase', 'on your computer.')
text.render() text.render()
ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel) ack = await ctx.call(PassphraseRequest(), PassphraseAck, Cancel)
if ack.MESSAGE_WIRE_TYPE == Cancel: if ack.MESSAGE_WIRE_TYPE == Cancel:
raise wire.FailureError(ActionCancelled, 'Passphrase cancelled') raise wire.FailureError(ActionCancelled, 'Passphrase cancelled')
return ack.passphrase return ack.passphrase
async def protect_by_passphrase(session_id): async def protect_by_passphrase(ctx):
from apps.common import storage from apps.common import storage
if storage.is_protected_by_passphrase(): if storage.is_protected_by_passphrase():
return await request_passphrase(session_id) return await request_passphrase(ctx)
else: else:
return '' return ''

View File

@ -7,7 +7,7 @@ if __debug__:
@unimport @unimport
async def request_pin_on_display(session_id: int, code: int=None) -> str: async def request_pin_on_display(ctx: wire.Context, code: int=None) -> str:
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.ButtonRequestType import ProtectCall
from trezor.messages.FailureType import PinCancelled from trezor.messages.FailureType import PinCancelled
@ -20,9 +20,8 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
_, label = _get_code_and_label(code) _, label = _get_code_and_label(code)
await wire.call(session_id, await ctx.call(ButtonRequest(code=ProtectCall),
ButtonRequest(code=ProtectCall), ButtonAck)
ButtonAck)
ui.display.clear() ui.display.clear()
matrix = PinMatrix(label) matrix = PinMatrix(label)
@ -36,7 +35,7 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
@unimport @unimport
async def request_pin_on_client(session_id: int, code: int=None) -> str: async def request_pin_on_client(ctx: wire.Context, code: int=None) -> str:
from trezor.messages.FailureType import PinCancelled from trezor.messages.FailureType import PinCancelled
from trezor.messages.PinMatrixRequest import PinMatrixRequest from trezor.messages.PinMatrixRequest import PinMatrixRequest
from trezor.messages.wire_types import PinMatrixAck, Cancel from trezor.messages.wire_types import PinMatrixAck, Cancel
@ -51,9 +50,8 @@ async def request_pin_on_client(session_id: int, code: int=None) -> str:
matrix = PinMatrix(label) matrix = PinMatrix(label)
matrix.render() matrix.render()
ack = await wire.call(session_id, ack = await ctx.call(PinMatrixRequest(type=code),
PinMatrixRequest(type=code), PinMatrixAck, Cancel)
PinMatrixAck, Cancel)
digits = matrix.digits digits = matrix.digits
matrix = None matrix = None
@ -66,12 +64,12 @@ request_pin = request_pin_on_client
@unimport @unimport
async def request_pin_twice(session_id: int) -> str: async def request_pin_twice(ctx: wire.Context) -> str:
from trezor.messages.FailureType import ActionCancelled from trezor.messages.FailureType import ActionCancelled
from trezor.messages import PinMatrixRequestType from trezor.messages import PinMatrixRequestType
pin_first = await request_pin(session_id, PinMatrixRequestType.NewFirst) pin_first = await request_pin(ctx, PinMatrixRequestType.NewFirst)
pin_again = await request_pin(session_id, PinMatrixRequestType.NewSecond) pin_again = await request_pin(ctx, PinMatrixRequestType.NewSecond)
if pin_first != pin_again: if pin_first != pin_again:
# changed message due to consistency with T1 msgs # changed message due to consistency with T1 msgs
raise wire.FailureError(ActionCancelled, 'PIN change failed') raise wire.FailureError(ActionCancelled, 'PIN change failed')
@ -79,22 +77,22 @@ async def request_pin_twice(session_id: int) -> str:
return pin_first return pin_first
async def protect_by_pin_repeatedly(session_id: int, at_least_once: bool=False): async def protect_by_pin_repeatedly(ctx: wire.Context, at_least_once: bool=False):
from . import storage from . import storage
locked = storage.is_locked() or at_least_once locked = storage.is_locked() or at_least_once
while locked: while locked:
pin = await request_pin(session_id) pin = await request_pin(ctx)
locked = not storage.unlock(pin, _render_pin_failure) locked = not storage.unlock(pin, _render_pin_failure)
async def protect_by_pin_or_fail(session_id: int, at_least_once: bool=False): async def protect_by_pin_or_fail(ctx: wire.Context, at_least_once: bool=False):
from trezor.messages.FailureType import PinInvalid from trezor.messages.FailureType import PinInvalid
from . import storage from . import storage
locked = storage.is_locked() or at_least_once locked = storage.is_locked() or at_least_once
if locked: if locked:
pin = await request_pin(session_id) pin = await request_pin(ctx)
if not storage.unlock(pin, _render_pin_failure): if not storage.unlock(pin, _render_pin_failure):
raise wire.FailureError(PinInvalid, 'PIN invalid') raise wire.FailureError(PinInvalid, 'PIN invalid')

View File

@ -5,20 +5,20 @@ from trezor.crypto import bip39
_DEFAULT_CURVE = 'secp256k1' _DEFAULT_CURVE = 'secp256k1'
async def get_root(session_id: int, curve_name=_DEFAULT_CURVE): async def get_root(ctx: wire.Context, curve_name=_DEFAULT_CURVE):
seed = await get_seed(session_id) seed = await get_seed(ctx)
root = bip32.from_seed(seed, curve_name) root = bip32.from_seed(seed, curve_name)
return root return root
async def get_seed(session_id: int) -> bytes: async def get_seed(ctx: wire.Context) -> bytes:
from . import cache from . import cache
if cache.seed is None: if cache.seed is None:
cache.seed = await compute_seed(session_id) cache.seed = await compute_seed(ctx)
return cache.seed return cache.seed
async def compute_seed(session_id: int) -> bytes: async def compute_seed(ctx: wire.Context) -> bytes:
from trezor.messages.FailureType import ProcessError from trezor.messages.FailureType import ProcessError
from .request_passphrase import protect_by_passphrase from .request_passphrase import protect_by_passphrase
from .request_pin import protect_by_pin from .request_pin import protect_by_pin
@ -27,9 +27,9 @@ async def compute_seed(session_id: int) -> bytes:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.FailureError(ProcessError, 'Device is not initialized') raise wire.FailureError(ProcessError, 'Device is not initialized')
await protect_by_pin(session_id) await protect_by_pin(ctx)
passphrase = await protect_by_passphrase(session_id) passphrase = await protect_by_passphrase(ctx)
return bip39.seed(storage.get_mnemonic(), passphrase) return bip39.seed(storage.get_mnemonic(), passphrase)

View File

@ -4,13 +4,13 @@ from trezor.messages.wire_types import \
DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase
async def dispatch_DebugLinkDecision(session_id, msg): async def dispatch_DebugLinkDecision(ctx, msg):
from trezor.ui.confirm import CONFIRMED, CANCELLED from trezor.ui.confirm import CONFIRMED, CANCELLED
from apps.common.confirm import signal from apps.common.confirm import signal
signal.send(CONFIRMED if msg.yes_no else CANCELLED) signal.send(CONFIRMED if msg.yes_no else CANCELLED)
async def dispatch_DebugLinkGetState(session_id, msg): async def dispatch_DebugLinkGetState(ctx, msg):
from trezor.messages.DebugLinkState import DebugLinkState from trezor.messages.DebugLinkState import DebugLinkState
from apps.common import storage, request_pin from apps.common import storage, request_pin
from apps.management import reset_device from apps.management import reset_device
@ -36,11 +36,11 @@ async def dispatch_DebugLinkGetState(session_id, msg):
return m return m
async def dispatch_DebugLinkStop(session_id, msg): async def dispatch_DebugLinkStop(ctx, msg):
pass pass
async def dispatch_DebugLinkMemoryRead(session_id, msg): async def dispatch_DebugLinkMemoryRead(ctx, msg):
from trezor.messages.DebugLinkMemory import DebugLinkMemory from trezor.messages.DebugLinkMemory import DebugLinkMemory
from uctypes import bytes_at from uctypes import bytes_at
m = DebugLinkMemory() m = DebugLinkMemory()
@ -48,14 +48,14 @@ async def dispatch_DebugLinkMemoryRead(session_id, msg):
return m return m
async def dispatch_DebugLinkMemoryWrite(session_id, msg): async def dispatch_DebugLinkMemoryWrite(ctx, msg):
from uctypes import bytearray_at from uctypes import bytearray_at
l = len(msg.memory) l = len(msg.memory)
data = bytearray_at(msg.address, l) data = bytearray_at(msg.address, l)
data[0:l] = msg.memory data[0:l] = msg.memory
async def dispatch_DebugLinkFlashErase(session_id, msg): async def dispatch_DebugLinkFlashErase(ctx, msg):
# TODO: erase(msg.sector) # TODO: erase(msg.sector)
pass pass

View File

@ -3,13 +3,13 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_ethereum_get_address(session_id, msg): async def layout_ethereum_get_address(ctx, msg):
from trezor.messages.EthereumAddress import EthereumAddress from trezor.messages.EthereumAddress import EthereumAddress
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
from ..common import seed from ..common import seed
node = await seed.get_root(session_id) node = await seed.get_root(ctx)
node.derive_path(msg.address_n or ()) node.derive_path(msg.address_n or ())
seckey = node.private_key() seckey = node.private_key()
@ -17,11 +17,11 @@ async def layout_ethereum_get_address(session_id, msg):
address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak
if msg.show_display: if msg.show_display:
await _show_address(session_id, address) await _show_address(ctx, address)
return EthereumAddress(address=address) return EthereumAddress(address=address)
async def _show_address(session_id, address): async def _show_address(ctx, address):
from trezor.messages.ButtonRequestType import Address from trezor.messages.ButtonRequestType import Address
from trezor.ui.text import Text from trezor.ui.text import Text
from ..common.confirm import require_confirm from ..common.confirm import require_confirm
@ -30,7 +30,7 @@ async def _show_address(session_id, address):
content = Text('Confirm address', ui.ICON_RESET, content = Text('Confirm address', ui.ICON_RESET,
ui.MONO, *_split_address(address)) ui.MONO, *_split_address(address))
await require_confirm(session_id, content, code=Address) await require_confirm(ctx, content, code=Address)
def _split_address(address): def _split_address(address):

View File

@ -226,11 +226,11 @@ class Cmd:
return Msg(self.cid, cla, ins, p1, p2, lc, data) return Msg(self.cid, cla, ins, p1, p2, lc, data)
async def read_cmd(iface: int) -> Cmd: async def read_cmd(iface: io.HID) -> Cmd:
desc_init = frame_init() desc_init = frame_init()
desc_cont = frame_cont() desc_cont = frame_cont()
buf, = await loop.select(iface) buf, = await loop.select(iface.iface_num())
# log.debug(__name__, 'read init %s', buf) # log.debug(__name__, 'read init %s', buf)
ifrm = overlay_struct(buf, desc_init) ifrm = overlay_struct(buf, desc_init)
@ -252,7 +252,7 @@ async def read_cmd(iface: int) -> Cmd:
data = data[:bcnt] data = data[:bcnt]
while datalen < bcnt: while datalen < bcnt:
buf, = await loop.select(iface) buf, = await loop.select(iface.iface_num())
# log.debug(__name__, 'read cont %s', buf) # log.debug(__name__, 'read cont %s', buf)
cfrm = overlay_struct(buf, desc_cont) cfrm = overlay_struct(buf, desc_cont)
@ -282,7 +282,7 @@ async def read_cmd(iface: int) -> Cmd:
return Cmd(ifrm.cid, ifrm.cmd, data) return Cmd(ifrm.cid, ifrm.cmd, data)
def send_cmd(cmd: Cmd, iface: int) -> None: def send_cmd(cmd: Cmd, iface: io.HID) -> None:
init_desc = frame_init() init_desc = frame_init()
cont_desc = frame_cont() cont_desc = frame_cont()
offset = 0 offset = 0
@ -295,7 +295,7 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
frm.bcnt = datalen frm.bcnt = datalen
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
io.send(iface, buf) iface.write(buf)
# log.debug(__name__, 'send init %s', buf) # log.debug(__name__, 'send init %s', buf)
if offset < datalen: if offset < datalen:
@ -304,18 +304,17 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
while offset < datalen: while offset < datalen:
frm.seq = seq frm.seq = seq
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
utime.sleep_ms(1) # FIXME: do async send utime.sleep_ms(1) # FIXME: async write
io.send(iface, buf) iface.write(buf)
# log.debug(__name__, 'send cont %s', buf) # log.debug(__name__, 'send cont %s', buf)
seq += 1 seq += 1
def boot(): def boot(iface: io.HID):
iface = 0x03
loop.schedule_task(handle_reports(iface)) loop.schedule_task(handle_reports(iface))
async def handle_reports(iface: int): async def handle_reports(iface: io.HID):
while True: while True:
try: try:
req = await read_cmd(iface) req = await read_cmd(iface)

View File

@ -4,7 +4,7 @@ from trezor.messages.wire_types import Initialize, GetFeatures, Ping
@unimport @unimport
async def respond_Features(session_id, msg): async def respond_Features(ctx, msg):
from apps.common import storage, coins from apps.common import storage, coins
from trezor.messages.Features import Features from trezor.messages.Features import Features
@ -28,7 +28,7 @@ async def respond_Features(session_id, msg):
@unimport @unimport
async def respond_Pong(session_id, msg): async def respond_Pong(ctx, msg):
from trezor.messages.Success import Success from trezor.messages.Success import Success
s = Success() s = Success()
@ -36,11 +36,11 @@ async def respond_Pong(session_id, msg):
if msg.pin_protection: if msg.pin_protection:
from apps.common.request_pin import protect_by_pin from apps.common.request_pin import protect_by_pin
await protect_by_pin(session_id) await protect_by_pin(ctx)
if msg.passphrase_protection: if msg.passphrase_protection:
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
await protect_by_passphrase(session_id) await protect_by_passphrase(ctx)
# TODO: handle other fields: # TODO: handle other fields:
# button_protection # button_protection

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_apply_settings(session_id, msg): async def layout_apply_settings(ctx, msg):
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.messages.FailureType import ProcessError from trezor.messages.FailureType import ProcessError
from trezor.ui.text import Text from trezor.ui.text import Text
@ -11,7 +11,7 @@ async def layout_apply_settings(session_id, msg):
from ..common.request_pin import protect_by_pin from ..common.request_pin import protect_by_pin
from ..common import storage from ..common import storage
await protect_by_pin(session_id) await protect_by_pin(ctx)
if msg.homescreen is not None: if msg.homescreen is not None:
raise wire.FailureError( raise wire.FailureError(
@ -21,20 +21,20 @@ async def layout_apply_settings(session_id, msg):
raise wire.FailureError(ProcessError, 'No setting provided') raise wire.FailureError(ProcessError, 'No setting provided')
if msg.label is not None: if msg.label is not None:
await require_confirm(session_id, Text( await require_confirm(ctx, Text(
'Change label', ui.ICON_RESET, 'Change label', ui.ICON_RESET,
'Do you really want to', 'change label to', 'Do you really want to', 'change label to',
ui.BOLD, '%s' % msg.label)) ui.BOLD, '%s' % msg.label))
if msg.language is not None: if msg.language is not None:
await require_confirm(session_id, Text( await require_confirm(ctx, Text(
'Change language', ui.ICON_RESET, 'Change language', ui.ICON_RESET,
'Do you really want to', 'change language to', 'Do you really want to', 'change language to',
ui.BOLD, '%s' % msg.language, ui.BOLD, '%s' % msg.language,
ui.NORMAL, '?')) ui.NORMAL, '?'))
if msg.use_passphrase is not None: if msg.use_passphrase is not None:
await require_confirm(session_id, Text( await require_confirm(ctx, Text(
'Enable passphrase' if msg.use_passphrase else 'Disable passphrase', 'Enable passphrase' if msg.use_passphrase else 'Disable passphrase',
ui.ICON_RESET, ui.ICON_RESET,
'Do you really want to', 'Do you really want to',

View File

@ -2,52 +2,52 @@ from trezor import ui
from trezor.utils import unimport from trezor.utils import unimport
def confirm_set_pin(session_id): def confirm_set_pin(ctx):
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.ui.text import Text from trezor.ui.text import Text
return require_confirm(session_id, Text( return require_confirm(ctx, Text(
'Change PIN', ui.ICON_RESET, 'Change PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD, 'Do you really want to', ui.BOLD,
'set new PIN?')) 'set new PIN?'))
def confirm_change_pin(session_id): def confirm_change_pin(ctx):
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.ui.text import Text from trezor.ui.text import Text
return require_confirm(session_id, Text( return require_confirm(ctx, Text(
'Change PIN', ui.ICON_RESET, 'Change PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD, 'Do you really want to', ui.BOLD,
'change current PIN?')) 'change current PIN?'))
def confirm_remove_pin(session_id): def confirm_remove_pin(ctx):
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.ui.text import Text from trezor.ui.text import Text
return require_confirm(session_id, Text( return require_confirm(ctx, Text(
'Remove PIN', ui.ICON_RESET, 'Remove PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD, 'Do you really want to', ui.BOLD,
'remove current PIN?')) 'remove current PIN?'))
@unimport @unimport
async def layout_change_pin(session_id, msg): async def layout_change_pin(ctx, msg):
from trezor.messages.Success import Success from trezor.messages.Success import Success
from apps.common.request_pin import protect_by_pin, request_pin_twice from apps.common.request_pin import protect_by_pin, request_pin_twice
from apps.common import storage from apps.common import storage
if msg.remove: if msg.remove:
if storage.is_protected_by_pin(): if storage.is_protected_by_pin():
await confirm_remove_pin(session_id) await confirm_remove_pin(ctx)
await protect_by_pin(session_id, at_least_once=True) await protect_by_pin(ctx, at_least_once=True)
pin = '' pin = ''
else: else:
if storage.is_protected_by_pin(): if storage.is_protected_by_pin():
await confirm_change_pin(session_id) await confirm_change_pin(ctx)
await protect_by_pin(session_id, at_least_once=True) await protect_by_pin(ctx, at_least_once=True)
else: else:
await confirm_set_pin(session_id) await confirm_set_pin(ctx)
pin = await request_pin_twice(session_id) pin = await request_pin_twice(ctx)
storage.load_settings(pin=pin) storage.load_settings(pin=pin)
if pin: if pin:

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_load_device(session_id, msg): async def layout_load_device(ctx, msg):
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.messages.FailureType import UnexpectedMessage, ProcessError from trezor.messages.FailureType import UnexpectedMessage, ProcessError
@ -20,7 +20,7 @@ async def layout_load_device(session_id, msg):
if not msg.skip_checksum and not bip39.check(msg.mnemonic): if not msg.skip_checksum and not bip39.check(msg.mnemonic):
raise wire.FailureError(ProcessError, 'Mnemonic is not valid') raise wire.FailureError(ProcessError, 'Mnemonic is not valid')
await require_confirm(session_id, Text( await require_confirm(ctx, Text(
'Loading seed', ui.ICON_RESET, 'Loading seed', ui.ICON_RESET,
ui.BOLD, 'Loading private seed', 'is not recommended.', ui.BOLD, 'Loading private seed', 'is not recommended.',
ui.NORMAL, 'Continue only if you', 'know what you are doing!')) ui.NORMAL, 'Continue only if you', 'know what you are doing!'))

View File

@ -11,7 +11,7 @@ def nth(n):
@unimport @unimport
async def layout_recovery_device(session_id, msg): async def layout_recovery_device(ctx, msg):
msg = 'Please enter ' + nth(msg.word_count) + ' word' msg = 'Please enter ' + nth(msg.word_count) + ' word'

View File

@ -10,7 +10,7 @@ if __debug__:
@unimport @unimport
async def layout_reset_device(session_id, msg): async def layout_reset_device(ctx, msg):
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.crypto import hashlib, random, bip39 from trezor.crypto import hashlib, random, bip39
from trezor.messages.EntropyRequest import EntropyRequest from trezor.messages.EntropyRequest import EntropyRequest
@ -39,21 +39,21 @@ async def layout_reset_device(session_id, msg):
if msg.display_random: if msg.display_random:
entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16) entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16)
entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines) entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines)
await require_confirm(session_id, entropy_content, ButtonRequestType.ResetDevice) await require_confirm(ctx, entropy_content, ButtonRequestType.ResetDevice)
if msg.pin_protection: if msg.pin_protection:
pin = await request_pin_twice(session_id) pin = await request_pin_twice(ctx)
else: else:
pin = None pin = None
external_entropy_ack = await wire.call(session_id, EntropyRequest(), EntropyAck) external_entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ctx = hashlib.sha256() ctx = hashlib.sha256()
ctx.update(internal_entropy) ctx.update(internal_entropy)
ctx.update(external_entropy_ack.entropy) ctx.update(external_entropy_ack.entropy)
entropy = ctx.digest() entropy = ctx.digest()
mnemonic = bip39.from_data(entropy[:msg.strength // 8]) mnemonic = bip39.from_data(entropy[:msg.strength // 8])
await show_mnemonic_by_word(session_id, mnemonic) await show_mnemonic_by_word(ctx, mnemonic)
storage.load_mnemonic(mnemonic) storage.load_mnemonic(mnemonic)
storage.load_settings(pin=pin, storage.load_settings(pin=pin,
@ -64,7 +64,7 @@ async def layout_reset_device(session_id, msg):
return Success(message='Initialized') return Success(message='Initialized')
async def show_mnemonic_by_word(session_id, mnemonic): async def show_mnemonic_by_word(ctx, mnemonic):
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.messages.ButtonRequestType import ConfirmWord from trezor.messages.ButtonRequestType import ConfirmWord
from apps.common.confirm import confirm from apps.common.confirm import confirm
@ -80,7 +80,7 @@ async def show_mnemonic_by_word(session_id, mnemonic):
while index < len(words): while index < len(words):
word = words[index] word = words[index]
current_word = word current_word = word
await confirm(session_id, await confirm(ctx,
Text( Text(
'Recovery seed setup', ui.ICON_RESET, 'Recovery seed setup', ui.ICON_RESET,
ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ', ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ',

View File

@ -3,13 +3,13 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_wipe_device(session_id, msg): async def layout_wipe_device(ctx, msg):
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.ui.text import Text from trezor.ui.text import Text
from ..common.confirm import hold_to_confirm from ..common.confirm import hold_to_confirm
from ..common import storage from ..common import storage
await hold_to_confirm(session_id, Text( await hold_to_confirm(ctx, Text(
'WIPE DEVICE', 'WIPE DEVICE',
ui.ICON_WIPE, ui.ICON_WIPE,
ui.NORMAL, 'Do you really want to', 'wipe the device?', ui.NORMAL, 'Do you really want to', 'wipe the device?',

View File

@ -26,7 +26,7 @@ def cipher_key_value(msg, seckey: bytes) -> bytes:
@unimport @unimport
async def layout_cipher_key_value(session_id, msg): async def layout_cipher_key_value(ctx, msg):
from trezor.messages.CipheredKeyValue import CipheredKeyValue from trezor.messages.CipheredKeyValue import CipheredKeyValue
from ..common import seed from ..common import seed
@ -38,7 +38,7 @@ async def layout_cipher_key_value(session_id, msg):
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK) ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK)
node = await seed.get_root(session_id) node = await seed.get_root(ctx)
node.derive_path(msg.address_n) node.derive_path(msg.address_n)
value = cipher_key_value(msg, node.private_key()) value = cipher_key_value(msg, node.private_key())

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_get_address(session_id, msg): async def layout_get_address(ctx, msg):
from trezor.messages.Address import Address from trezor.messages.Address import Address
from trezor.messages.FailureType import ProcessError from trezor.messages.FailureType import ProcessError
from ..common import coins from ..common import coins
@ -15,18 +15,18 @@ async def layout_get_address(session_id, msg):
address_n = msg.address_n or () address_n = msg.address_n or ()
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or 'Bitcoin'
node = await seed.get_root(session_id) node = await seed.get_root(ctx)
node.derive_path(address_n) node.derive_path(address_n)
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
address = node.address(coin.address_type) address = node.address(coin.address_type)
if msg.show_display: if msg.show_display:
await _show_address(session_id, address) await _show_address(ctx, address)
return Address(address=address) return Address(address=address)
async def _show_address(session_id, address): async def _show_address(ctx, address):
from trezor.messages.ButtonRequestType import Address from trezor.messages.ButtonRequestType import Address
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.ui.qr import Qr from trezor.ui.qr import Qr
@ -37,7 +37,7 @@ async def _show_address(session_id, address):
content = Container( content = Container(
Qr(address, (120, 135), 3), Qr(address, (120, 135), 3),
Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines)) Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines))
await require_confirm(session_id, content, code=Address) await require_confirm(ctx, content, code=Address)
def _split_address(address): def _split_address(address):

View File

@ -3,18 +3,18 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_get_entropy(session_id, msg): async def layout_get_entropy(ctx, msg):
from trezor.messages.Entropy import Entropy from trezor.messages.Entropy import Entropy
from trezor.crypto import random from trezor.crypto import random
l = min(msg.size, 1024) l = min(msg.size, 1024)
await _show_entropy(session_id) await _show_entropy(ctx)
return Entropy(entropy=random.bytes(l)) return Entropy(entropy=random.bytes(l))
async def _show_entropy(session_id): async def _show_entropy(ctx):
from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.ButtonRequestType import ProtectCall
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.ui.container import Container from trezor.ui.container import Container
@ -23,4 +23,4 @@ async def _show_entropy(session_id):
content = Container( content = Container(
Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?')) Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?'))
await require_confirm(session_id, content, code=ProtectCall) await require_confirm(ctx, content, code=ProtectCall)

View File

@ -2,7 +2,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_get_public_key(session_id, msg): async def layout_get_public_key(ctx, msg):
from trezor.messages.HDNodeType import HDNodeType from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.PublicKey import PublicKey from trezor.messages.PublicKey import PublicKey
from ..common import coins from ..common import coins
@ -11,7 +11,7 @@ async def layout_get_public_key(session_id, msg):
address_n = msg.address_n or () address_n = msg.address_n or ()
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or 'Bitcoin'
node = await seed.get_root(session_id) node = await seed.get_root(ctx)
node.derive_path(address_n) node.derive_path(address_n)
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)

View File

@ -83,7 +83,7 @@ def sign_challenge(seckey: bytes,
@unimport @unimport
async def layout_sign_identity(session_id, msg): async def layout_sign_identity(ctx, msg):
from trezor.messages.SignedIdentity import SignedIdentity from trezor.messages.SignedIdentity import SignedIdentity
from ..common import coins from ..common import coins
from ..common import seed from ..common import seed
@ -92,7 +92,7 @@ async def layout_sign_identity(session_id, msg):
display_identity(identity, msg.challenge_visual) display_identity(identity, msg.challenge_visual)
address_n = get_identity_path(identity, msg.identity.index or 0) address_n = get_identity_path(identity, msg.identity.index or 0)
node = await seed.get_root(session_id, msg.ecdsa_curve_name) node = await seed.get_root(ctx, msg.ecdsa_curve_name)
node.derive_path(address_n) node.derive_path(address_n)
coin = coins.by_name('Bitcoin') coin = coins.by_name('Bitcoin')

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_sign_message(session_id, msg): async def layout_sign_message(ctx, msg):
from trezor.messages.MessageSignature import MessageSignature from trezor.messages.MessageSignature import MessageSignature
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from ..common import coins from ..common import coins
@ -18,7 +18,7 @@ async def layout_sign_message(session_id, msg):
coin_name = msg.coin_name or 'Bitcoin' coin_name = msg.coin_name or 'Bitcoin'
coin = coins.by_name(coin_name) coin = coins.by_name(coin_name)
node = await seed.get_root(session_id) node = await seed.get_root(ctx)
node.derive_path(msg.address_n) node.derive_path(msg.address_n)
seckey = node.private_key() seckey = node.private_key()

View File

@ -3,7 +3,7 @@ from trezor import wire
@unimport @unimport
async def sign_tx(session_id, msg): async def sign_tx(ctx, msg):
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.wire_types import TxAck from trezor.messages.wire_types import TxAck
@ -11,7 +11,7 @@ async def sign_tx(session_id, msg):
from . import signing from . import signing
from . import layout from . import layout
root = await seed.get_root(session_id) root = await seed.get_root(ctx)
signer = signing.sign_tx(msg, root) signer = signing.sign_tx(msg, root)
res = None res = None
@ -23,13 +23,13 @@ async def sign_tx(session_id, msg):
if req.__qualname__ == 'TxRequest': if req.__qualname__ == 'TxRequest':
if req.request_type == TXFINISHED: if req.request_type == TXFINISHED:
break break
res = await wire.call(session_id, req, TxAck) res = await ctx.call(req, TxAck)
elif req.__qualname__ == 'UiConfirmOutput': elif req.__qualname__ == 'UiConfirmOutput':
res = await layout.confirm_output(session_id, req.output, req.coin) res = await layout.confirm_output(ctx, req.output, req.coin)
elif req.__qualname__ == 'UiConfirmTotal': elif req.__qualname__ == 'UiConfirmTotal':
res = await layout.confirm_total(session_id, req.spending, req.fee, req.coin) res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
elif req.__qualname__ == 'UiConfirmFeeOverThreshold': elif req.__qualname__ == 'UiConfirmFeeOverThreshold':
res = await layout.confirm_feeoverthreshold(session_id, req.fee, req.coin) res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
else: else:
raise TypeError('Invalid signing instruction') raise TypeError('Invalid signing instruction')
return req return req

View File

@ -14,22 +14,22 @@ def split_address(address):
return chunks(address, 17) return chunks(address, 17)
async def confirm_output(session_id, output, coin): async def confirm_output(ctx, output, coin):
content = Text('Confirm output', ui.ICON_RESET, content = Text('Confirm output', ui.ICON_RESET,
ui.BOLD, format_amount(output.amount, coin), ui.BOLD, format_amount(output.amount, coin),
ui.NORMAL, 'to', ui.NORMAL, 'to',
ui.MONO, *split_address(output.address)) ui.MONO, *split_address(output.address))
return await confirm(session_id, content, ButtonRequestType.ConfirmOutput) return await confirm(ctx, content, ButtonRequestType.ConfirmOutput)
async def confirm_total(session_id, spending, fee, coin): async def confirm_total(ctx, spending, fee, coin):
content = Text('Confirm transaction', ui.ICON_RESET, content = Text('Confirm transaction', ui.ICON_RESET,
'Sending: %s' % format_amount(spending, coin), 'Sending: %s' % format_amount(spending, coin),
'Fee: %s' % format_amount(fee, coin)) 'Fee: %s' % format_amount(fee, coin))
return await hold_to_confirm(session_id, content, ButtonRequestType.SignTx) return await hold_to_confirm(ctx, content, ButtonRequestType.SignTx)
async def confirm_feeoverthreshold(session_id, fee, coin): async def confirm_feeoverthreshold(ctx, fee, coin):
content = Text('Confirm high fee:', ui.ICON_RESET, content = Text('Confirm high fee:', ui.ICON_RESET,
ui.BOLD, format_amount(fee, coin)) ui.BOLD, format_amount(fee, coin))
return await confirm(session_id, content, ButtonRequestType.FeeOverThreshold) return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold)

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport @unimport
async def layout_verify_message(session_id, msg): async def layout_verify_message(ctx, msg):
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import ripemd160, sha256 from trezor.crypto.hashlib import ripemd160, sha256

View File

@ -4,26 +4,7 @@ from trezor import io
from trezor import wire from trezor import wire
from trezor import main from trezor import main
# Load applications # initialize the USB stack
from apps.common import storage
if __debug__:
from apps import debug
from apps import homescreen
from apps import management
from apps import wallet
from apps import ethereum
from apps import fido_u2f
# Boot applications
if __debug__:
debug.boot()
homescreen.boot()
management.boot()
wallet.boot()
ethereum.boot()
fido_u2f.boot()
# Intialize the USB stack
usb_wire = io.HID( usb_wire = io.HID(
iface_num=0x00, iface_num=0x00,
ep_in=0x81, ep_in=0x81,
@ -90,11 +71,30 @@ usb.add(usb_vcp)
usb.add(usb_u2f) usb.add(usb_u2f)
usb.open() usb.open()
# Initialize the wire codec pipeline # load applications
wire.setup(usb_wire.iface_num()) from apps.common import storage
if __debug__:
from apps import debug
from apps import homescreen
from apps import management
from apps import wallet
from apps import ethereum
from apps import fido_u2f
# Load default homescreen # boot applications
if __debug__:
debug.boot()
homescreen.boot()
management.boot()
wallet.boot()
ethereum.boot()
fido_u2f.boot(usb_u2f)
# initialize the wire codec pipeline
wire.setup(usb_wire)
# load default homescreen
from apps.homescreen.homescreen import layout_homescreen from apps.homescreen.homescreen import layout_homescreen
# Run main even loop and specify which screen is default # run main even loop and specify which screen is default
main.run(default_workflow=layout_homescreen) main.run(default_workflow=layout_homescreen)

View File

@ -14,7 +14,7 @@ async def load_uvarint(reader):
shift = 0 shift = 0
byte = 0x80 byte = 0x80
while byte & 0x80: while byte & 0x80:
await reader.readinto(buffer) await reader.areadinto(buffer)
byte = buffer[0] byte = buffer[0]
result += (byte & 0x7F) << shift result += (byte & 0x7F) << shift
shift += 7 shift += 7
@ -27,7 +27,7 @@ async def dump_uvarint(writer, n):
while shifted: while shifted:
shifted = n >> 7 shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.write(buffer) await writer.awrite(buffer)
n = shifted n = shifted
@ -69,11 +69,11 @@ class LimitedReader:
self.reader = reader self.reader = reader
self.limit = limit self.limit = limit
async def readinto(self, buf): async def areadinto(self, buf):
if self.limit < len(buf): if self.limit < len(buf):
raise EOFError raise EOFError
else: else:
nread = await self.reader.readinto(buf) nread = await self.reader.areadinto(buf)
self.limit -= nread self.limit -= nread
return nread return nread
@ -83,7 +83,7 @@ class CountingWriter:
def __init__(self): def __init__(self):
self.size = 0 self.size = 0
async def write(self, buf): async def awrite(self, buf):
nwritten = len(buf) nwritten = len(buf)
self.size += nwritten self.size += nwritten
return nwritten return nwritten
@ -112,7 +112,7 @@ async def load_message(reader, msg_type):
await load_uvarint(reader) await load_uvarint(reader)
elif wtype == 2: elif wtype == 2:
ivalue = await load_uvarint(reader) ivalue = await load_uvarint(reader)
await reader.readinto(bytearray(ivalue)) await reader.areadinto(bytearray(ivalue))
else: else:
raise ValueError raise ValueError
continue continue
@ -129,10 +129,10 @@ async def load_message(reader, msg_type):
fvalue = bool(ivalue) fvalue = bool(ivalue)
elif ftype is BytesType: elif ftype is BytesType:
fvalue = bytearray(ivalue) fvalue = bytearray(ivalue)
await reader.readinto(fvalue) await reader.areadinto(fvalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
fvalue = bytearray(ivalue) fvalue = bytearray(ivalue)
await reader.readinto(fvalue) await reader.areadinto(fvalue)
fvalue = str(fvalue, 'utf8') fvalue = str(fvalue, 'utf8')
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype) fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
@ -186,11 +186,11 @@ async def dump_message(writer, msg):
elif ftype is BytesType: elif ftype is BytesType:
await dump_uvarint(writer, len(svalue)) await dump_uvarint(writer, len(svalue))
await writer.write(svalue) await writer.awrite(svalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
await dump_uvarint(writer, len(svalue)) await dump_uvarint(writer, len(svalue))
await writer.write(bytes(svalue, 'utf8')) await writer.awrite(bytes(svalue, 'utf8'))
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
counter = CountingWriter() counter = CountingWriter()

View File

@ -8,58 +8,82 @@ from trezor import workflow
from . import codec_v1 from . import codec_v1
from . import codec_v2 from . import codec_v2
workflows = {} workflow_handlers = {}
def register(wire_type, handler, *args): def register(mtype, handler, *args):
if wire_type in workflows: '''Register `handler` to get scheduled after `mtype` message is received.'''
if mtype in workflow_handlers:
raise KeyError raise KeyError
workflows[wire_type] = (handler, args) workflow_handlers[mtype] = (handler, args)
def setup(interface): def setup(iface):
session_supervisor = codec_v2.SesssionSupervisor(interface, '''Initialize the wire stack on passed USB interface.'''
session_handler) session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler)
session_supervisor.open(codec_v1.SESSION_ID) session_supervisor.open(codec_v1.SESSION_ID)
loop.schedule_task(session_supervisor.listen()) loop.schedule_task(session_supervisor.listen())
class Context: class Context:
def __init__(self, interface, session_id): def __init__(self, iface, sid):
self.interface = interface self.iface = iface
self.session_id = session_id self.sid = sid
def get_reader(self):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Reader(self.interface)
else:
return codec_v2.Reader(self.interface, self.session_id)
def get_writer(self, mtype, msize):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Writer(self.interface, mtype, msize)
else:
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
async def read(self, types):
reader = self.get_reader()
await reader.open()
if reader.type not in types:
raise UnexpectedMessageError(reader)
return await protobuf.load_message(reader,
messages.get_type(reader.type))
async def write(self, msg):
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.close()
async def call(self, msg, types): async def call(self, msg, types):
'''
Reply with `msg` and wait for one of `types`. See `self.write()` and
`self.read()`.
'''
await self.write(msg) await self.write(msg)
return await self.read(types) return await self.read(types)
async def read(self, types):
'''
Wait for incoming message on this wire context and return it. Raises
`UnexpectedMessageError` if the message type does not match one of
`types`; and caller should always make sure to re-raise it.
'''
reader = self.getreader()
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in types:
raise UnexpectedMessageError(reader)
# look up the protobuf class and parse the message
pbtype = messages.get_type(reader.type)
return await protobuf.load_message(reader, pbtype)
async def write(self, msg):
'''
Write a protobuf message to this wire context.
'''
writer = self.getwriter()
# get the message size
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
# write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.aclose()
def getreader(self):
if self.sid == codec_v1.SESSION_ID:
return codec_v1.Reader(self.iface)
else:
return codec_v2.Reader(self.iface, self.sid)
def getwriter(self):
if self.sid == codec_v1.SESSION_ID:
return codec_v1.Writer(self.iface)
else:
return codec_v2.Writer(self.iface, self.sid)
class UnexpectedMessageError(Exception): class UnexpectedMessageError(Exception):
def __init__(self, reader): def __init__(self, reader):
@ -74,60 +98,69 @@ class FailureError(Exception):
self.message = message self.message = message
class Workflow: async def session_handler(iface, sid):
def __init__(self, default): reader = None
self.handlers = {} ctx = Context(iface, sid)
self.default = default while True:
try:
async def __call__(self, interface, session_id): # wait for new message, if needed, and find handler
ctx = Context(interface, session_id) if not reader:
while True: reader = ctx.getreader()
await reader.aopen()
try: try:
reader = ctx.get_reader() handler, args = workflow_handlers[reader.type]
await reader.open() except KeyError:
try: handler, args = unexpected_msg, ()
handler = self.handlers[reader.type]
except KeyError: await handler(ctx, reader, *args)
handler = self.default
try: except UnexpectedMessageError as exc:
await handler(ctx, reader) # retry with opened reader from the exception
except UnexpectedMessageError as unexp_msg: reader = exc.reader
reader = unexp_msg.reader continue
except Exception as e: except FailureError as exc:
log.exception(__name__, e) # we log FailureError as warning, not as exception
log.warning(__name__, 'failure: %s', exc.message)
except Exception as exc:
# sessions are never closed by raised exceptions
log.exception(__name__, exc)
# read new message in next iteration
reader = None
async def protobuf_workflow(ctx, reader, handler, *args): async def protobuf_workflow(ctx, reader, handler, *args):
msg = await protobuf.load_message(reader, messages.get_type(reader.type)) from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
req = await protobuf.load_message(reader, messages.get_type(reader.type))
try: try:
res = await handler(reader.sid, msg, *args) res = await handler(ctx, req, *args)
except Exception as exc: except UnexpectedMessageError:
if not isinstance(exc, UnexpectedMessageError): # session handler takes care of this one
await ctx.write(make_failure_msg(exc))
raise raise
else: except FailureError as exc:
if res: # respond with specific code and message
await ctx.write(res) await ctx.write(Failure(code=exc.code, message=exc.message))
raise
except Exception as exc:
# respond with a generic code and message
await ctx.write(Failure(code=FirmwareError, message='Firmware error'))
raise
if res:
# respond with a specific response
await ctx.write(res)
async def handle_unexp_msg(ctx, reader): async def unexpected_msg(ctx, reader):
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage
# receive the message and throw it away # receive the message and throw it away
while reader.size > 0: while reader.size > 0:
buf = bytearray(reader.size) buf = bytearray(reader.size)
await reader.readinto(buf) await reader.areadinto(buf)
# respond with an unknown message error # respond with an unknown message error
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage
await ctx.write( await ctx.write(
Failure(code=UnexpectedMessage, message='Unexpected message')) Failure(code=UnexpectedMessage, message='Unexpected message'))
def make_failure_msg(exc):
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
if isinstance(exc, FailureError):
code = exc.code
message = exc.message
else:
code = FirmwareError
message = 'Firmware Error'
return Failure(code=code, message=message)

View File

@ -1,7 +1,6 @@
from micropython import const from micropython import const
import ustruct import ustruct
from trezor import io
from trezor import loop from trezor import loop
from trezor import utils from trezor import utils
@ -32,13 +31,13 @@ class Reader:
def __repr__(self): def __repr__(self):
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size) return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
async def open(self): async def aopen(self):
''' '''
Begin the message transmission by waiting for initial V2 message report Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and on this session. `self.type` and `self.size` are initialized and
available after `open()` returns. available after `aopen()` returns.
''' '''
read = loop.select(self.iface | loop.READ) read = loop.select(self.iface.iface_num() | loop.READ)
while True: while True:
# wait for initial report # wait for initial report
report = await read report = await read
@ -55,7 +54,7 @@ class Reader:
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize] self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0 self.ofs = 0
async def readinto(self, buf): async def areadinto(self, buf):
''' '''
Read exactly `len(buf)` bytes into `buf`, waiting for additional Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -64,7 +63,7 @@ class Reader:
if self.size < len(buf): if self.size < len(buf):
raise EOFError raise EOFError
read = loop.select(self.iface | loop.READ) read = loop.select(self.iface.iface_num() | loop.READ)
nread = 0 nread = 0
while nread < len(buf): while nread < len(buf):
if self.ofs == len(self.data): if self.ofs == len(self.data):
@ -93,20 +92,28 @@ class Writer:
async-file-like interface. async-file-like interface.
''' '''
def __init__(self, iface, mtype, msize): def __init__(self, iface):
self.iface = iface self.iface = iface
self.type = mtype self.type = None
self.size = msize self.size = None
self.data = bytearray(_REP_LEN) self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA self.ofs = 0
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
def __repr__(self): def __repr__(self):
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size) return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
async def write(self, buf): def setheader(self, mtype, msize):
'''
Reset the writer state and load the message header with passed type and
total message size.
'''
self.type = mtype
self.size = msize
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC,
_REP_MAGIC, mtype, msize)
self.ofs = _REP_INIT_DATA
async def awrite(self, buf):
''' '''
Encode and write every byte from `buf`. Does not need to be called in Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf` case message has zero length. Raises `EOFError` if the length of `buf`
@ -115,7 +122,7 @@ class Writer:
if self.size < len(buf): if self.size < len(buf):
raise EOFError raise EOFError
write = loop.select(self.iface | loop.WRITE) write = loop.select(self.iface.iface_num() | loop.WRITE)
nwritten = 0 nwritten = 0
while nwritten < len(buf): while nwritten < len(buf):
# copy as much as possible to report buffer # copy as much as possible to report buffer
@ -127,12 +134,12 @@ class Writer:
if self.ofs == _REP_LEN: if self.ofs == _REP_LEN:
# we are at the end of the report, flush it # we are at the end of the report, flush it
await write await write
io.write(self.iface, self.data) self.iface.write(self.data)
self.ofs = _REP_CONT_DATA self.ofs = _REP_CONT_DATA
return nwritten return nwritten
async def close(self): async def aclose(self):
'''Flush and close the message transmission.''' '''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA: if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned, # we didn't write anything or last write() wasn't report-aligned,
@ -141,5 +148,5 @@ class Writer:
self.data[self.ofs] = 0x00 self.data[self.ofs] = 0x00
self.ofs += 1 self.ofs += 1
await loop.select(self.iface | loop.WRITE) await loop.select(self.iface.iface_num() | loop.WRITE)
io.send(self.iface, self.data) self.iface.write(self.data)

View File

@ -1,7 +1,6 @@
from micropython import const from micropython import const
import ustruct import ustruct
from trezor import io
from trezor import loop from trezor import loop
from trezor import utils from trezor import utils
from trezor.crypto import random from trezor.crypto import random
@ -28,11 +27,11 @@ _REP_MARKER_CONT = const(0x02)
_REP_MARKER_OPEN = const(0x03) _REP_MARKER_OPEN = const(0x03)
_REP_MARKER_CLOSE = const(0x04) _REP_MARKER_CLOSE = const(0x04)
_REP = '>BL' # marker, session_id _REP = '>BL' # marker, session_id
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size _REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
_REP_CONT = '>BLL' # marker, session_id, sequence _REP_CONT = '>BLL' # marker, session_id, sequence
_REP_INIT_DATA = const(13) # offset of data in init report _REP_INIT_DATA = const(13) # offset of data in init report
_REP_CONT_DATA = const(9) # offset of data in cont report _REP_CONT_DATA = const(9) # offset of data in cont report
class Reader: class Reader:
@ -51,15 +50,16 @@ class Reader:
self.seq = 0 self.seq = 0
def __repr__(self): def __repr__(self):
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size) return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type,
self.size)
async def open(self): async def aopen(self):
''' '''
Begin the message transmission by waiting for initial V2 message report Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and on this session. `self.type` and `self.size` are initialized and
available after `open()` returns. available after `aopen()` returns.
''' '''
read = loop.select(self.iface | loop.READ) read = loop.select(self.iface.iface_num() | loop.READ)
while True: while True:
# wait for initial report # wait for initial report
report = await read report = await read
@ -74,7 +74,7 @@ class Reader:
self.ofs = 0 self.ofs = 0
self.seq = 0 self.seq = 0
async def readinto(self, buf): async def areadinto(self, buf):
''' '''
Read exactly `len(buf)` bytes into `buf`, waiting for additional Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -83,7 +83,7 @@ class Reader:
if self.size < len(buf): if self.size < len(buf):
raise EOFError raise EOFError
read = loop.select(self.iface | loop.READ) read = loop.select(self.iface.iface_num() | loop.READ)
nread = 0 nread = 0
while nread < len(buf): while nread < len(buf):
if self.ofs == len(self.data): if self.ofs == len(self.data):
@ -115,20 +115,28 @@ class Writer:
interface. interface.
''' '''
def __init__(self, iface, sid, mtype, msize): def __init__(self, iface, sid):
self.iface = iface self.iface = iface
self.sid = sid self.sid = sid
self.type = None
self.size = None
self.data = bytearray(_REP_LEN)
self.ofs = 0
self.seq = 0
def setheader(self, mtype, msize):
'''
Reset the writer state and load the message header with passed type and
total message size.
'''
self.type = mtype self.type = mtype
self.size = msize self.size = msize
self.data = bytearray(_REP_LEN) ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER_INIT, self.sid,
mtype, msize)
self.ofs = _REP_INIT_DATA self.ofs = _REP_INIT_DATA
self.seq = 0 self.seq = 0
# load the report with initial header async def awrite(self, buf):
ustruct.pack_into(_REP_INIT, self.data, 0,
_REP_MARKER_INIT, sid, mtype, msize)
async def write(self, buf):
''' '''
Encode and write every byte from `buf`. Does not need to be called in Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf` case message has zero length. Raises `EOFError` if the length of `buf`
@ -137,7 +145,7 @@ class Writer:
if self.size < len(buf): if self.size < len(buf):
raise EOFError raise EOFError
write = loop.select(self.iface | loop.WRITE) write = loop.select(self.iface.iface_num() | loop.WRITE)
nwritten = 0 nwritten = 0
while nwritten < len(buf): while nwritten < len(buf):
# copy as much as possible to report buffer # copy as much as possible to report buffer
@ -149,15 +157,15 @@ class Writer:
if self.ofs == _REP_LEN: if self.ofs == _REP_LEN:
# we are at the end of the report, flush it, and prepare header # we are at the end of the report, flush it, and prepare header
await write await write
io.send(self.iface, self.data) self.iface.write(self.data)
ustruct.pack_into(_REP_CONT, self.data, 0, ustruct.pack_into(_REP_CONT, self.data, 0, _REP_MARKER_CONT,
_REP_MARKER_CONT, self.sid, self.seq) self.sid, self.seq)
self.ofs = _REP_CONT_DATA self.ofs = _REP_CONT_DATA
self.seq += 1 self.seq += 1
return nwritten return nwritten
async def close(self): async def aclose(self):
'''Flush and close the message transmission.''' '''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA: if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned, # we didn't write anything or last write() wasn't report-aligned,
@ -166,8 +174,8 @@ class Writer:
self.data[self.ofs] = 0x00 self.data[self.ofs] = 0x00
self.ofs += 1 self.ofs += 1
await loop.select(self.iface | loop.WRITE) await loop.select(self.iface.iface_num() | loop.WRITE)
io.send(self.iface, self.data) self.iface.write(self.data)
class SesssionSupervisor: class SesssionSupervisor:
@ -186,8 +194,8 @@ class SesssionSupervisor:
After close request, the handling task is closed and session terminated. After close request, the handling task is closed and session terminated.
Both requests receive responses confirming the operation. Both requests receive responses confirming the operation.
''' '''
read = loop.select(self.iface | loop.READ) read = loop.select(self.iface.iface_num() | loop.READ)
write = loop.select(self.iface | loop.WRITE) write = loop.select(self.iface.iface_num() | loop.WRITE)
while True: while True:
report = await read report = await read
repmarker, repsid = ustruct.unpack(_REP, report) repmarker, repsid = ustruct.unpack(_REP, report)
@ -225,8 +233,8 @@ class SesssionSupervisor:
def writeopen(self, sid): def writeopen(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid) ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
io.write(self.iface, self.session_report) self.iface.write(self.session_report)
def writeclose(self, sid): def writeclose(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid) ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
io.write(self.iface, self.session_report) self.iface.write(self.session_report)

View File

@ -26,26 +26,26 @@ def test_reader():
# open, expected one read # open, expected one read
first_report = report_header + message[:rep_len - len(report_header)] first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
# empty read # empty read
empty_buffer = bytearray() empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0) assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
# short read, expected no read # short read, expected no read
short_buffer = bytearray(32) short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32) assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)]) assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer)) assert_eq(reader.size, message_len - len(short_buffer))
# aligned read, expected no read # aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
@ -53,12 +53,12 @@ def test_reader():
next_report_header = bytearray(unhexlify('3f')) next_report_header = bytearray(unhexlify('3f'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1) onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
# too long read, raises eof # too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
# long read, expect multiple reads # long read, expect multiple reads
start_size = reader.size start_size = reader.size
@ -74,12 +74,12 @@ def test_reader():
prev_report = next_reports[i - 1] if i > 0 else None prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration())) expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls) assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:]) assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0) assert_eq(reader.size, 0)
# one byte read, raises eof # one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
def test_writer(): def test_writer():
@ -87,7 +87,8 @@ def test_writer():
interface = 0xdeadbeef interface = 0xdeadbeef
message_type = 0x87654321 message_type = 0x87654321
message_len = 1024 message_len = 1024
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len) writer = codec_v1.Writer(interface, codec_v1.SESSION_ID)
writer.setheader(message_type, message_len)
# init header corresponding to the data above # init header corresponding to the data above
report_header = bytearray(unhexlify('3f2323432100000400')) report_header = bytearray(unhexlify('3f2323432100000400'))
@ -96,14 +97,14 @@ def test_writer():
# empty write # empty write
start_size = writer.size start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),]) assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size) assert_eq(writer.size, start_size)
# short write, expected no report # short write, expected no report
start_size = writer.size start_size = writer.size
short_payload = bytearray(range(4)) short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),]) assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data, assert_eq(writer.data,
report_header report_header
@ -118,7 +119,7 @@ def test_writer():
+ short_payload + short_payload
+ aligned_payload + aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload)) assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1) msg.send.assert_called_n_times(1)
msg.send = msg.send.original msg.send = msg.send.original
@ -126,7 +127,7 @@ def test_writer():
# short write, expected no report, but data starts with correct seq and cont marker # short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('3f')) report_header = bytearray(unhexlify('3f'))
start_size = writer.size start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),]) assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)], assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload) report_header + short_payload)
@ -142,19 +143,19 @@ def test_writer():
# test write # test write
expected_write_reports = expected_reports[:-1] expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload)) assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports)) msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original msg.send = msg.send.original
# test write raises eof # test write raises eof
msg.send = mock_call(msg.send, []) msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())]) assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0) msg.send.assert_called_n_times(0)
msg.send = msg.send.original msg.send = msg.send.original
# test close # test close
expected_close_reports = expected_reports[-1:] expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0) assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports)) msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original msg.send = msg.send.original

View File

@ -26,26 +26,26 @@ def test_reader():
# open, expected one read # open, expected one read
first_report = report_header + message[:rep_len - len(report_header)] first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
# empty read # empty read
empty_buffer = bytearray() empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0) assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
# short read, expected no read # short read, expected no read
short_buffer = bytearray(32) short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32) assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)]) assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer)) assert_eq(reader.size, message_len - len(short_buffer))
# aligned read, expected no read # aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),]) assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
@ -53,12 +53,12 @@ def test_reader():
next_report_header = bytearray(unhexlify('021234567800000000')) next_report_header = bytearray(unhexlify('021234567800000000'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1) onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
# too long read, raises eof # too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),]) assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
# long read, expect multiple reads # long read, expect multiple reads
start_size = reader.size start_size = reader.size
@ -74,12 +74,12 @@ def test_reader():
prev_report = next_reports[i - 1] if i > 0 else None prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration())) expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls) assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:]) assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0) assert_eq(reader.size, 0)
# one byte read, raises eof # one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
def test_writer(): def test_writer():
@ -88,7 +88,8 @@ def test_writer():
session_id = 0x12345678 session_id = 0x12345678
message_type = 0x87654321 message_type = 0x87654321
message_len = 1024 message_len = 1024
writer = codec_v2.Writer(interface, session_id, message_type, message_len) writer = codec_v2.Writer(interface, session_id)
writer.setheader(message_type, message_len)
# init header corresponding to the data above # init header corresponding to the data above
report_header = bytearray(unhexlify('01123456788765432100000400')) report_header = bytearray(unhexlify('01123456788765432100000400'))
@ -97,14 +98,14 @@ def test_writer():
# empty write # empty write
start_size = writer.size start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),]) assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size) assert_eq(writer.size, start_size)
# short write, expected no report # short write, expected no report
start_size = writer.size start_size = writer.size
short_payload = bytearray(range(4)) short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),]) assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data, assert_eq(writer.data,
report_header report_header
@ -119,7 +120,7 @@ def test_writer():
+ short_payload + short_payload
+ aligned_payload + aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),]) assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload)) assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1) msg.send.assert_called_n_times(1)
msg.send = msg.send.original msg.send = msg.send.original
@ -127,7 +128,7 @@ def test_writer():
# short write, expected no report, but data starts with correct seq and cont marker # short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('021234567800000000')) report_header = bytearray(unhexlify('021234567800000000'))
start_size = writer.size start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),]) assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload)) assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)], assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload) report_header + short_payload)
@ -145,19 +146,19 @@ def test_writer():
# test write # test write
expected_write_reports = expected_reports[:-1] expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload)) assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports)) msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original msg.send = msg.send.original
# test write raises eof # test write raises eof
msg.send = mock_call(msg.send, []) msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())]) assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0) msg.send.assert_called_n_times(0)
msg.send = msg.send.original msg.send = msg.send.original
# test close # test close
expected_close_reports = expected_reports[-1:] expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0) assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports)) msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original msg.send = msg.send.original