mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-27 07:40:59 +00:00
wire: pass Context to apps
This commit is contained in:
parent
b1b84fb233
commit
3562ffdc54
@ -7,7 +7,7 @@ signal = loop.Signal()
|
||||
|
||||
|
||||
@unimport
|
||||
async def confirm(session_id, content, code=None, *args, **kwargs):
|
||||
async def confirm(ctx, content, code=None, *args, **kwargs):
|
||||
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
||||
from trezor.messages.ButtonRequest import ButtonRequest
|
||||
from trezor.messages.ButtonRequestType import Other
|
||||
@ -19,12 +19,12 @@ async def confirm(session_id, content, code=None, *args, **kwargs):
|
||||
|
||||
if code is None:
|
||||
code = Other
|
||||
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
|
||||
await ctx.call(ButtonRequest(code=code), ButtonAck)
|
||||
return await loop.Wait((signal, dialog)) == CONFIRMED
|
||||
|
||||
|
||||
@unimport
|
||||
async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
|
||||
async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
|
||||
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
|
||||
from trezor.messages.ButtonRequest import ButtonRequest
|
||||
from trezor.messages.ButtonRequestType import Other
|
||||
@ -36,7 +36,7 @@ async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
|
||||
|
||||
if code is None:
|
||||
code = Other
|
||||
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
|
||||
await ctx.call(ButtonRequest(code=code), ButtonAck)
|
||||
return await loop.Wait((signal, dialog)) == CONFIRMED
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from trezor import ui, wire
|
||||
|
||||
|
||||
async def request_passphrase(session_id):
|
||||
async def request_passphrase(ctx):
|
||||
from trezor.messages.FailureType import ActionCancelled
|
||||
from trezor.messages.PassphraseRequest import PassphraseRequest
|
||||
from trezor.messages.wire_types import PassphraseAck, Cancel
|
||||
@ -12,17 +12,17 @@ async def request_passphrase(session_id):
|
||||
'Please enter passphrase', 'on your computer.')
|
||||
text.render()
|
||||
|
||||
ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel)
|
||||
ack = await ctx.call(PassphraseRequest(), PassphraseAck, Cancel)
|
||||
if ack.MESSAGE_WIRE_TYPE == Cancel:
|
||||
raise wire.FailureError(ActionCancelled, 'Passphrase cancelled')
|
||||
|
||||
return ack.passphrase
|
||||
|
||||
|
||||
async def protect_by_passphrase(session_id):
|
||||
async def protect_by_passphrase(ctx):
|
||||
from apps.common import storage
|
||||
|
||||
if storage.is_protected_by_passphrase():
|
||||
return await request_passphrase(session_id)
|
||||
return await request_passphrase(ctx)
|
||||
else:
|
||||
return ''
|
||||
|
@ -7,7 +7,7 @@ if __debug__:
|
||||
|
||||
|
||||
@unimport
|
||||
async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
||||
async def request_pin_on_display(ctx: wire.Context, code: int=None) -> str:
|
||||
from trezor.messages.ButtonRequest import ButtonRequest
|
||||
from trezor.messages.ButtonRequestType import ProtectCall
|
||||
from trezor.messages.FailureType import PinCancelled
|
||||
@ -20,9 +20,8 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
||||
|
||||
_, label = _get_code_and_label(code)
|
||||
|
||||
await wire.call(session_id,
|
||||
ButtonRequest(code=ProtectCall),
|
||||
ButtonAck)
|
||||
await ctx.call(ButtonRequest(code=ProtectCall),
|
||||
ButtonAck)
|
||||
|
||||
ui.display.clear()
|
||||
matrix = PinMatrix(label)
|
||||
@ -36,7 +35,7 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
||||
|
||||
|
||||
@unimport
|
||||
async def request_pin_on_client(session_id: int, code: int=None) -> str:
|
||||
async def request_pin_on_client(ctx: wire.Context, code: int=None) -> str:
|
||||
from trezor.messages.FailureType import PinCancelled
|
||||
from trezor.messages.PinMatrixRequest import PinMatrixRequest
|
||||
from trezor.messages.wire_types import PinMatrixAck, Cancel
|
||||
@ -51,9 +50,8 @@ async def request_pin_on_client(session_id: int, code: int=None) -> str:
|
||||
matrix = PinMatrix(label)
|
||||
matrix.render()
|
||||
|
||||
ack = await wire.call(session_id,
|
||||
PinMatrixRequest(type=code),
|
||||
PinMatrixAck, Cancel)
|
||||
ack = await ctx.call(PinMatrixRequest(type=code),
|
||||
PinMatrixAck, Cancel)
|
||||
digits = matrix.digits
|
||||
matrix = None
|
||||
|
||||
@ -66,12 +64,12 @@ request_pin = request_pin_on_client
|
||||
|
||||
|
||||
@unimport
|
||||
async def request_pin_twice(session_id: int) -> str:
|
||||
async def request_pin_twice(ctx: wire.Context) -> str:
|
||||
from trezor.messages.FailureType import ActionCancelled
|
||||
from trezor.messages import PinMatrixRequestType
|
||||
|
||||
pin_first = await request_pin(session_id, PinMatrixRequestType.NewFirst)
|
||||
pin_again = await request_pin(session_id, PinMatrixRequestType.NewSecond)
|
||||
pin_first = await request_pin(ctx, PinMatrixRequestType.NewFirst)
|
||||
pin_again = await request_pin(ctx, PinMatrixRequestType.NewSecond)
|
||||
if pin_first != pin_again:
|
||||
# changed message due to consistency with T1 msgs
|
||||
raise wire.FailureError(ActionCancelled, 'PIN change failed')
|
||||
@ -79,22 +77,22 @@ async def request_pin_twice(session_id: int) -> str:
|
||||
return pin_first
|
||||
|
||||
|
||||
async def protect_by_pin_repeatedly(session_id: int, at_least_once: bool=False):
|
||||
async def protect_by_pin_repeatedly(ctx: wire.Context, at_least_once: bool=False):
|
||||
from . import storage
|
||||
|
||||
locked = storage.is_locked() or at_least_once
|
||||
while locked:
|
||||
pin = await request_pin(session_id)
|
||||
pin = await request_pin(ctx)
|
||||
locked = not storage.unlock(pin, _render_pin_failure)
|
||||
|
||||
|
||||
async def protect_by_pin_or_fail(session_id: int, at_least_once: bool=False):
|
||||
async def protect_by_pin_or_fail(ctx: wire.Context, at_least_once: bool=False):
|
||||
from trezor.messages.FailureType import PinInvalid
|
||||
from . import storage
|
||||
|
||||
locked = storage.is_locked() or at_least_once
|
||||
if locked:
|
||||
pin = await request_pin(session_id)
|
||||
pin = await request_pin(ctx)
|
||||
if not storage.unlock(pin, _render_pin_failure):
|
||||
raise wire.FailureError(PinInvalid, 'PIN invalid')
|
||||
|
||||
|
@ -5,20 +5,20 @@ from trezor.crypto import bip39
|
||||
_DEFAULT_CURVE = 'secp256k1'
|
||||
|
||||
|
||||
async def get_root(session_id: int, curve_name=_DEFAULT_CURVE):
|
||||
seed = await get_seed(session_id)
|
||||
async def get_root(ctx: wire.Context, curve_name=_DEFAULT_CURVE):
|
||||
seed = await get_seed(ctx)
|
||||
root = bip32.from_seed(seed, curve_name)
|
||||
return root
|
||||
|
||||
|
||||
async def get_seed(session_id: int) -> bytes:
|
||||
async def get_seed(ctx: wire.Context) -> bytes:
|
||||
from . import cache
|
||||
if cache.seed is None:
|
||||
cache.seed = await compute_seed(session_id)
|
||||
cache.seed = await compute_seed(ctx)
|
||||
return cache.seed
|
||||
|
||||
|
||||
async def compute_seed(session_id: int) -> bytes:
|
||||
async def compute_seed(ctx: wire.Context) -> bytes:
|
||||
from trezor.messages.FailureType import ProcessError
|
||||
from .request_passphrase import protect_by_passphrase
|
||||
from .request_pin import protect_by_pin
|
||||
@ -27,9 +27,9 @@ async def compute_seed(session_id: int) -> bytes:
|
||||
if not storage.is_initialized():
|
||||
raise wire.FailureError(ProcessError, 'Device is not initialized')
|
||||
|
||||
await protect_by_pin(session_id)
|
||||
await protect_by_pin(ctx)
|
||||
|
||||
passphrase = await protect_by_passphrase(session_id)
|
||||
passphrase = await protect_by_passphrase(ctx)
|
||||
return bip39.seed(storage.get_mnemonic(), passphrase)
|
||||
|
||||
|
||||
|
@ -4,13 +4,13 @@ from trezor.messages.wire_types import \
|
||||
DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase
|
||||
|
||||
|
||||
async def dispatch_DebugLinkDecision(session_id, msg):
|
||||
async def dispatch_DebugLinkDecision(ctx, msg):
|
||||
from trezor.ui.confirm import CONFIRMED, CANCELLED
|
||||
from apps.common.confirm import signal
|
||||
signal.send(CONFIRMED if msg.yes_no else CANCELLED)
|
||||
|
||||
|
||||
async def dispatch_DebugLinkGetState(session_id, msg):
|
||||
async def dispatch_DebugLinkGetState(ctx, msg):
|
||||
from trezor.messages.DebugLinkState import DebugLinkState
|
||||
from apps.common import storage, request_pin
|
||||
from apps.management import reset_device
|
||||
@ -36,11 +36,11 @@ async def dispatch_DebugLinkGetState(session_id, msg):
|
||||
return m
|
||||
|
||||
|
||||
async def dispatch_DebugLinkStop(session_id, msg):
|
||||
async def dispatch_DebugLinkStop(ctx, msg):
|
||||
pass
|
||||
|
||||
|
||||
async def dispatch_DebugLinkMemoryRead(session_id, msg):
|
||||
async def dispatch_DebugLinkMemoryRead(ctx, msg):
|
||||
from trezor.messages.DebugLinkMemory import DebugLinkMemory
|
||||
from uctypes import bytes_at
|
||||
m = DebugLinkMemory()
|
||||
@ -48,14 +48,14 @@ async def dispatch_DebugLinkMemoryRead(session_id, msg):
|
||||
return m
|
||||
|
||||
|
||||
async def dispatch_DebugLinkMemoryWrite(session_id, msg):
|
||||
async def dispatch_DebugLinkMemoryWrite(ctx, msg):
|
||||
from uctypes import bytearray_at
|
||||
l = len(msg.memory)
|
||||
data = bytearray_at(msg.address, l)
|
||||
data[0:l] = msg.memory
|
||||
|
||||
|
||||
async def dispatch_DebugLinkFlashErase(session_id, msg):
|
||||
async def dispatch_DebugLinkFlashErase(ctx, msg):
|
||||
# TODO: erase(msg.sector)
|
||||
pass
|
||||
|
||||
|
@ -3,13 +3,13 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_ethereum_get_address(session_id, msg):
|
||||
async def layout_ethereum_get_address(ctx, msg):
|
||||
from trezor.messages.EthereumAddress import EthereumAddress
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from ..common import seed
|
||||
|
||||
node = await seed.get_root(session_id)
|
||||
node = await seed.get_root(ctx)
|
||||
node.derive_path(msg.address_n or ())
|
||||
|
||||
seckey = node.private_key()
|
||||
@ -17,11 +17,11 @@ async def layout_ethereum_get_address(session_id, msg):
|
||||
address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak
|
||||
|
||||
if msg.show_display:
|
||||
await _show_address(session_id, address)
|
||||
await _show_address(ctx, address)
|
||||
return EthereumAddress(address=address)
|
||||
|
||||
|
||||
async def _show_address(session_id, address):
|
||||
async def _show_address(ctx, address):
|
||||
from trezor.messages.ButtonRequestType import Address
|
||||
from trezor.ui.text import Text
|
||||
from ..common.confirm import require_confirm
|
||||
@ -30,7 +30,7 @@ async def _show_address(session_id, address):
|
||||
|
||||
content = Text('Confirm address', ui.ICON_RESET,
|
||||
ui.MONO, *_split_address(address))
|
||||
await require_confirm(session_id, content, code=Address)
|
||||
await require_confirm(ctx, content, code=Address)
|
||||
|
||||
|
||||
def _split_address(address):
|
||||
|
@ -226,11 +226,11 @@ class Cmd:
|
||||
return Msg(self.cid, cla, ins, p1, p2, lc, data)
|
||||
|
||||
|
||||
async def read_cmd(iface: int) -> Cmd:
|
||||
async def read_cmd(iface: io.HID) -> Cmd:
|
||||
desc_init = frame_init()
|
||||
desc_cont = frame_cont()
|
||||
|
||||
buf, = await loop.select(iface)
|
||||
buf, = await loop.select(iface.iface_num())
|
||||
# log.debug(__name__, 'read init %s', buf)
|
||||
|
||||
ifrm = overlay_struct(buf, desc_init)
|
||||
@ -252,7 +252,7 @@ async def read_cmd(iface: int) -> Cmd:
|
||||
data = data[:bcnt]
|
||||
|
||||
while datalen < bcnt:
|
||||
buf, = await loop.select(iface)
|
||||
buf, = await loop.select(iface.iface_num())
|
||||
# log.debug(__name__, 'read cont %s', buf)
|
||||
|
||||
cfrm = overlay_struct(buf, desc_cont)
|
||||
@ -282,7 +282,7 @@ async def read_cmd(iface: int) -> Cmd:
|
||||
return Cmd(ifrm.cid, ifrm.cmd, data)
|
||||
|
||||
|
||||
def send_cmd(cmd: Cmd, iface: int) -> None:
|
||||
def send_cmd(cmd: Cmd, iface: io.HID) -> None:
|
||||
init_desc = frame_init()
|
||||
cont_desc = frame_cont()
|
||||
offset = 0
|
||||
@ -295,7 +295,7 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
|
||||
frm.bcnt = datalen
|
||||
|
||||
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||
io.send(iface, buf)
|
||||
iface.write(buf)
|
||||
# log.debug(__name__, 'send init %s', buf)
|
||||
|
||||
if offset < datalen:
|
||||
@ -304,18 +304,17 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
|
||||
while offset < datalen:
|
||||
frm.seq = seq
|
||||
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||
utime.sleep_ms(1) # FIXME: do async send
|
||||
io.send(iface, buf)
|
||||
utime.sleep_ms(1) # FIXME: async write
|
||||
iface.write(buf)
|
||||
# log.debug(__name__, 'send cont %s', buf)
|
||||
seq += 1
|
||||
|
||||
|
||||
def boot():
|
||||
iface = 0x03
|
||||
def boot(iface: io.HID):
|
||||
loop.schedule_task(handle_reports(iface))
|
||||
|
||||
|
||||
async def handle_reports(iface: int):
|
||||
async def handle_reports(iface: io.HID):
|
||||
while True:
|
||||
try:
|
||||
req = await read_cmd(iface)
|
||||
|
@ -4,7 +4,7 @@ from trezor.messages.wire_types import Initialize, GetFeatures, Ping
|
||||
|
||||
|
||||
@unimport
|
||||
async def respond_Features(session_id, msg):
|
||||
async def respond_Features(ctx, msg):
|
||||
from apps.common import storage, coins
|
||||
from trezor.messages.Features import Features
|
||||
|
||||
@ -28,7 +28,7 @@ async def respond_Features(session_id, msg):
|
||||
|
||||
|
||||
@unimport
|
||||
async def respond_Pong(session_id, msg):
|
||||
async def respond_Pong(ctx, msg):
|
||||
from trezor.messages.Success import Success
|
||||
|
||||
s = Success()
|
||||
@ -36,11 +36,11 @@ async def respond_Pong(session_id, msg):
|
||||
|
||||
if msg.pin_protection:
|
||||
from apps.common.request_pin import protect_by_pin
|
||||
await protect_by_pin(session_id)
|
||||
await protect_by_pin(ctx)
|
||||
|
||||
if msg.passphrase_protection:
|
||||
from apps.common.request_passphrase import protect_by_passphrase
|
||||
await protect_by_passphrase(session_id)
|
||||
await protect_by_passphrase(ctx)
|
||||
|
||||
# TODO: handle other fields:
|
||||
# button_protection
|
||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_apply_settings(session_id, msg):
|
||||
async def layout_apply_settings(ctx, msg):
|
||||
from trezor.messages.Success import Success
|
||||
from trezor.messages.FailureType import ProcessError
|
||||
from trezor.ui.text import Text
|
||||
@ -11,7 +11,7 @@ async def layout_apply_settings(session_id, msg):
|
||||
from ..common.request_pin import protect_by_pin
|
||||
from ..common import storage
|
||||
|
||||
await protect_by_pin(session_id)
|
||||
await protect_by_pin(ctx)
|
||||
|
||||
if msg.homescreen is not None:
|
||||
raise wire.FailureError(
|
||||
@ -21,20 +21,20 @@ async def layout_apply_settings(session_id, msg):
|
||||
raise wire.FailureError(ProcessError, 'No setting provided')
|
||||
|
||||
if msg.label is not None:
|
||||
await require_confirm(session_id, Text(
|
||||
await require_confirm(ctx, Text(
|
||||
'Change label', ui.ICON_RESET,
|
||||
'Do you really want to', 'change label to',
|
||||
ui.BOLD, '%s' % msg.label))
|
||||
|
||||
if msg.language is not None:
|
||||
await require_confirm(session_id, Text(
|
||||
await require_confirm(ctx, Text(
|
||||
'Change language', ui.ICON_RESET,
|
||||
'Do you really want to', 'change language to',
|
||||
ui.BOLD, '%s' % msg.language,
|
||||
ui.NORMAL, '?'))
|
||||
|
||||
if msg.use_passphrase is not None:
|
||||
await require_confirm(session_id, Text(
|
||||
await require_confirm(ctx, Text(
|
||||
'Enable passphrase' if msg.use_passphrase else 'Disable passphrase',
|
||||
ui.ICON_RESET,
|
||||
'Do you really want to',
|
||||
|
@ -2,52 +2,52 @@ from trezor import ui
|
||||
from trezor.utils import unimport
|
||||
|
||||
|
||||
def confirm_set_pin(session_id):
|
||||
def confirm_set_pin(ctx):
|
||||
from apps.common.confirm import require_confirm
|
||||
from trezor.ui.text import Text
|
||||
return require_confirm(session_id, Text(
|
||||
return require_confirm(ctx, Text(
|
||||
'Change PIN', ui.ICON_RESET,
|
||||
'Do you really want to', ui.BOLD,
|
||||
'set new PIN?'))
|
||||
|
||||
|
||||
def confirm_change_pin(session_id):
|
||||
def confirm_change_pin(ctx):
|
||||
from apps.common.confirm import require_confirm
|
||||
from trezor.ui.text import Text
|
||||
return require_confirm(session_id, Text(
|
||||
return require_confirm(ctx, Text(
|
||||
'Change PIN', ui.ICON_RESET,
|
||||
'Do you really want to', ui.BOLD,
|
||||
'change current PIN?'))
|
||||
|
||||
|
||||
def confirm_remove_pin(session_id):
|
||||
def confirm_remove_pin(ctx):
|
||||
from apps.common.confirm import require_confirm
|
||||
from trezor.ui.text import Text
|
||||
return require_confirm(session_id, Text(
|
||||
return require_confirm(ctx, Text(
|
||||
'Remove PIN', ui.ICON_RESET,
|
||||
'Do you really want to', ui.BOLD,
|
||||
'remove current PIN?'))
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_change_pin(session_id, msg):
|
||||
async def layout_change_pin(ctx, msg):
|
||||
from trezor.messages.Success import Success
|
||||
from apps.common.request_pin import protect_by_pin, request_pin_twice
|
||||
from apps.common import storage
|
||||
|
||||
if msg.remove:
|
||||
if storage.is_protected_by_pin():
|
||||
await confirm_remove_pin(session_id)
|
||||
await protect_by_pin(session_id, at_least_once=True)
|
||||
await confirm_remove_pin(ctx)
|
||||
await protect_by_pin(ctx, at_least_once=True)
|
||||
pin = ''
|
||||
|
||||
else:
|
||||
if storage.is_protected_by_pin():
|
||||
await confirm_change_pin(session_id)
|
||||
await protect_by_pin(session_id, at_least_once=True)
|
||||
await confirm_change_pin(ctx)
|
||||
await protect_by_pin(ctx, at_least_once=True)
|
||||
else:
|
||||
await confirm_set_pin(session_id)
|
||||
pin = await request_pin_twice(session_id)
|
||||
await confirm_set_pin(ctx)
|
||||
pin = await request_pin_twice(ctx)
|
||||
|
||||
storage.load_settings(pin=pin)
|
||||
if pin:
|
||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_load_device(session_id, msg):
|
||||
async def layout_load_device(ctx, msg):
|
||||
from trezor.crypto import bip39
|
||||
from trezor.messages.Success import Success
|
||||
from trezor.messages.FailureType import UnexpectedMessage, ProcessError
|
||||
@ -20,7 +20,7 @@ async def layout_load_device(session_id, msg):
|
||||
if not msg.skip_checksum and not bip39.check(msg.mnemonic):
|
||||
raise wire.FailureError(ProcessError, 'Mnemonic is not valid')
|
||||
|
||||
await require_confirm(session_id, Text(
|
||||
await require_confirm(ctx, Text(
|
||||
'Loading seed', ui.ICON_RESET,
|
||||
ui.BOLD, 'Loading private seed', 'is not recommended.',
|
||||
ui.NORMAL, 'Continue only if you', 'know what you are doing!'))
|
||||
|
@ -11,7 +11,7 @@ def nth(n):
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_recovery_device(session_id, msg):
|
||||
async def layout_recovery_device(ctx, msg):
|
||||
|
||||
msg = 'Please enter ' + nth(msg.word_count) + ' word'
|
||||
|
||||
|
@ -10,7 +10,7 @@ if __debug__:
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_reset_device(session_id, msg):
|
||||
async def layout_reset_device(ctx, msg):
|
||||
from trezor.ui.text import Text
|
||||
from trezor.crypto import hashlib, random, bip39
|
||||
from trezor.messages.EntropyRequest import EntropyRequest
|
||||
@ -39,21 +39,21 @@ async def layout_reset_device(session_id, msg):
|
||||
if msg.display_random:
|
||||
entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16)
|
||||
entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines)
|
||||
await require_confirm(session_id, entropy_content, ButtonRequestType.ResetDevice)
|
||||
await require_confirm(ctx, entropy_content, ButtonRequestType.ResetDevice)
|
||||
|
||||
if msg.pin_protection:
|
||||
pin = await request_pin_twice(session_id)
|
||||
pin = await request_pin_twice(ctx)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
external_entropy_ack = await wire.call(session_id, EntropyRequest(), EntropyAck)
|
||||
external_entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
|
||||
ctx = hashlib.sha256()
|
||||
ctx.update(internal_entropy)
|
||||
ctx.update(external_entropy_ack.entropy)
|
||||
entropy = ctx.digest()
|
||||
mnemonic = bip39.from_data(entropy[:msg.strength // 8])
|
||||
|
||||
await show_mnemonic_by_word(session_id, mnemonic)
|
||||
await show_mnemonic_by_word(ctx, mnemonic)
|
||||
|
||||
storage.load_mnemonic(mnemonic)
|
||||
storage.load_settings(pin=pin,
|
||||
@ -64,7 +64,7 @@ async def layout_reset_device(session_id, msg):
|
||||
return Success(message='Initialized')
|
||||
|
||||
|
||||
async def show_mnemonic_by_word(session_id, mnemonic):
|
||||
async def show_mnemonic_by_word(ctx, mnemonic):
|
||||
from trezor.ui.text import Text
|
||||
from trezor.messages.ButtonRequestType import ConfirmWord
|
||||
from apps.common.confirm import confirm
|
||||
@ -80,7 +80,7 @@ async def show_mnemonic_by_word(session_id, mnemonic):
|
||||
while index < len(words):
|
||||
word = words[index]
|
||||
current_word = word
|
||||
await confirm(session_id,
|
||||
await confirm(ctx,
|
||||
Text(
|
||||
'Recovery seed setup', ui.ICON_RESET,
|
||||
ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ',
|
||||
|
@ -3,13 +3,13 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_wipe_device(session_id, msg):
|
||||
async def layout_wipe_device(ctx, msg):
|
||||
from trezor.messages.Success import Success
|
||||
from trezor.ui.text import Text
|
||||
from ..common.confirm import hold_to_confirm
|
||||
from ..common import storage
|
||||
|
||||
await hold_to_confirm(session_id, Text(
|
||||
await hold_to_confirm(ctx, Text(
|
||||
'WIPE DEVICE',
|
||||
ui.ICON_WIPE,
|
||||
ui.NORMAL, 'Do you really want to', 'wipe the device?',
|
||||
|
@ -26,7 +26,7 @@ def cipher_key_value(msg, seckey: bytes) -> bytes:
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_cipher_key_value(session_id, msg):
|
||||
async def layout_cipher_key_value(ctx, msg):
|
||||
from trezor.messages.CipheredKeyValue import CipheredKeyValue
|
||||
from ..common import seed
|
||||
|
||||
@ -38,7 +38,7 @@ async def layout_cipher_key_value(session_id, msg):
|
||||
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
||||
ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK)
|
||||
|
||||
node = await seed.get_root(session_id)
|
||||
node = await seed.get_root(ctx)
|
||||
node.derive_path(msg.address_n)
|
||||
|
||||
value = cipher_key_value(msg, node.private_key())
|
||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_get_address(session_id, msg):
|
||||
async def layout_get_address(ctx, msg):
|
||||
from trezor.messages.Address import Address
|
||||
from trezor.messages.FailureType import ProcessError
|
||||
from ..common import coins
|
||||
@ -15,18 +15,18 @@ async def layout_get_address(session_id, msg):
|
||||
address_n = msg.address_n or ()
|
||||
coin_name = msg.coin_name or 'Bitcoin'
|
||||
|
||||
node = await seed.get_root(session_id)
|
||||
node = await seed.get_root(ctx)
|
||||
node.derive_path(address_n)
|
||||
coin = coins.by_name(coin_name)
|
||||
address = node.address(coin.address_type)
|
||||
|
||||
if msg.show_display:
|
||||
await _show_address(session_id, address)
|
||||
await _show_address(ctx, address)
|
||||
|
||||
return Address(address=address)
|
||||
|
||||
|
||||
async def _show_address(session_id, address):
|
||||
async def _show_address(ctx, address):
|
||||
from trezor.messages.ButtonRequestType import Address
|
||||
from trezor.ui.text import Text
|
||||
from trezor.ui.qr import Qr
|
||||
@ -37,7 +37,7 @@ async def _show_address(session_id, address):
|
||||
content = Container(
|
||||
Qr(address, (120, 135), 3),
|
||||
Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines))
|
||||
await require_confirm(session_id, content, code=Address)
|
||||
await require_confirm(ctx, content, code=Address)
|
||||
|
||||
|
||||
def _split_address(address):
|
||||
|
@ -3,18 +3,18 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_get_entropy(session_id, msg):
|
||||
async def layout_get_entropy(ctx, msg):
|
||||
from trezor.messages.Entropy import Entropy
|
||||
from trezor.crypto import random
|
||||
|
||||
l = min(msg.size, 1024)
|
||||
|
||||
await _show_entropy(session_id)
|
||||
await _show_entropy(ctx)
|
||||
|
||||
return Entropy(entropy=random.bytes(l))
|
||||
|
||||
|
||||
async def _show_entropy(session_id):
|
||||
async def _show_entropy(ctx):
|
||||
from trezor.messages.ButtonRequestType import ProtectCall
|
||||
from trezor.ui.text import Text
|
||||
from trezor.ui.container import Container
|
||||
@ -23,4 +23,4 @@ async def _show_entropy(session_id):
|
||||
content = Container(
|
||||
Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?'))
|
||||
|
||||
await require_confirm(session_id, content, code=ProtectCall)
|
||||
await require_confirm(ctx, content, code=ProtectCall)
|
||||
|
@ -2,7 +2,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_get_public_key(session_id, msg):
|
||||
async def layout_get_public_key(ctx, msg):
|
||||
from trezor.messages.HDNodeType import HDNodeType
|
||||
from trezor.messages.PublicKey import PublicKey
|
||||
from ..common import coins
|
||||
@ -11,7 +11,7 @@ async def layout_get_public_key(session_id, msg):
|
||||
address_n = msg.address_n or ()
|
||||
coin_name = msg.coin_name or 'Bitcoin'
|
||||
|
||||
node = await seed.get_root(session_id)
|
||||
node = await seed.get_root(ctx)
|
||||
node.derive_path(address_n)
|
||||
coin = coins.by_name(coin_name)
|
||||
|
||||
|
@ -83,7 +83,7 @@ def sign_challenge(seckey: bytes,
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_sign_identity(session_id, msg):
|
||||
async def layout_sign_identity(ctx, msg):
|
||||
from trezor.messages.SignedIdentity import SignedIdentity
|
||||
from ..common import coins
|
||||
from ..common import seed
|
||||
@ -92,7 +92,7 @@ async def layout_sign_identity(session_id, msg):
|
||||
display_identity(identity, msg.challenge_visual)
|
||||
|
||||
address_n = get_identity_path(identity, msg.identity.index or 0)
|
||||
node = await seed.get_root(session_id, msg.ecdsa_curve_name)
|
||||
node = await seed.get_root(ctx, msg.ecdsa_curve_name)
|
||||
node.derive_path(address_n)
|
||||
|
||||
coin = coins.by_name('Bitcoin')
|
||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_sign_message(session_id, msg):
|
||||
async def layout_sign_message(ctx, msg):
|
||||
from trezor.messages.MessageSignature import MessageSignature
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from ..common import coins
|
||||
@ -18,7 +18,7 @@ async def layout_sign_message(session_id, msg):
|
||||
coin_name = msg.coin_name or 'Bitcoin'
|
||||
coin = coins.by_name(coin_name)
|
||||
|
||||
node = await seed.get_root(session_id)
|
||||
node = await seed.get_root(ctx)
|
||||
node.derive_path(msg.address_n)
|
||||
|
||||
seckey = node.private_key()
|
||||
|
@ -3,7 +3,7 @@ from trezor import wire
|
||||
|
||||
|
||||
@unimport
|
||||
async def sign_tx(session_id, msg):
|
||||
async def sign_tx(ctx, msg):
|
||||
from trezor.messages.RequestType import TXFINISHED
|
||||
from trezor.messages.wire_types import TxAck
|
||||
|
||||
@ -11,7 +11,7 @@ async def sign_tx(session_id, msg):
|
||||
from . import signing
|
||||
from . import layout
|
||||
|
||||
root = await seed.get_root(session_id)
|
||||
root = await seed.get_root(ctx)
|
||||
|
||||
signer = signing.sign_tx(msg, root)
|
||||
res = None
|
||||
@ -23,13 +23,13 @@ async def sign_tx(session_id, msg):
|
||||
if req.__qualname__ == 'TxRequest':
|
||||
if req.request_type == TXFINISHED:
|
||||
break
|
||||
res = await wire.call(session_id, req, TxAck)
|
||||
res = await ctx.call(req, TxAck)
|
||||
elif req.__qualname__ == 'UiConfirmOutput':
|
||||
res = await layout.confirm_output(session_id, req.output, req.coin)
|
||||
res = await layout.confirm_output(ctx, req.output, req.coin)
|
||||
elif req.__qualname__ == 'UiConfirmTotal':
|
||||
res = await layout.confirm_total(session_id, req.spending, req.fee, req.coin)
|
||||
res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
|
||||
elif req.__qualname__ == 'UiConfirmFeeOverThreshold':
|
||||
res = await layout.confirm_feeoverthreshold(session_id, req.fee, req.coin)
|
||||
res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
|
||||
else:
|
||||
raise TypeError('Invalid signing instruction')
|
||||
return req
|
||||
|
@ -14,22 +14,22 @@ def split_address(address):
|
||||
return chunks(address, 17)
|
||||
|
||||
|
||||
async def confirm_output(session_id, output, coin):
|
||||
async def confirm_output(ctx, output, coin):
|
||||
content = Text('Confirm output', ui.ICON_RESET,
|
||||
ui.BOLD, format_amount(output.amount, coin),
|
||||
ui.NORMAL, 'to',
|
||||
ui.MONO, *split_address(output.address))
|
||||
return await confirm(session_id, content, ButtonRequestType.ConfirmOutput)
|
||||
return await confirm(ctx, content, ButtonRequestType.ConfirmOutput)
|
||||
|
||||
|
||||
async def confirm_total(session_id, spending, fee, coin):
|
||||
async def confirm_total(ctx, spending, fee, coin):
|
||||
content = Text('Confirm transaction', ui.ICON_RESET,
|
||||
'Sending: %s' % format_amount(spending, coin),
|
||||
'Fee: %s' % format_amount(fee, coin))
|
||||
return await hold_to_confirm(session_id, content, ButtonRequestType.SignTx)
|
||||
return await hold_to_confirm(ctx, content, ButtonRequestType.SignTx)
|
||||
|
||||
|
||||
async def confirm_feeoverthreshold(session_id, fee, coin):
|
||||
async def confirm_feeoverthreshold(ctx, fee, coin):
|
||||
content = Text('Confirm high fee:', ui.ICON_RESET,
|
||||
ui.BOLD, format_amount(fee, coin))
|
||||
return await confirm(session_id, content, ButtonRequestType.FeeOverThreshold)
|
||||
return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold)
|
||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
||||
|
||||
|
||||
@unimport
|
||||
async def layout_verify_message(session_id, msg):
|
||||
async def layout_verify_message(ctx, msg):
|
||||
from trezor.messages.Success import Success
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import ripemd160, sha256
|
||||
|
48
src/main.py
48
src/main.py
@ -4,26 +4,7 @@ from trezor import io
|
||||
from trezor import wire
|
||||
from trezor import main
|
||||
|
||||
# Load applications
|
||||
from apps.common import storage
|
||||
if __debug__:
|
||||
from apps import debug
|
||||
from apps import homescreen
|
||||
from apps import management
|
||||
from apps import wallet
|
||||
from apps import ethereum
|
||||
from apps import fido_u2f
|
||||
|
||||
# Boot applications
|
||||
if __debug__:
|
||||
debug.boot()
|
||||
homescreen.boot()
|
||||
management.boot()
|
||||
wallet.boot()
|
||||
ethereum.boot()
|
||||
fido_u2f.boot()
|
||||
|
||||
# Intialize the USB stack
|
||||
# initialize the USB stack
|
||||
usb_wire = io.HID(
|
||||
iface_num=0x00,
|
||||
ep_in=0x81,
|
||||
@ -90,11 +71,30 @@ usb.add(usb_vcp)
|
||||
usb.add(usb_u2f)
|
||||
usb.open()
|
||||
|
||||
# Initialize the wire codec pipeline
|
||||
wire.setup(usb_wire.iface_num())
|
||||
# load applications
|
||||
from apps.common import storage
|
||||
if __debug__:
|
||||
from apps import debug
|
||||
from apps import homescreen
|
||||
from apps import management
|
||||
from apps import wallet
|
||||
from apps import ethereum
|
||||
from apps import fido_u2f
|
||||
|
||||
# Load default homescreen
|
||||
# boot applications
|
||||
if __debug__:
|
||||
debug.boot()
|
||||
homescreen.boot()
|
||||
management.boot()
|
||||
wallet.boot()
|
||||
ethereum.boot()
|
||||
fido_u2f.boot(usb_u2f)
|
||||
|
||||
# initialize the wire codec pipeline
|
||||
wire.setup(usb_wire)
|
||||
|
||||
# load default homescreen
|
||||
from apps.homescreen.homescreen import layout_homescreen
|
||||
|
||||
# Run main even loop and specify which screen is default
|
||||
# run main even loop and specify which screen is default
|
||||
main.run(default_workflow=layout_homescreen)
|
||||
|
@ -14,7 +14,7 @@ async def load_uvarint(reader):
|
||||
shift = 0
|
||||
byte = 0x80
|
||||
while byte & 0x80:
|
||||
await reader.readinto(buffer)
|
||||
await reader.areadinto(buffer)
|
||||
byte = buffer[0]
|
||||
result += (byte & 0x7F) << shift
|
||||
shift += 7
|
||||
@ -27,7 +27,7 @@ async def dump_uvarint(writer, n):
|
||||
while shifted:
|
||||
shifted = n >> 7
|
||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await writer.write(buffer)
|
||||
await writer.awrite(buffer)
|
||||
n = shifted
|
||||
|
||||
|
||||
@ -69,11 +69,11 @@ class LimitedReader:
|
||||
self.reader = reader
|
||||
self.limit = limit
|
||||
|
||||
async def readinto(self, buf):
|
||||
async def areadinto(self, buf):
|
||||
if self.limit < len(buf):
|
||||
raise EOFError
|
||||
else:
|
||||
nread = await self.reader.readinto(buf)
|
||||
nread = await self.reader.areadinto(buf)
|
||||
self.limit -= nread
|
||||
return nread
|
||||
|
||||
@ -83,7 +83,7 @@ class CountingWriter:
|
||||
def __init__(self):
|
||||
self.size = 0
|
||||
|
||||
async def write(self, buf):
|
||||
async def awrite(self, buf):
|
||||
nwritten = len(buf)
|
||||
self.size += nwritten
|
||||
return nwritten
|
||||
@ -112,7 +112,7 @@ async def load_message(reader, msg_type):
|
||||
await load_uvarint(reader)
|
||||
elif wtype == 2:
|
||||
ivalue = await load_uvarint(reader)
|
||||
await reader.readinto(bytearray(ivalue))
|
||||
await reader.areadinto(bytearray(ivalue))
|
||||
else:
|
||||
raise ValueError
|
||||
continue
|
||||
@ -129,10 +129,10 @@ async def load_message(reader, msg_type):
|
||||
fvalue = bool(ivalue)
|
||||
elif ftype is BytesType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.readinto(fvalue)
|
||||
await reader.areadinto(fvalue)
|
||||
elif ftype is UnicodeType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.readinto(fvalue)
|
||||
await reader.areadinto(fvalue)
|
||||
fvalue = str(fvalue, 'utf8')
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
||||
@ -186,11 +186,11 @@ async def dump_message(writer, msg):
|
||||
|
||||
elif ftype is BytesType:
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.write(svalue)
|
||||
await writer.awrite(svalue)
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.write(bytes(svalue, 'utf8'))
|
||||
await writer.awrite(bytes(svalue, 'utf8'))
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
counter = CountingWriter()
|
||||
|
@ -8,58 +8,82 @@ from trezor import workflow
|
||||
from . import codec_v1
|
||||
from . import codec_v2
|
||||
|
||||
workflows = {}
|
||||
workflow_handlers = {}
|
||||
|
||||
|
||||
def register(wire_type, handler, *args):
|
||||
if wire_type in workflows:
|
||||
def register(mtype, handler, *args):
|
||||
'''Register `handler` to get scheduled after `mtype` message is received.'''
|
||||
if mtype in workflow_handlers:
|
||||
raise KeyError
|
||||
workflows[wire_type] = (handler, args)
|
||||
workflow_handlers[mtype] = (handler, args)
|
||||
|
||||
|
||||
def setup(interface):
|
||||
session_supervisor = codec_v2.SesssionSupervisor(interface,
|
||||
session_handler)
|
||||
def setup(iface):
|
||||
'''Initialize the wire stack on passed USB interface.'''
|
||||
session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler)
|
||||
session_supervisor.open(codec_v1.SESSION_ID)
|
||||
loop.schedule_task(session_supervisor.listen())
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, interface, session_id):
|
||||
self.interface = interface
|
||||
self.session_id = session_id
|
||||
|
||||
def get_reader(self):
|
||||
if self.session_id == codec_v1.SESSION_ID:
|
||||
return codec_v1.Reader(self.interface)
|
||||
else:
|
||||
return codec_v2.Reader(self.interface, self.session_id)
|
||||
|
||||
def get_writer(self, mtype, msize):
|
||||
if self.session_id == codec_v1.SESSION_ID:
|
||||
return codec_v1.Writer(self.interface, mtype, msize)
|
||||
else:
|
||||
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
|
||||
|
||||
async def read(self, types):
|
||||
reader = self.get_reader()
|
||||
await reader.open()
|
||||
if reader.type not in types:
|
||||
raise UnexpectedMessageError(reader)
|
||||
return await protobuf.load_message(reader,
|
||||
messages.get_type(reader.type))
|
||||
|
||||
async def write(self, msg):
|
||||
counter = protobuf.CountingWriter()
|
||||
await protobuf.dump_message(counter, msg)
|
||||
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||
await protobuf.dump_message(writer, msg)
|
||||
await writer.close()
|
||||
def __init__(self, iface, sid):
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
|
||||
async def call(self, msg, types):
|
||||
'''
|
||||
Reply with `msg` and wait for one of `types`. See `self.write()` and
|
||||
`self.read()`.
|
||||
'''
|
||||
await self.write(msg)
|
||||
return await self.read(types)
|
||||
|
||||
async def read(self, types):
|
||||
'''
|
||||
Wait for incoming message on this wire context and return it. Raises
|
||||
`UnexpectedMessageError` if the message type does not match one of
|
||||
`types`; and caller should always make sure to re-raise it.
|
||||
'''
|
||||
reader = self.getreader()
|
||||
|
||||
await reader.aopen() # wait for the message header
|
||||
|
||||
# if we got a message with unexpected type, raise the reader via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it
|
||||
if reader.type not in types:
|
||||
raise UnexpectedMessageError(reader)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
pbtype = messages.get_type(reader.type)
|
||||
return await protobuf.load_message(reader, pbtype)
|
||||
|
||||
async def write(self, msg):
|
||||
'''
|
||||
Write a protobuf message to this wire context.
|
||||
'''
|
||||
writer = self.getwriter()
|
||||
|
||||
# get the message size
|
||||
counter = protobuf.CountingWriter()
|
||||
await protobuf.dump_message(counter, msg)
|
||||
|
||||
# write the message
|
||||
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||
await protobuf.dump_message(writer, msg)
|
||||
await writer.aclose()
|
||||
|
||||
def getreader(self):
|
||||
if self.sid == codec_v1.SESSION_ID:
|
||||
return codec_v1.Reader(self.iface)
|
||||
else:
|
||||
return codec_v2.Reader(self.iface, self.sid)
|
||||
|
||||
def getwriter(self):
|
||||
if self.sid == codec_v1.SESSION_ID:
|
||||
return codec_v1.Writer(self.iface)
|
||||
else:
|
||||
return codec_v2.Writer(self.iface, self.sid)
|
||||
|
||||
|
||||
class UnexpectedMessageError(Exception):
|
||||
def __init__(self, reader):
|
||||
@ -74,60 +98,69 @@ class FailureError(Exception):
|
||||
self.message = message
|
||||
|
||||
|
||||
class Workflow:
|
||||
def __init__(self, default):
|
||||
self.handlers = {}
|
||||
self.default = default
|
||||
|
||||
async def __call__(self, interface, session_id):
|
||||
ctx = Context(interface, session_id)
|
||||
while True:
|
||||
async def session_handler(iface, sid):
|
||||
reader = None
|
||||
ctx = Context(iface, sid)
|
||||
while True:
|
||||
try:
|
||||
# wait for new message, if needed, and find handler
|
||||
if not reader:
|
||||
reader = ctx.getreader()
|
||||
await reader.aopen()
|
||||
try:
|
||||
reader = ctx.get_reader()
|
||||
await reader.open()
|
||||
try:
|
||||
handler = self.handlers[reader.type]
|
||||
except KeyError:
|
||||
handler = self.default
|
||||
try:
|
||||
await handler(ctx, reader)
|
||||
except UnexpectedMessageError as unexp_msg:
|
||||
reader = unexp_msg.reader
|
||||
except Exception as e:
|
||||
log.exception(__name__, e)
|
||||
handler, args = workflow_handlers[reader.type]
|
||||
except KeyError:
|
||||
handler, args = unexpected_msg, ()
|
||||
|
||||
await handler(ctx, reader, *args)
|
||||
|
||||
except UnexpectedMessageError as exc:
|
||||
# retry with opened reader from the exception
|
||||
reader = exc.reader
|
||||
continue
|
||||
except FailureError as exc:
|
||||
# we log FailureError as warning, not as exception
|
||||
log.warning(__name__, 'failure: %s', exc.message)
|
||||
except Exception as exc:
|
||||
# sessions are never closed by raised exceptions
|
||||
log.exception(__name__, exc)
|
||||
|
||||
# read new message in next iteration
|
||||
reader = None
|
||||
|
||||
|
||||
async def protobuf_workflow(ctx, reader, handler, *args):
|
||||
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import FirmwareError
|
||||
|
||||
req = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||||
try:
|
||||
res = await handler(reader.sid, msg, *args)
|
||||
except Exception as exc:
|
||||
if not isinstance(exc, UnexpectedMessageError):
|
||||
await ctx.write(make_failure_msg(exc))
|
||||
res = await handler(ctx, req, *args)
|
||||
except UnexpectedMessageError:
|
||||
# session handler takes care of this one
|
||||
raise
|
||||
else:
|
||||
if res:
|
||||
await ctx.write(res)
|
||||
except FailureError as exc:
|
||||
# respond with specific code and message
|
||||
await ctx.write(Failure(code=exc.code, message=exc.message))
|
||||
raise
|
||||
except Exception as exc:
|
||||
# respond with a generic code and message
|
||||
await ctx.write(Failure(code=FirmwareError, message='Firmware error'))
|
||||
raise
|
||||
if res:
|
||||
# respond with a specific response
|
||||
await ctx.write(res)
|
||||
|
||||
|
||||
async def handle_unexp_msg(ctx, reader):
|
||||
async def unexpected_msg(ctx, reader):
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import UnexpectedMessage
|
||||
|
||||
# receive the message and throw it away
|
||||
while reader.size > 0:
|
||||
buf = bytearray(reader.size)
|
||||
await reader.readinto(buf)
|
||||
await reader.areadinto(buf)
|
||||
|
||||
# respond with an unknown message error
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import UnexpectedMessage
|
||||
await ctx.write(
|
||||
Failure(code=UnexpectedMessage, message='Unexpected message'))
|
||||
|
||||
def make_failure_msg(exc):
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.messages.FailureType import FirmwareError
|
||||
if isinstance(exc, FailureError):
|
||||
code = exc.code
|
||||
message = exc.message
|
||||
else:
|
||||
code = FirmwareError
|
||||
message = 'Firmware Error'
|
||||
return Failure(code=code, message=message)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from micropython import const
|
||||
import ustruct
|
||||
|
||||
from trezor import io
|
||||
from trezor import loop
|
||||
from trezor import utils
|
||||
|
||||
@ -32,13 +31,13 @@ class Reader:
|
||||
def __repr__(self):
|
||||
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
|
||||
|
||||
async def open(self):
|
||||
async def aopen(self):
|
||||
'''
|
||||
Begin the message transmission by waiting for initial V2 message report
|
||||
on this session. `self.type` and `self.size` are initialized and
|
||||
available after `open()` returns.
|
||||
available after `aopen()` returns.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||
while True:
|
||||
# wait for initial report
|
||||
report = await read
|
||||
@ -55,7 +54,7 @@ class Reader:
|
||||
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||
self.ofs = 0
|
||||
|
||||
async def readinto(self, buf):
|
||||
async def areadinto(self, buf):
|
||||
'''
|
||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||
@ -64,7 +63,7 @@ class Reader:
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||
nread = 0
|
||||
while nread < len(buf):
|
||||
if self.ofs == len(self.data):
|
||||
@ -93,20 +92,28 @@ class Writer:
|
||||
async-file-like interface.
|
||||
'''
|
||||
|
||||
def __init__(self, iface, mtype, msize):
|
||||
def __init__(self, iface):
|
||||
self.iface = iface
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.type = None
|
||||
self.size = None
|
||||
self.data = bytearray(_REP_LEN)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
|
||||
# load the report with initial header
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
|
||||
self.ofs = 0
|
||||
|
||||
def __repr__(self):
|
||||
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
|
||||
|
||||
async def write(self, buf):
|
||||
def setheader(self, mtype, msize):
|
||||
'''
|
||||
Reset the writer state and load the message header with passed type and
|
||||
total message size.
|
||||
'''
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC,
|
||||
_REP_MAGIC, mtype, msize)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
|
||||
async def awrite(self, buf):
|
||||
'''
|
||||
Encode and write every byte from `buf`. Does not need to be called in
|
||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||
@ -115,7 +122,7 @@ class Writer:
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||
nwritten = 0
|
||||
while nwritten < len(buf):
|
||||
# copy as much as possible to report buffer
|
||||
@ -127,12 +134,12 @@ class Writer:
|
||||
if self.ofs == _REP_LEN:
|
||||
# we are at the end of the report, flush it
|
||||
await write
|
||||
io.write(self.iface, self.data)
|
||||
self.iface.write(self.data)
|
||||
self.ofs = _REP_CONT_DATA
|
||||
|
||||
return nwritten
|
||||
|
||||
async def close(self):
|
||||
async def aclose(self):
|
||||
'''Flush and close the message transmission.'''
|
||||
if self.ofs != _REP_CONT_DATA:
|
||||
# we didn't write anything or last write() wasn't report-aligned,
|
||||
@ -141,5 +148,5 @@ class Writer:
|
||||
self.data[self.ofs] = 0x00
|
||||
self.ofs += 1
|
||||
|
||||
await loop.select(self.iface | loop.WRITE)
|
||||
io.send(self.iface, self.data)
|
||||
await loop.select(self.iface.iface_num() | loop.WRITE)
|
||||
self.iface.write(self.data)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from micropython import const
|
||||
import ustruct
|
||||
|
||||
from trezor import io
|
||||
from trezor import loop
|
||||
from trezor import utils
|
||||
from trezor.crypto import random
|
||||
@ -28,11 +27,11 @@ _REP_MARKER_CONT = const(0x02)
|
||||
_REP_MARKER_OPEN = const(0x03)
|
||||
_REP_MARKER_CLOSE = const(0x04)
|
||||
|
||||
_REP = '>BL' # marker, session_id
|
||||
_REP = '>BL' # marker, session_id
|
||||
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
|
||||
_REP_CONT = '>BLL' # marker, session_id, sequence
|
||||
_REP_CONT = '>BLL' # marker, session_id, sequence
|
||||
_REP_INIT_DATA = const(13) # offset of data in init report
|
||||
_REP_CONT_DATA = const(9) # offset of data in cont report
|
||||
_REP_CONT_DATA = const(9) # offset of data in cont report
|
||||
|
||||
|
||||
class Reader:
|
||||
@ -51,15 +50,16 @@ class Reader:
|
||||
self.seq = 0
|
||||
|
||||
def __repr__(self):
|
||||
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
|
||||
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type,
|
||||
self.size)
|
||||
|
||||
async def open(self):
|
||||
async def aopen(self):
|
||||
'''
|
||||
Begin the message transmission by waiting for initial V2 message report
|
||||
on this session. `self.type` and `self.size` are initialized and
|
||||
available after `open()` returns.
|
||||
available after `aopen()` returns.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||
while True:
|
||||
# wait for initial report
|
||||
report = await read
|
||||
@ -74,7 +74,7 @@ class Reader:
|
||||
self.ofs = 0
|
||||
self.seq = 0
|
||||
|
||||
async def readinto(self, buf):
|
||||
async def areadinto(self, buf):
|
||||
'''
|
||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||
@ -83,7 +83,7 @@ class Reader:
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||
nread = 0
|
||||
while nread < len(buf):
|
||||
if self.ofs == len(self.data):
|
||||
@ -115,20 +115,28 @@ class Writer:
|
||||
interface.
|
||||
'''
|
||||
|
||||
def __init__(self, iface, sid, mtype, msize):
|
||||
def __init__(self, iface, sid):
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.type = None
|
||||
self.size = None
|
||||
self.data = bytearray(_REP_LEN)
|
||||
self.ofs = 0
|
||||
self.seq = 0
|
||||
|
||||
def setheader(self, mtype, msize):
|
||||
'''
|
||||
Reset the writer state and load the message header with passed type and
|
||||
total message size.
|
||||
'''
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = bytearray(_REP_LEN)
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER_INIT, self.sid,
|
||||
mtype, msize)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
self.seq = 0
|
||||
|
||||
# load the report with initial header
|
||||
ustruct.pack_into(_REP_INIT, self.data, 0,
|
||||
_REP_MARKER_INIT, sid, mtype, msize)
|
||||
|
||||
async def write(self, buf):
|
||||
async def awrite(self, buf):
|
||||
'''
|
||||
Encode and write every byte from `buf`. Does not need to be called in
|
||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||
@ -137,7 +145,7 @@ class Writer:
|
||||
if self.size < len(buf):
|
||||
raise EOFError
|
||||
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||
nwritten = 0
|
||||
while nwritten < len(buf):
|
||||
# copy as much as possible to report buffer
|
||||
@ -149,15 +157,15 @@ class Writer:
|
||||
if self.ofs == _REP_LEN:
|
||||
# we are at the end of the report, flush it, and prepare header
|
||||
await write
|
||||
io.send(self.iface, self.data)
|
||||
ustruct.pack_into(_REP_CONT, self.data, 0,
|
||||
_REP_MARKER_CONT, self.sid, self.seq)
|
||||
self.iface.write(self.data)
|
||||
ustruct.pack_into(_REP_CONT, self.data, 0, _REP_MARKER_CONT,
|
||||
self.sid, self.seq)
|
||||
self.ofs = _REP_CONT_DATA
|
||||
self.seq += 1
|
||||
|
||||
return nwritten
|
||||
|
||||
async def close(self):
|
||||
async def aclose(self):
|
||||
'''Flush and close the message transmission.'''
|
||||
if self.ofs != _REP_CONT_DATA:
|
||||
# we didn't write anything or last write() wasn't report-aligned,
|
||||
@ -166,8 +174,8 @@ class Writer:
|
||||
self.data[self.ofs] = 0x00
|
||||
self.ofs += 1
|
||||
|
||||
await loop.select(self.iface | loop.WRITE)
|
||||
io.send(self.iface, self.data)
|
||||
await loop.select(self.iface.iface_num() | loop.WRITE)
|
||||
self.iface.write(self.data)
|
||||
|
||||
|
||||
class SesssionSupervisor:
|
||||
@ -186,8 +194,8 @@ class SesssionSupervisor:
|
||||
After close request, the handling task is closed and session terminated.
|
||||
Both requests receive responses confirming the operation.
|
||||
'''
|
||||
read = loop.select(self.iface | loop.READ)
|
||||
write = loop.select(self.iface | loop.WRITE)
|
||||
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||
while True:
|
||||
report = await read
|
||||
repmarker, repsid = ustruct.unpack(_REP, report)
|
||||
@ -225,8 +233,8 @@ class SesssionSupervisor:
|
||||
|
||||
def writeopen(self, sid):
|
||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
|
||||
io.write(self.iface, self.session_report)
|
||||
self.iface.write(self.session_report)
|
||||
|
||||
def writeclose(self, sid):
|
||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
|
||||
io.write(self.iface, self.session_report)
|
||||
self.iface.write(self.session_report)
|
||||
|
@ -26,26 +26,26 @@ def test_reader():
|
||||
|
||||
# open, expected one read
|
||||
first_report = report_header + message[:rep_len - len(report_header)]
|
||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_eq(reader.type, message_type)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
# empty read
|
||||
empty_buffer = bytearray()
|
||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(empty_buffer), 0)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
# short read, expected no read
|
||||
short_buffer = bytearray(32)
|
||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(short_buffer), 32)
|
||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer))
|
||||
|
||||
# aligned read, expected no read
|
||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||
|
||||
@ -53,12 +53,12 @@ def test_reader():
|
||||
next_report_header = bytearray(unhexlify('3f'))
|
||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||
onebyte_buffer = bytearray(1)
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||
|
||||
# too long read, raises eof
|
||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
|
||||
# long read, expect multiple reads
|
||||
start_size = reader.size
|
||||
@ -74,12 +74,12 @@ def test_reader():
|
||||
prev_report = next_reports[i - 1] if i > 0 else None
|
||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||
assert_async(reader.areadinto(long_buffer), expected_syscalls)
|
||||
assert_eq(long_buffer, message[-start_size:])
|
||||
assert_eq(reader.size, 0)
|
||||
|
||||
# one byte read, raises eof
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
|
||||
|
||||
|
||||
def test_writer():
|
||||
@ -87,7 +87,8 @@ def test_writer():
|
||||
interface = 0xdeadbeef
|
||||
message_type = 0x87654321
|
||||
message_len = 1024
|
||||
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
|
||||
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID)
|
||||
writer.setheader(message_type, message_len)
|
||||
|
||||
# init header corresponding to the data above
|
||||
report_header = bytearray(unhexlify('3f2323432100000400'))
|
||||
@ -96,14 +97,14 @@ def test_writer():
|
||||
|
||||
# empty write
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
|
||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||
assert_eq(writer.size, start_size)
|
||||
|
||||
# short write, expected no report
|
||||
start_size = writer.size
|
||||
short_payload = bytearray(range(4))
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data,
|
||||
report_header
|
||||
@ -118,7 +119,7 @@ def test_writer():
|
||||
+ short_payload
|
||||
+ aligned_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||
msg.send.assert_called_n_times(1)
|
||||
msg.send = msg.send.original
|
||||
@ -126,7 +127,7 @@ def test_writer():
|
||||
# short write, expected no report, but data starts with correct seq and cont marker
|
||||
report_header = bytearray(unhexlify('3f'))
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||
report_header + short_payload)
|
||||
@ -142,19 +143,19 @@ def test_writer():
|
||||
# test write
|
||||
expected_write_reports = expected_reports[:-1]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, start_size - len(long_payload))
|
||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||
msg.send = msg.send.original
|
||||
# test write raises eof
|
||||
msg.send = mock_call(msg.send, [])
|
||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
|
||||
msg.send.assert_called_n_times(0)
|
||||
msg.send = msg.send.original
|
||||
# test close
|
||||
expected_close_reports = expected_reports[-1:]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, 0)
|
||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||
msg.send = msg.send.original
|
||||
|
@ -26,26 +26,26 @@ def test_reader():
|
||||
|
||||
# open, expected one read
|
||||
first_report = report_header + message[:rep_len - len(report_header)]
|
||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||
assert_eq(reader.type, message_type)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
# empty read
|
||||
empty_buffer = bytearray()
|
||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(empty_buffer), 0)
|
||||
assert_eq(reader.size, message_len)
|
||||
|
||||
# short read, expected no read
|
||||
short_buffer = bytearray(32)
|
||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
|
||||
assert_eq(len(short_buffer), 32)
|
||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer))
|
||||
|
||||
# aligned read, expected no read
|
||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
|
||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||
|
||||
@ -53,12 +53,12 @@ def test_reader():
|
||||
next_report_header = bytearray(unhexlify('021234567800000000'))
|
||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||
onebyte_buffer = bytearray(1)
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||
|
||||
# too long read, raises eof
|
||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||
|
||||
# long read, expect multiple reads
|
||||
start_size = reader.size
|
||||
@ -74,12 +74,12 @@ def test_reader():
|
||||
prev_report = next_reports[i - 1] if i > 0 else None
|
||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
||||
assert_async(reader.areadinto(long_buffer), expected_syscalls)
|
||||
assert_eq(long_buffer, message[-start_size:])
|
||||
assert_eq(reader.size, 0)
|
||||
|
||||
# one byte read, raises eof
|
||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
||||
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
|
||||
|
||||
|
||||
def test_writer():
|
||||
@ -88,7 +88,8 @@ def test_writer():
|
||||
session_id = 0x12345678
|
||||
message_type = 0x87654321
|
||||
message_len = 1024
|
||||
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
|
||||
writer = codec_v2.Writer(interface, session_id)
|
||||
writer.setheader(message_type, message_len)
|
||||
|
||||
# init header corresponding to the data above
|
||||
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
||||
@ -97,14 +98,14 @@ def test_writer():
|
||||
|
||||
# empty write
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
|
||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||
assert_eq(writer.size, start_size)
|
||||
|
||||
# short write, expected no report
|
||||
start_size = writer.size
|
||||
short_payload = bytearray(range(4))
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data,
|
||||
report_header
|
||||
@ -119,7 +120,7 @@ def test_writer():
|
||||
+ short_payload
|
||||
+ aligned_payload
|
||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||
msg.send.assert_called_n_times(1)
|
||||
msg.send = msg.send.original
|
||||
@ -127,7 +128,7 @@ def test_writer():
|
||||
# short write, expected no report, but data starts with correct seq and cont marker
|
||||
report_header = bytearray(unhexlify('021234567800000000'))
|
||||
start_size = writer.size
|
||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||
assert_eq(writer.size, start_size - len(short_payload))
|
||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||
report_header + short_payload)
|
||||
@ -145,19 +146,19 @@ def test_writer():
|
||||
# test write
|
||||
expected_write_reports = expected_reports[:-1]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, start_size - len(long_payload))
|
||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||
msg.send = msg.send.original
|
||||
# test write raises eof
|
||||
msg.send = mock_call(msg.send, [])
|
||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
||||
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
|
||||
msg.send.assert_called_n_times(0)
|
||||
msg.send = msg.send.original
|
||||
# test close
|
||||
expected_close_reports = expected_reports[-1:]
|
||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||
assert_eq(writer.size, 0)
|
||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||
msg.send = msg.send.original
|
||||
|
Loading…
Reference in New Issue
Block a user