1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-27 07:40:59 +00:00

wire: pass Context to apps

This commit is contained in:
Jan Pochyla 2017-08-15 15:09:09 +02:00
parent b1b84fb233
commit 3562ffdc54
30 changed files with 355 additions and 308 deletions

View File

@ -7,7 +7,7 @@ signal = loop.Signal()
@unimport
async def confirm(session_id, content, code=None, *args, **kwargs):
async def confirm(ctx, content, code=None, *args, **kwargs):
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import Other
@ -19,12 +19,12 @@ async def confirm(session_id, content, code=None, *args, **kwargs):
if code is None:
code = Other
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
await ctx.call(ButtonRequest(code=code), ButtonAck)
return await loop.Wait((signal, dialog)) == CONFIRMED
@unimport
async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import Other
@ -36,7 +36,7 @@ async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
if code is None:
code = Other
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
await ctx.call(ButtonRequest(code=code), ButtonAck)
return await loop.Wait((signal, dialog)) == CONFIRMED

View File

@ -1,7 +1,7 @@
from trezor import ui, wire
async def request_passphrase(session_id):
async def request_passphrase(ctx):
from trezor.messages.FailureType import ActionCancelled
from trezor.messages.PassphraseRequest import PassphraseRequest
from trezor.messages.wire_types import PassphraseAck, Cancel
@ -12,17 +12,17 @@ async def request_passphrase(session_id):
'Please enter passphrase', 'on your computer.')
text.render()
ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel)
ack = await ctx.call(PassphraseRequest(), PassphraseAck, Cancel)
if ack.MESSAGE_WIRE_TYPE == Cancel:
raise wire.FailureError(ActionCancelled, 'Passphrase cancelled')
return ack.passphrase
async def protect_by_passphrase(session_id):
async def protect_by_passphrase(ctx):
from apps.common import storage
if storage.is_protected_by_passphrase():
return await request_passphrase(session_id)
return await request_passphrase(ctx)
else:
return ''

View File

@ -7,7 +7,7 @@ if __debug__:
@unimport
async def request_pin_on_display(session_id: int, code: int=None) -> str:
async def request_pin_on_display(ctx: wire.Context, code: int=None) -> str:
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import ProtectCall
from trezor.messages.FailureType import PinCancelled
@ -20,9 +20,8 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
_, label = _get_code_and_label(code)
await wire.call(session_id,
ButtonRequest(code=ProtectCall),
ButtonAck)
await ctx.call(ButtonRequest(code=ProtectCall),
ButtonAck)
ui.display.clear()
matrix = PinMatrix(label)
@ -36,7 +35,7 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
@unimport
async def request_pin_on_client(session_id: int, code: int=None) -> str:
async def request_pin_on_client(ctx: wire.Context, code: int=None) -> str:
from trezor.messages.FailureType import PinCancelled
from trezor.messages.PinMatrixRequest import PinMatrixRequest
from trezor.messages.wire_types import PinMatrixAck, Cancel
@ -51,9 +50,8 @@ async def request_pin_on_client(session_id: int, code: int=None) -> str:
matrix = PinMatrix(label)
matrix.render()
ack = await wire.call(session_id,
PinMatrixRequest(type=code),
PinMatrixAck, Cancel)
ack = await ctx.call(PinMatrixRequest(type=code),
PinMatrixAck, Cancel)
digits = matrix.digits
matrix = None
@ -66,12 +64,12 @@ request_pin = request_pin_on_client
@unimport
async def request_pin_twice(session_id: int) -> str:
async def request_pin_twice(ctx: wire.Context) -> str:
from trezor.messages.FailureType import ActionCancelled
from trezor.messages import PinMatrixRequestType
pin_first = await request_pin(session_id, PinMatrixRequestType.NewFirst)
pin_again = await request_pin(session_id, PinMatrixRequestType.NewSecond)
pin_first = await request_pin(ctx, PinMatrixRequestType.NewFirst)
pin_again = await request_pin(ctx, PinMatrixRequestType.NewSecond)
if pin_first != pin_again:
# changed message due to consistency with T1 msgs
raise wire.FailureError(ActionCancelled, 'PIN change failed')
@ -79,22 +77,22 @@ async def request_pin_twice(session_id: int) -> str:
return pin_first
async def protect_by_pin_repeatedly(session_id: int, at_least_once: bool=False):
async def protect_by_pin_repeatedly(ctx: wire.Context, at_least_once: bool=False):
from . import storage
locked = storage.is_locked() or at_least_once
while locked:
pin = await request_pin(session_id)
pin = await request_pin(ctx)
locked = not storage.unlock(pin, _render_pin_failure)
async def protect_by_pin_or_fail(session_id: int, at_least_once: bool=False):
async def protect_by_pin_or_fail(ctx: wire.Context, at_least_once: bool=False):
from trezor.messages.FailureType import PinInvalid
from . import storage
locked = storage.is_locked() or at_least_once
if locked:
pin = await request_pin(session_id)
pin = await request_pin(ctx)
if not storage.unlock(pin, _render_pin_failure):
raise wire.FailureError(PinInvalid, 'PIN invalid')

View File

@ -5,20 +5,20 @@ from trezor.crypto import bip39
_DEFAULT_CURVE = 'secp256k1'
async def get_root(session_id: int, curve_name=_DEFAULT_CURVE):
seed = await get_seed(session_id)
async def get_root(ctx: wire.Context, curve_name=_DEFAULT_CURVE):
seed = await get_seed(ctx)
root = bip32.from_seed(seed, curve_name)
return root
async def get_seed(session_id: int) -> bytes:
async def get_seed(ctx: wire.Context) -> bytes:
from . import cache
if cache.seed is None:
cache.seed = await compute_seed(session_id)
cache.seed = await compute_seed(ctx)
return cache.seed
async def compute_seed(session_id: int) -> bytes:
async def compute_seed(ctx: wire.Context) -> bytes:
from trezor.messages.FailureType import ProcessError
from .request_passphrase import protect_by_passphrase
from .request_pin import protect_by_pin
@ -27,9 +27,9 @@ async def compute_seed(session_id: int) -> bytes:
if not storage.is_initialized():
raise wire.FailureError(ProcessError, 'Device is not initialized')
await protect_by_pin(session_id)
await protect_by_pin(ctx)
passphrase = await protect_by_passphrase(session_id)
passphrase = await protect_by_passphrase(ctx)
return bip39.seed(storage.get_mnemonic(), passphrase)

View File

@ -4,13 +4,13 @@ from trezor.messages.wire_types import \
DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase
async def dispatch_DebugLinkDecision(session_id, msg):
async def dispatch_DebugLinkDecision(ctx, msg):
from trezor.ui.confirm import CONFIRMED, CANCELLED
from apps.common.confirm import signal
signal.send(CONFIRMED if msg.yes_no else CANCELLED)
async def dispatch_DebugLinkGetState(session_id, msg):
async def dispatch_DebugLinkGetState(ctx, msg):
from trezor.messages.DebugLinkState import DebugLinkState
from apps.common import storage, request_pin
from apps.management import reset_device
@ -36,11 +36,11 @@ async def dispatch_DebugLinkGetState(session_id, msg):
return m
async def dispatch_DebugLinkStop(session_id, msg):
async def dispatch_DebugLinkStop(ctx, msg):
pass
async def dispatch_DebugLinkMemoryRead(session_id, msg):
async def dispatch_DebugLinkMemoryRead(ctx, msg):
from trezor.messages.DebugLinkMemory import DebugLinkMemory
from uctypes import bytes_at
m = DebugLinkMemory()
@ -48,14 +48,14 @@ async def dispatch_DebugLinkMemoryRead(session_id, msg):
return m
async def dispatch_DebugLinkMemoryWrite(session_id, msg):
async def dispatch_DebugLinkMemoryWrite(ctx, msg):
from uctypes import bytearray_at
l = len(msg.memory)
data = bytearray_at(msg.address, l)
data[0:l] = msg.memory
async def dispatch_DebugLinkFlashErase(session_id, msg):
async def dispatch_DebugLinkFlashErase(ctx, msg):
# TODO: erase(msg.sector)
pass

View File

@ -3,13 +3,13 @@ from trezor.utils import unimport
@unimport
async def layout_ethereum_get_address(session_id, msg):
async def layout_ethereum_get_address(ctx, msg):
from trezor.messages.EthereumAddress import EthereumAddress
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from ..common import seed
node = await seed.get_root(session_id)
node = await seed.get_root(ctx)
node.derive_path(msg.address_n or ())
seckey = node.private_key()
@ -17,11 +17,11 @@ async def layout_ethereum_get_address(session_id, msg):
address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak
if msg.show_display:
await _show_address(session_id, address)
await _show_address(ctx, address)
return EthereumAddress(address=address)
async def _show_address(session_id, address):
async def _show_address(ctx, address):
from trezor.messages.ButtonRequestType import Address
from trezor.ui.text import Text
from ..common.confirm import require_confirm
@ -30,7 +30,7 @@ async def _show_address(session_id, address):
content = Text('Confirm address', ui.ICON_RESET,
ui.MONO, *_split_address(address))
await require_confirm(session_id, content, code=Address)
await require_confirm(ctx, content, code=Address)
def _split_address(address):

View File

@ -226,11 +226,11 @@ class Cmd:
return Msg(self.cid, cla, ins, p1, p2, lc, data)
async def read_cmd(iface: int) -> Cmd:
async def read_cmd(iface: io.HID) -> Cmd:
desc_init = frame_init()
desc_cont = frame_cont()
buf, = await loop.select(iface)
buf, = await loop.select(iface.iface_num())
# log.debug(__name__, 'read init %s', buf)
ifrm = overlay_struct(buf, desc_init)
@ -252,7 +252,7 @@ async def read_cmd(iface: int) -> Cmd:
data = data[:bcnt]
while datalen < bcnt:
buf, = await loop.select(iface)
buf, = await loop.select(iface.iface_num())
# log.debug(__name__, 'read cont %s', buf)
cfrm = overlay_struct(buf, desc_cont)
@ -282,7 +282,7 @@ async def read_cmd(iface: int) -> Cmd:
return Cmd(ifrm.cid, ifrm.cmd, data)
def send_cmd(cmd: Cmd, iface: int) -> None:
def send_cmd(cmd: Cmd, iface: io.HID) -> None:
init_desc = frame_init()
cont_desc = frame_cont()
offset = 0
@ -295,7 +295,7 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
frm.bcnt = datalen
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
io.send(iface, buf)
iface.write(buf)
# log.debug(__name__, 'send init %s', buf)
if offset < datalen:
@ -304,18 +304,17 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
while offset < datalen:
frm.seq = seq
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
utime.sleep_ms(1) # FIXME: do async send
io.send(iface, buf)
utime.sleep_ms(1) # FIXME: async write
iface.write(buf)
# log.debug(__name__, 'send cont %s', buf)
seq += 1
def boot():
iface = 0x03
def boot(iface: io.HID):
loop.schedule_task(handle_reports(iface))
async def handle_reports(iface: int):
async def handle_reports(iface: io.HID):
while True:
try:
req = await read_cmd(iface)

View File

@ -4,7 +4,7 @@ from trezor.messages.wire_types import Initialize, GetFeatures, Ping
@unimport
async def respond_Features(session_id, msg):
async def respond_Features(ctx, msg):
from apps.common import storage, coins
from trezor.messages.Features import Features
@ -28,7 +28,7 @@ async def respond_Features(session_id, msg):
@unimport
async def respond_Pong(session_id, msg):
async def respond_Pong(ctx, msg):
from trezor.messages.Success import Success
s = Success()
@ -36,11 +36,11 @@ async def respond_Pong(session_id, msg):
if msg.pin_protection:
from apps.common.request_pin import protect_by_pin
await protect_by_pin(session_id)
await protect_by_pin(ctx)
if msg.passphrase_protection:
from apps.common.request_passphrase import protect_by_passphrase
await protect_by_passphrase(session_id)
await protect_by_passphrase(ctx)
# TODO: handle other fields:
# button_protection

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport
async def layout_apply_settings(session_id, msg):
async def layout_apply_settings(ctx, msg):
from trezor.messages.Success import Success
from trezor.messages.FailureType import ProcessError
from trezor.ui.text import Text
@ -11,7 +11,7 @@ async def layout_apply_settings(session_id, msg):
from ..common.request_pin import protect_by_pin
from ..common import storage
await protect_by_pin(session_id)
await protect_by_pin(ctx)
if msg.homescreen is not None:
raise wire.FailureError(
@ -21,20 +21,20 @@ async def layout_apply_settings(session_id, msg):
raise wire.FailureError(ProcessError, 'No setting provided')
if msg.label is not None:
await require_confirm(session_id, Text(
await require_confirm(ctx, Text(
'Change label', ui.ICON_RESET,
'Do you really want to', 'change label to',
ui.BOLD, '%s' % msg.label))
if msg.language is not None:
await require_confirm(session_id, Text(
await require_confirm(ctx, Text(
'Change language', ui.ICON_RESET,
'Do you really want to', 'change language to',
ui.BOLD, '%s' % msg.language,
ui.NORMAL, '?'))
if msg.use_passphrase is not None:
await require_confirm(session_id, Text(
await require_confirm(ctx, Text(
'Enable passphrase' if msg.use_passphrase else 'Disable passphrase',
ui.ICON_RESET,
'Do you really want to',

View File

@ -2,52 +2,52 @@ from trezor import ui
from trezor.utils import unimport
def confirm_set_pin(session_id):
def confirm_set_pin(ctx):
from apps.common.confirm import require_confirm
from trezor.ui.text import Text
return require_confirm(session_id, Text(
return require_confirm(ctx, Text(
'Change PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD,
'set new PIN?'))
def confirm_change_pin(session_id):
def confirm_change_pin(ctx):
from apps.common.confirm import require_confirm
from trezor.ui.text import Text
return require_confirm(session_id, Text(
return require_confirm(ctx, Text(
'Change PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD,
'change current PIN?'))
def confirm_remove_pin(session_id):
def confirm_remove_pin(ctx):
from apps.common.confirm import require_confirm
from trezor.ui.text import Text
return require_confirm(session_id, Text(
return require_confirm(ctx, Text(
'Remove PIN', ui.ICON_RESET,
'Do you really want to', ui.BOLD,
'remove current PIN?'))
@unimport
async def layout_change_pin(session_id, msg):
async def layout_change_pin(ctx, msg):
from trezor.messages.Success import Success
from apps.common.request_pin import protect_by_pin, request_pin_twice
from apps.common import storage
if msg.remove:
if storage.is_protected_by_pin():
await confirm_remove_pin(session_id)
await protect_by_pin(session_id, at_least_once=True)
await confirm_remove_pin(ctx)
await protect_by_pin(ctx, at_least_once=True)
pin = ''
else:
if storage.is_protected_by_pin():
await confirm_change_pin(session_id)
await protect_by_pin(session_id, at_least_once=True)
await confirm_change_pin(ctx)
await protect_by_pin(ctx, at_least_once=True)
else:
await confirm_set_pin(session_id)
pin = await request_pin_twice(session_id)
await confirm_set_pin(ctx)
pin = await request_pin_twice(ctx)
storage.load_settings(pin=pin)
if pin:

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport
async def layout_load_device(session_id, msg):
async def layout_load_device(ctx, msg):
from trezor.crypto import bip39
from trezor.messages.Success import Success
from trezor.messages.FailureType import UnexpectedMessage, ProcessError
@ -20,7 +20,7 @@ async def layout_load_device(session_id, msg):
if not msg.skip_checksum and not bip39.check(msg.mnemonic):
raise wire.FailureError(ProcessError, 'Mnemonic is not valid')
await require_confirm(session_id, Text(
await require_confirm(ctx, Text(
'Loading seed', ui.ICON_RESET,
ui.BOLD, 'Loading private seed', 'is not recommended.',
ui.NORMAL, 'Continue only if you', 'know what you are doing!'))

View File

@ -11,7 +11,7 @@ def nth(n):
@unimport
async def layout_recovery_device(session_id, msg):
async def layout_recovery_device(ctx, msg):
msg = 'Please enter ' + nth(msg.word_count) + ' word'

View File

@ -10,7 +10,7 @@ if __debug__:
@unimport
async def layout_reset_device(session_id, msg):
async def layout_reset_device(ctx, msg):
from trezor.ui.text import Text
from trezor.crypto import hashlib, random, bip39
from trezor.messages.EntropyRequest import EntropyRequest
@ -39,21 +39,21 @@ async def layout_reset_device(session_id, msg):
if msg.display_random:
entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16)
entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines)
await require_confirm(session_id, entropy_content, ButtonRequestType.ResetDevice)
await require_confirm(ctx, entropy_content, ButtonRequestType.ResetDevice)
if msg.pin_protection:
pin = await request_pin_twice(session_id)
pin = await request_pin_twice(ctx)
else:
pin = None
external_entropy_ack = await wire.call(session_id, EntropyRequest(), EntropyAck)
external_entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ctx = hashlib.sha256()
ctx.update(internal_entropy)
ctx.update(external_entropy_ack.entropy)
entropy = ctx.digest()
mnemonic = bip39.from_data(entropy[:msg.strength // 8])
await show_mnemonic_by_word(session_id, mnemonic)
await show_mnemonic_by_word(ctx, mnemonic)
storage.load_mnemonic(mnemonic)
storage.load_settings(pin=pin,
@ -64,7 +64,7 @@ async def layout_reset_device(session_id, msg):
return Success(message='Initialized')
async def show_mnemonic_by_word(session_id, mnemonic):
async def show_mnemonic_by_word(ctx, mnemonic):
from trezor.ui.text import Text
from trezor.messages.ButtonRequestType import ConfirmWord
from apps.common.confirm import confirm
@ -80,7 +80,7 @@ async def show_mnemonic_by_word(session_id, mnemonic):
while index < len(words):
word = words[index]
current_word = word
await confirm(session_id,
await confirm(ctx,
Text(
'Recovery seed setup', ui.ICON_RESET,
ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ',

View File

@ -3,13 +3,13 @@ from trezor.utils import unimport
@unimport
async def layout_wipe_device(session_id, msg):
async def layout_wipe_device(ctx, msg):
from trezor.messages.Success import Success
from trezor.ui.text import Text
from ..common.confirm import hold_to_confirm
from ..common import storage
await hold_to_confirm(session_id, Text(
await hold_to_confirm(ctx, Text(
'WIPE DEVICE',
ui.ICON_WIPE,
ui.NORMAL, 'Do you really want to', 'wipe the device?',

View File

@ -26,7 +26,7 @@ def cipher_key_value(msg, seckey: bytes) -> bytes:
@unimport
async def layout_cipher_key_value(session_id, msg):
async def layout_cipher_key_value(ctx, msg):
from trezor.messages.CipheredKeyValue import CipheredKeyValue
from ..common import seed
@ -38,7 +38,7 @@ async def layout_cipher_key_value(session_id, msg):
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK)
node = await seed.get_root(session_id)
node = await seed.get_root(ctx)
node.derive_path(msg.address_n)
value = cipher_key_value(msg, node.private_key())

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport
async def layout_get_address(session_id, msg):
async def layout_get_address(ctx, msg):
from trezor.messages.Address import Address
from trezor.messages.FailureType import ProcessError
from ..common import coins
@ -15,18 +15,18 @@ async def layout_get_address(session_id, msg):
address_n = msg.address_n or ()
coin_name = msg.coin_name or 'Bitcoin'
node = await seed.get_root(session_id)
node = await seed.get_root(ctx)
node.derive_path(address_n)
coin = coins.by_name(coin_name)
address = node.address(coin.address_type)
if msg.show_display:
await _show_address(session_id, address)
await _show_address(ctx, address)
return Address(address=address)
async def _show_address(session_id, address):
async def _show_address(ctx, address):
from trezor.messages.ButtonRequestType import Address
from trezor.ui.text import Text
from trezor.ui.qr import Qr
@ -37,7 +37,7 @@ async def _show_address(session_id, address):
content = Container(
Qr(address, (120, 135), 3),
Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines))
await require_confirm(session_id, content, code=Address)
await require_confirm(ctx, content, code=Address)
def _split_address(address):

View File

@ -3,18 +3,18 @@ from trezor.utils import unimport
@unimport
async def layout_get_entropy(session_id, msg):
async def layout_get_entropy(ctx, msg):
from trezor.messages.Entropy import Entropy
from trezor.crypto import random
l = min(msg.size, 1024)
await _show_entropy(session_id)
await _show_entropy(ctx)
return Entropy(entropy=random.bytes(l))
async def _show_entropy(session_id):
async def _show_entropy(ctx):
from trezor.messages.ButtonRequestType import ProtectCall
from trezor.ui.text import Text
from trezor.ui.container import Container
@ -23,4 +23,4 @@ async def _show_entropy(session_id):
content = Container(
Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?'))
await require_confirm(session_id, content, code=ProtectCall)
await require_confirm(ctx, content, code=ProtectCall)

View File

@ -2,7 +2,7 @@ from trezor.utils import unimport
@unimport
async def layout_get_public_key(session_id, msg):
async def layout_get_public_key(ctx, msg):
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.PublicKey import PublicKey
from ..common import coins
@ -11,7 +11,7 @@ async def layout_get_public_key(session_id, msg):
address_n = msg.address_n or ()
coin_name = msg.coin_name or 'Bitcoin'
node = await seed.get_root(session_id)
node = await seed.get_root(ctx)
node.derive_path(address_n)
coin = coins.by_name(coin_name)

View File

@ -83,7 +83,7 @@ def sign_challenge(seckey: bytes,
@unimport
async def layout_sign_identity(session_id, msg):
async def layout_sign_identity(ctx, msg):
from trezor.messages.SignedIdentity import SignedIdentity
from ..common import coins
from ..common import seed
@ -92,7 +92,7 @@ async def layout_sign_identity(session_id, msg):
display_identity(identity, msg.challenge_visual)
address_n = get_identity_path(identity, msg.identity.index or 0)
node = await seed.get_root(session_id, msg.ecdsa_curve_name)
node = await seed.get_root(ctx, msg.ecdsa_curve_name)
node.derive_path(address_n)
coin = coins.by_name('Bitcoin')

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport
async def layout_sign_message(session_id, msg):
async def layout_sign_message(ctx, msg):
from trezor.messages.MessageSignature import MessageSignature
from trezor.crypto.curve import secp256k1
from ..common import coins
@ -18,7 +18,7 @@ async def layout_sign_message(session_id, msg):
coin_name = msg.coin_name or 'Bitcoin'
coin = coins.by_name(coin_name)
node = await seed.get_root(session_id)
node = await seed.get_root(ctx)
node.derive_path(msg.address_n)
seckey = node.private_key()

View File

@ -3,7 +3,7 @@ from trezor import wire
@unimport
async def sign_tx(session_id, msg):
async def sign_tx(ctx, msg):
from trezor.messages.RequestType import TXFINISHED
from trezor.messages.wire_types import TxAck
@ -11,7 +11,7 @@ async def sign_tx(session_id, msg):
from . import signing
from . import layout
root = await seed.get_root(session_id)
root = await seed.get_root(ctx)
signer = signing.sign_tx(msg, root)
res = None
@ -23,13 +23,13 @@ async def sign_tx(session_id, msg):
if req.__qualname__ == 'TxRequest':
if req.request_type == TXFINISHED:
break
res = await wire.call(session_id, req, TxAck)
res = await ctx.call(req, TxAck)
elif req.__qualname__ == 'UiConfirmOutput':
res = await layout.confirm_output(session_id, req.output, req.coin)
res = await layout.confirm_output(ctx, req.output, req.coin)
elif req.__qualname__ == 'UiConfirmTotal':
res = await layout.confirm_total(session_id, req.spending, req.fee, req.coin)
res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
elif req.__qualname__ == 'UiConfirmFeeOverThreshold':
res = await layout.confirm_feeoverthreshold(session_id, req.fee, req.coin)
res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
else:
raise TypeError('Invalid signing instruction')
return req

View File

@ -14,22 +14,22 @@ def split_address(address):
return chunks(address, 17)
async def confirm_output(session_id, output, coin):
async def confirm_output(ctx, output, coin):
content = Text('Confirm output', ui.ICON_RESET,
ui.BOLD, format_amount(output.amount, coin),
ui.NORMAL, 'to',
ui.MONO, *split_address(output.address))
return await confirm(session_id, content, ButtonRequestType.ConfirmOutput)
return await confirm(ctx, content, ButtonRequestType.ConfirmOutput)
async def confirm_total(session_id, spending, fee, coin):
async def confirm_total(ctx, spending, fee, coin):
content = Text('Confirm transaction', ui.ICON_RESET,
'Sending: %s' % format_amount(spending, coin),
'Fee: %s' % format_amount(fee, coin))
return await hold_to_confirm(session_id, content, ButtonRequestType.SignTx)
return await hold_to_confirm(ctx, content, ButtonRequestType.SignTx)
async def confirm_feeoverthreshold(session_id, fee, coin):
async def confirm_feeoverthreshold(ctx, fee, coin):
content = Text('Confirm high fee:', ui.ICON_RESET,
ui.BOLD, format_amount(fee, coin))
return await confirm(session_id, content, ButtonRequestType.FeeOverThreshold)
return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold)

View File

@ -3,7 +3,7 @@ from trezor.utils import unimport
@unimport
async def layout_verify_message(session_id, msg):
async def layout_verify_message(ctx, msg):
from trezor.messages.Success import Success
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import ripemd160, sha256

View File

@ -4,26 +4,7 @@ from trezor import io
from trezor import wire
from trezor import main
# Load applications
from apps.common import storage
if __debug__:
from apps import debug
from apps import homescreen
from apps import management
from apps import wallet
from apps import ethereum
from apps import fido_u2f
# Boot applications
if __debug__:
debug.boot()
homescreen.boot()
management.boot()
wallet.boot()
ethereum.boot()
fido_u2f.boot()
# Intialize the USB stack
# initialize the USB stack
usb_wire = io.HID(
iface_num=0x00,
ep_in=0x81,
@ -90,11 +71,30 @@ usb.add(usb_vcp)
usb.add(usb_u2f)
usb.open()
# Initialize the wire codec pipeline
wire.setup(usb_wire.iface_num())
# load applications
from apps.common import storage
if __debug__:
from apps import debug
from apps import homescreen
from apps import management
from apps import wallet
from apps import ethereum
from apps import fido_u2f
# Load default homescreen
# boot applications
if __debug__:
debug.boot()
homescreen.boot()
management.boot()
wallet.boot()
ethereum.boot()
fido_u2f.boot(usb_u2f)
# initialize the wire codec pipeline
wire.setup(usb_wire)
# load default homescreen
from apps.homescreen.homescreen import layout_homescreen
# Run main even loop and specify which screen is default
# run main even loop and specify which screen is default
main.run(default_workflow=layout_homescreen)

View File

@ -14,7 +14,7 @@ async def load_uvarint(reader):
shift = 0
byte = 0x80
while byte & 0x80:
await reader.readinto(buffer)
await reader.areadinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
@ -27,7 +27,7 @@ async def dump_uvarint(writer, n):
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.write(buffer)
await writer.awrite(buffer)
n = shifted
@ -69,11 +69,11 @@ class LimitedReader:
self.reader = reader
self.limit = limit
async def readinto(self, buf):
async def areadinto(self, buf):
if self.limit < len(buf):
raise EOFError
else:
nread = await self.reader.readinto(buf)
nread = await self.reader.areadinto(buf)
self.limit -= nread
return nread
@ -83,7 +83,7 @@ class CountingWriter:
def __init__(self):
self.size = 0
async def write(self, buf):
async def awrite(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
@ -112,7 +112,7 @@ async def load_message(reader, msg_type):
await load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
await reader.readinto(bytearray(ivalue))
await reader.areadinto(bytearray(ivalue))
else:
raise ValueError
continue
@ -129,10 +129,10 @@ async def load_message(reader, msg_type):
fvalue = bool(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
await reader.areadinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.readinto(fvalue)
await reader.areadinto(fvalue)
fvalue = str(fvalue, 'utf8')
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
@ -186,11 +186,11 @@ async def dump_message(writer, msg):
elif ftype is BytesType:
await dump_uvarint(writer, len(svalue))
await writer.write(svalue)
await writer.awrite(svalue)
elif ftype is UnicodeType:
await dump_uvarint(writer, len(svalue))
await writer.write(bytes(svalue, 'utf8'))
await writer.awrite(bytes(svalue, 'utf8'))
elif issubclass(ftype, MessageType):
counter = CountingWriter()

View File

@ -8,58 +8,82 @@ from trezor import workflow
from . import codec_v1
from . import codec_v2
workflows = {}
workflow_handlers = {}
def register(wire_type, handler, *args):
if wire_type in workflows:
def register(mtype, handler, *args):
'''Register `handler` to get scheduled after `mtype` message is received.'''
if mtype in workflow_handlers:
raise KeyError
workflows[wire_type] = (handler, args)
workflow_handlers[mtype] = (handler, args)
def setup(interface):
session_supervisor = codec_v2.SesssionSupervisor(interface,
session_handler)
def setup(iface):
'''Initialize the wire stack on passed USB interface.'''
session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler)
session_supervisor.open(codec_v1.SESSION_ID)
loop.schedule_task(session_supervisor.listen())
class Context:
def __init__(self, interface, session_id):
self.interface = interface
self.session_id = session_id
def get_reader(self):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Reader(self.interface)
else:
return codec_v2.Reader(self.interface, self.session_id)
def get_writer(self, mtype, msize):
if self.session_id == codec_v1.SESSION_ID:
return codec_v1.Writer(self.interface, mtype, msize)
else:
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
async def read(self, types):
reader = self.get_reader()
await reader.open()
if reader.type not in types:
raise UnexpectedMessageError(reader)
return await protobuf.load_message(reader,
messages.get_type(reader.type))
async def write(self, msg):
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.close()
def __init__(self, iface, sid):
self.iface = iface
self.sid = sid
async def call(self, msg, types):
'''
Reply with `msg` and wait for one of `types`. See `self.write()` and
`self.read()`.
'''
await self.write(msg)
return await self.read(types)
async def read(self, types):
'''
Wait for incoming message on this wire context and return it. Raises
`UnexpectedMessageError` if the message type does not match one of
`types`; and caller should always make sure to re-raise it.
'''
reader = self.getreader()
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in types:
raise UnexpectedMessageError(reader)
# look up the protobuf class and parse the message
pbtype = messages.get_type(reader.type)
return await protobuf.load_message(reader, pbtype)
async def write(self, msg):
'''
Write a protobuf message to this wire context.
'''
writer = self.getwriter()
# get the message size
counter = protobuf.CountingWriter()
await protobuf.dump_message(counter, msg)
# write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
await protobuf.dump_message(writer, msg)
await writer.aclose()
def getreader(self):
if self.sid == codec_v1.SESSION_ID:
return codec_v1.Reader(self.iface)
else:
return codec_v2.Reader(self.iface, self.sid)
def getwriter(self):
if self.sid == codec_v1.SESSION_ID:
return codec_v1.Writer(self.iface)
else:
return codec_v2.Writer(self.iface, self.sid)
class UnexpectedMessageError(Exception):
def __init__(self, reader):
@ -74,60 +98,69 @@ class FailureError(Exception):
self.message = message
class Workflow:
def __init__(self, default):
self.handlers = {}
self.default = default
async def __call__(self, interface, session_id):
ctx = Context(interface, session_id)
while True:
async def session_handler(iface, sid):
reader = None
ctx = Context(iface, sid)
while True:
try:
# wait for new message, if needed, and find handler
if not reader:
reader = ctx.getreader()
await reader.aopen()
try:
reader = ctx.get_reader()
await reader.open()
try:
handler = self.handlers[reader.type]
except KeyError:
handler = self.default
try:
await handler(ctx, reader)
except UnexpectedMessageError as unexp_msg:
reader = unexp_msg.reader
except Exception as e:
log.exception(__name__, e)
handler, args = workflow_handlers[reader.type]
except KeyError:
handler, args = unexpected_msg, ()
await handler(ctx, reader, *args)
except UnexpectedMessageError as exc:
# retry with opened reader from the exception
reader = exc.reader
continue
except FailureError as exc:
# we log FailureError as warning, not as exception
log.warning(__name__, 'failure: %s', exc.message)
except Exception as exc:
# sessions are never closed by raised exceptions
log.exception(__name__, exc)
# read new message in next iteration
reader = None
async def protobuf_workflow(ctx, reader, handler, *args):
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
req = await protobuf.load_message(reader, messages.get_type(reader.type))
try:
res = await handler(reader.sid, msg, *args)
except Exception as exc:
if not isinstance(exc, UnexpectedMessageError):
await ctx.write(make_failure_msg(exc))
res = await handler(ctx, req, *args)
except UnexpectedMessageError:
# session handler takes care of this one
raise
else:
if res:
await ctx.write(res)
except FailureError as exc:
# respond with specific code and message
await ctx.write(Failure(code=exc.code, message=exc.message))
raise
except Exception as exc:
# respond with a generic code and message
await ctx.write(Failure(code=FirmwareError, message='Firmware error'))
raise
if res:
# respond with a specific response
await ctx.write(res)
async def handle_unexp_msg(ctx, reader):
async def unexpected_msg(ctx, reader):
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage
# receive the message and throw it away
while reader.size > 0:
buf = bytearray(reader.size)
await reader.readinto(buf)
await reader.areadinto(buf)
# respond with an unknown message error
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import UnexpectedMessage
await ctx.write(
Failure(code=UnexpectedMessage, message='Unexpected message'))
def make_failure_msg(exc):
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
if isinstance(exc, FailureError):
code = exc.code
message = exc.message
else:
code = FirmwareError
message = 'Firmware Error'
return Failure(code=code, message=message)

View File

@ -1,7 +1,6 @@
from micropython import const
import ustruct
from trezor import io
from trezor import loop
from trezor import utils
@ -32,13 +31,13 @@ class Reader:
def __repr__(self):
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
async def open(self):
async def aopen(self):
'''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
available after `aopen()` returns.
'''
read = loop.select(self.iface | loop.READ)
read = loop.select(self.iface.iface_num() | loop.READ)
while True:
# wait for initial report
report = await read
@ -55,7 +54,7 @@ class Reader:
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
self.ofs = 0
async def readinto(self, buf):
async def areadinto(self, buf):
'''
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -64,7 +63,7 @@ class Reader:
if self.size < len(buf):
raise EOFError
read = loop.select(self.iface | loop.READ)
read = loop.select(self.iface.iface_num() | loop.READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
@ -93,20 +92,28 @@ class Writer:
async-file-like interface.
'''
def __init__(self, iface, mtype, msize):
def __init__(self, iface):
self.iface = iface
self.type = mtype
self.size = msize
self.type = None
self.size = None
self.data = bytearray(_REP_LEN)
self.ofs = _REP_INIT_DATA
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
self.ofs = 0
def __repr__(self):
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
async def write(self, buf):
def setheader(self, mtype, msize):
'''
Reset the writer state and load the message header with passed type and
total message size.
'''
self.type = mtype
self.size = msize
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC,
_REP_MAGIC, mtype, msize)
self.ofs = _REP_INIT_DATA
async def awrite(self, buf):
'''
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
@ -115,7 +122,7 @@ class Writer:
if self.size < len(buf):
raise EOFError
write = loop.select(self.iface | loop.WRITE)
write = loop.select(self.iface.iface_num() | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
@ -127,12 +134,12 @@ class Writer:
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it
await write
io.write(self.iface, self.data)
self.iface.write(self.data)
self.ofs = _REP_CONT_DATA
return nwritten
async def close(self):
async def aclose(self):
'''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
@ -141,5 +148,5 @@ class Writer:
self.data[self.ofs] = 0x00
self.ofs += 1
await loop.select(self.iface | loop.WRITE)
io.send(self.iface, self.data)
await loop.select(self.iface.iface_num() | loop.WRITE)
self.iface.write(self.data)

View File

@ -1,7 +1,6 @@
from micropython import const
import ustruct
from trezor import io
from trezor import loop
from trezor import utils
from trezor.crypto import random
@ -28,11 +27,11 @@ _REP_MARKER_CONT = const(0x02)
_REP_MARKER_OPEN = const(0x03)
_REP_MARKER_CLOSE = const(0x04)
_REP = '>BL' # marker, session_id
_REP = '>BL' # marker, session_id
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
_REP_CONT = '>BLL' # marker, session_id, sequence
_REP_CONT = '>BLL' # marker, session_id, sequence
_REP_INIT_DATA = const(13) # offset of data in init report
_REP_CONT_DATA = const(9) # offset of data in cont report
_REP_CONT_DATA = const(9) # offset of data in cont report
class Reader:
@ -51,15 +50,16 @@ class Reader:
self.seq = 0
def __repr__(self):
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type,
self.size)
async def open(self):
async def aopen(self):
'''
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
available after `open()` returns.
available after `aopen()` returns.
'''
read = loop.select(self.iface | loop.READ)
read = loop.select(self.iface.iface_num() | loop.READ)
while True:
# wait for initial report
report = await read
@ -74,7 +74,7 @@ class Reader:
self.ofs = 0
self.seq = 0
async def readinto(self, buf):
async def areadinto(self, buf):
'''
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -83,7 +83,7 @@ class Reader:
if self.size < len(buf):
raise EOFError
read = loop.select(self.iface | loop.READ)
read = loop.select(self.iface.iface_num() | loop.READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
@ -115,20 +115,28 @@ class Writer:
interface.
'''
def __init__(self, iface, sid, mtype, msize):
def __init__(self, iface, sid):
self.iface = iface
self.sid = sid
self.type = None
self.size = None
self.data = bytearray(_REP_LEN)
self.ofs = 0
self.seq = 0
def setheader(self, mtype, msize):
'''
Reset the writer state and load the message header with passed type and
total message size.
'''
self.type = mtype
self.size = msize
self.data = bytearray(_REP_LEN)
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER_INIT, self.sid,
mtype, msize)
self.ofs = _REP_INIT_DATA
self.seq = 0
# load the report with initial header
ustruct.pack_into(_REP_INIT, self.data, 0,
_REP_MARKER_INIT, sid, mtype, msize)
async def write(self, buf):
async def awrite(self, buf):
'''
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
@ -137,7 +145,7 @@ class Writer:
if self.size < len(buf):
raise EOFError
write = loop.select(self.iface | loop.WRITE)
write = loop.select(self.iface.iface_num() | loop.WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
@ -149,15 +157,15 @@ class Writer:
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it, and prepare header
await write
io.send(self.iface, self.data)
ustruct.pack_into(_REP_CONT, self.data, 0,
_REP_MARKER_CONT, self.sid, self.seq)
self.iface.write(self.data)
ustruct.pack_into(_REP_CONT, self.data, 0, _REP_MARKER_CONT,
self.sid, self.seq)
self.ofs = _REP_CONT_DATA
self.seq += 1
return nwritten
async def close(self):
async def aclose(self):
'''Flush and close the message transmission.'''
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
@ -166,8 +174,8 @@ class Writer:
self.data[self.ofs] = 0x00
self.ofs += 1
await loop.select(self.iface | loop.WRITE)
io.send(self.iface, self.data)
await loop.select(self.iface.iface_num() | loop.WRITE)
self.iface.write(self.data)
class SesssionSupervisor:
@ -186,8 +194,8 @@ class SesssionSupervisor:
After close request, the handling task is closed and session terminated.
Both requests receive responses confirming the operation.
'''
read = loop.select(self.iface | loop.READ)
write = loop.select(self.iface | loop.WRITE)
read = loop.select(self.iface.iface_num() | loop.READ)
write = loop.select(self.iface.iface_num() | loop.WRITE)
while True:
report = await read
repmarker, repsid = ustruct.unpack(_REP, report)
@ -225,8 +233,8 @@ class SesssionSupervisor:
def writeopen(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
io.write(self.iface, self.session_report)
self.iface.write(self.session_report)
def writeclose(self, sid):
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
io.write(self.iface, self.session_report)
self.iface.write(self.session_report)

View File

@ -26,26 +26,26 @@ def test_reader():
# open, expected one read
first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
# empty read
empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
# short read, expected no read
short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer))
# aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
@ -53,12 +53,12 @@ def test_reader():
next_report_header = bytearray(unhexlify('3f'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
# too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
# long read, expect multiple reads
start_size = reader.size
@ -74,12 +74,12 @@ def test_reader():
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
# one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
def test_writer():
@ -87,7 +87,8 @@ def test_writer():
interface = 0xdeadbeef
message_type = 0x87654321
message_len = 1024
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID)
writer.setheader(message_type, message_len)
# init header corresponding to the data above
report_header = bytearray(unhexlify('3f2323432100000400'))
@ -96,14 +97,14 @@ def test_writer():
# empty write
start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size)
# short write, expected no report
start_size = writer.size
short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data,
report_header
@ -118,7 +119,7 @@ def test_writer():
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
@ -126,7 +127,7 @@ def test_writer():
# short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('3f'))
start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
@ -142,19 +143,19 @@ def test_writer():
# test write
expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original
# test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0)
msg.send = msg.send.original
# test close
expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original

View File

@ -26,26 +26,26 @@ def test_reader():
# open, expected one read
first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len)
# empty read
empty_buffer = bytearray()
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
assert_eq(len(empty_buffer), 0)
assert_eq(reader.size, message_len)
# short read, expected no read
short_buffer = bytearray(32)
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
assert_eq(len(short_buffer), 32)
assert_eq(short_buffer, message[:len(short_buffer)])
assert_eq(reader.size, message_len - len(short_buffer))
# aligned read, expected no read
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
@ -53,12 +53,12 @@ def test_reader():
next_report_header = bytearray(unhexlify('021234567800000000'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
# too long read, raises eof
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
# long read, expect multiple reads
start_size = reader.size
@ -74,12 +74,12 @@ def test_reader():
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface)))
expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.readinto(long_buffer), expected_syscalls)
assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:])
assert_eq(reader.size, 0)
# one byte read, raises eof
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
def test_writer():
@ -88,7 +88,8 @@ def test_writer():
session_id = 0x12345678
message_type = 0x87654321
message_len = 1024
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
writer = codec_v2.Writer(interface, session_id)
writer.setheader(message_type, message_len)
# init header corresponding to the data above
report_header = bytearray(unhexlify('01123456788765432100000400'))
@ -97,14 +98,14 @@ def test_writer():
# empty write
start_size = writer.size
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
assert_eq(writer.size, start_size)
# short write, expected no report
start_size = writer.size
short_payload = bytearray(range(4))
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data,
report_header
@ -119,7 +120,7 @@ def test_writer():
+ short_payload
+ aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1)
msg.send = msg.send.original
@ -127,7 +128,7 @@ def test_writer():
# short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('021234567800000000'))
start_size = writer.size
assert_async(writer.write(short_payload), [(None, StopIteration()),])
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
assert_eq(writer.size, start_size - len(short_payload))
assert_eq(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
@ -145,19 +146,19 @@ def test_writer():
# test write
expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports))
msg.send = msg.send.original
# test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0)
msg.send = msg.send.original
# test close
expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original