mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-28 16:21:03 +00:00
wire: pass Context to apps
This commit is contained in:
parent
b1b84fb233
commit
3562ffdc54
@ -7,7 +7,7 @@ signal = loop.Signal()
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def confirm(session_id, content, code=None, *args, **kwargs):
|
async def confirm(ctx, content, code=None, *args, **kwargs):
|
||||||
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
||||||
from trezor.messages.ButtonRequest import ButtonRequest
|
from trezor.messages.ButtonRequest import ButtonRequest
|
||||||
from trezor.messages.ButtonRequestType import Other
|
from trezor.messages.ButtonRequestType import Other
|
||||||
@ -19,12 +19,12 @@ async def confirm(session_id, content, code=None, *args, **kwargs):
|
|||||||
|
|
||||||
if code is None:
|
if code is None:
|
||||||
code = Other
|
code = Other
|
||||||
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
|
await ctx.call(ButtonRequest(code=code), ButtonAck)
|
||||||
return await loop.Wait((signal, dialog)) == CONFIRMED
|
return await loop.Wait((signal, dialog)) == CONFIRMED
|
||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
|
async def hold_to_confirm(ctx, content, code=None, *args, **kwargs):
|
||||||
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
|
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
|
||||||
from trezor.messages.ButtonRequest import ButtonRequest
|
from trezor.messages.ButtonRequest import ButtonRequest
|
||||||
from trezor.messages.ButtonRequestType import Other
|
from trezor.messages.ButtonRequestType import Other
|
||||||
@ -36,7 +36,7 @@ async def hold_to_confirm(session_id, content, code=None, *args, **kwargs):
|
|||||||
|
|
||||||
if code is None:
|
if code is None:
|
||||||
code = Other
|
code = Other
|
||||||
await wire.call(session_id, ButtonRequest(code=code), ButtonAck)
|
await ctx.call(ButtonRequest(code=code), ButtonAck)
|
||||||
return await loop.Wait((signal, dialog)) == CONFIRMED
|
return await loop.Wait((signal, dialog)) == CONFIRMED
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from trezor import ui, wire
|
from trezor import ui, wire
|
||||||
|
|
||||||
|
|
||||||
async def request_passphrase(session_id):
|
async def request_passphrase(ctx):
|
||||||
from trezor.messages.FailureType import ActionCancelled
|
from trezor.messages.FailureType import ActionCancelled
|
||||||
from trezor.messages.PassphraseRequest import PassphraseRequest
|
from trezor.messages.PassphraseRequest import PassphraseRequest
|
||||||
from trezor.messages.wire_types import PassphraseAck, Cancel
|
from trezor.messages.wire_types import PassphraseAck, Cancel
|
||||||
@ -12,17 +12,17 @@ async def request_passphrase(session_id):
|
|||||||
'Please enter passphrase', 'on your computer.')
|
'Please enter passphrase', 'on your computer.')
|
||||||
text.render()
|
text.render()
|
||||||
|
|
||||||
ack = await wire.call(session_id, PassphraseRequest(), PassphraseAck, Cancel)
|
ack = await ctx.call(PassphraseRequest(), PassphraseAck, Cancel)
|
||||||
if ack.MESSAGE_WIRE_TYPE == Cancel:
|
if ack.MESSAGE_WIRE_TYPE == Cancel:
|
||||||
raise wire.FailureError(ActionCancelled, 'Passphrase cancelled')
|
raise wire.FailureError(ActionCancelled, 'Passphrase cancelled')
|
||||||
|
|
||||||
return ack.passphrase
|
return ack.passphrase
|
||||||
|
|
||||||
|
|
||||||
async def protect_by_passphrase(session_id):
|
async def protect_by_passphrase(ctx):
|
||||||
from apps.common import storage
|
from apps.common import storage
|
||||||
|
|
||||||
if storage.is_protected_by_passphrase():
|
if storage.is_protected_by_passphrase():
|
||||||
return await request_passphrase(session_id)
|
return await request_passphrase(ctx)
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
@ -7,7 +7,7 @@ if __debug__:
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
async def request_pin_on_display(ctx: wire.Context, code: int=None) -> str:
|
||||||
from trezor.messages.ButtonRequest import ButtonRequest
|
from trezor.messages.ButtonRequest import ButtonRequest
|
||||||
from trezor.messages.ButtonRequestType import ProtectCall
|
from trezor.messages.ButtonRequestType import ProtectCall
|
||||||
from trezor.messages.FailureType import PinCancelled
|
from trezor.messages.FailureType import PinCancelled
|
||||||
@ -20,9 +20,8 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
|||||||
|
|
||||||
_, label = _get_code_and_label(code)
|
_, label = _get_code_and_label(code)
|
||||||
|
|
||||||
await wire.call(session_id,
|
await ctx.call(ButtonRequest(code=ProtectCall),
|
||||||
ButtonRequest(code=ProtectCall),
|
ButtonAck)
|
||||||
ButtonAck)
|
|
||||||
|
|
||||||
ui.display.clear()
|
ui.display.clear()
|
||||||
matrix = PinMatrix(label)
|
matrix = PinMatrix(label)
|
||||||
@ -36,7 +35,7 @@ async def request_pin_on_display(session_id: int, code: int=None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def request_pin_on_client(session_id: int, code: int=None) -> str:
|
async def request_pin_on_client(ctx: wire.Context, code: int=None) -> str:
|
||||||
from trezor.messages.FailureType import PinCancelled
|
from trezor.messages.FailureType import PinCancelled
|
||||||
from trezor.messages.PinMatrixRequest import PinMatrixRequest
|
from trezor.messages.PinMatrixRequest import PinMatrixRequest
|
||||||
from trezor.messages.wire_types import PinMatrixAck, Cancel
|
from trezor.messages.wire_types import PinMatrixAck, Cancel
|
||||||
@ -51,9 +50,8 @@ async def request_pin_on_client(session_id: int, code: int=None) -> str:
|
|||||||
matrix = PinMatrix(label)
|
matrix = PinMatrix(label)
|
||||||
matrix.render()
|
matrix.render()
|
||||||
|
|
||||||
ack = await wire.call(session_id,
|
ack = await ctx.call(PinMatrixRequest(type=code),
|
||||||
PinMatrixRequest(type=code),
|
PinMatrixAck, Cancel)
|
||||||
PinMatrixAck, Cancel)
|
|
||||||
digits = matrix.digits
|
digits = matrix.digits
|
||||||
matrix = None
|
matrix = None
|
||||||
|
|
||||||
@ -66,12 +64,12 @@ request_pin = request_pin_on_client
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def request_pin_twice(session_id: int) -> str:
|
async def request_pin_twice(ctx: wire.Context) -> str:
|
||||||
from trezor.messages.FailureType import ActionCancelled
|
from trezor.messages.FailureType import ActionCancelled
|
||||||
from trezor.messages import PinMatrixRequestType
|
from trezor.messages import PinMatrixRequestType
|
||||||
|
|
||||||
pin_first = await request_pin(session_id, PinMatrixRequestType.NewFirst)
|
pin_first = await request_pin(ctx, PinMatrixRequestType.NewFirst)
|
||||||
pin_again = await request_pin(session_id, PinMatrixRequestType.NewSecond)
|
pin_again = await request_pin(ctx, PinMatrixRequestType.NewSecond)
|
||||||
if pin_first != pin_again:
|
if pin_first != pin_again:
|
||||||
# changed message due to consistency with T1 msgs
|
# changed message due to consistency with T1 msgs
|
||||||
raise wire.FailureError(ActionCancelled, 'PIN change failed')
|
raise wire.FailureError(ActionCancelled, 'PIN change failed')
|
||||||
@ -79,22 +77,22 @@ async def request_pin_twice(session_id: int) -> str:
|
|||||||
return pin_first
|
return pin_first
|
||||||
|
|
||||||
|
|
||||||
async def protect_by_pin_repeatedly(session_id: int, at_least_once: bool=False):
|
async def protect_by_pin_repeatedly(ctx: wire.Context, at_least_once: bool=False):
|
||||||
from . import storage
|
from . import storage
|
||||||
|
|
||||||
locked = storage.is_locked() or at_least_once
|
locked = storage.is_locked() or at_least_once
|
||||||
while locked:
|
while locked:
|
||||||
pin = await request_pin(session_id)
|
pin = await request_pin(ctx)
|
||||||
locked = not storage.unlock(pin, _render_pin_failure)
|
locked = not storage.unlock(pin, _render_pin_failure)
|
||||||
|
|
||||||
|
|
||||||
async def protect_by_pin_or_fail(session_id: int, at_least_once: bool=False):
|
async def protect_by_pin_or_fail(ctx: wire.Context, at_least_once: bool=False):
|
||||||
from trezor.messages.FailureType import PinInvalid
|
from trezor.messages.FailureType import PinInvalid
|
||||||
from . import storage
|
from . import storage
|
||||||
|
|
||||||
locked = storage.is_locked() or at_least_once
|
locked = storage.is_locked() or at_least_once
|
||||||
if locked:
|
if locked:
|
||||||
pin = await request_pin(session_id)
|
pin = await request_pin(ctx)
|
||||||
if not storage.unlock(pin, _render_pin_failure):
|
if not storage.unlock(pin, _render_pin_failure):
|
||||||
raise wire.FailureError(PinInvalid, 'PIN invalid')
|
raise wire.FailureError(PinInvalid, 'PIN invalid')
|
||||||
|
|
||||||
|
@ -5,20 +5,20 @@ from trezor.crypto import bip39
|
|||||||
_DEFAULT_CURVE = 'secp256k1'
|
_DEFAULT_CURVE = 'secp256k1'
|
||||||
|
|
||||||
|
|
||||||
async def get_root(session_id: int, curve_name=_DEFAULT_CURVE):
|
async def get_root(ctx: wire.Context, curve_name=_DEFAULT_CURVE):
|
||||||
seed = await get_seed(session_id)
|
seed = await get_seed(ctx)
|
||||||
root = bip32.from_seed(seed, curve_name)
|
root = bip32.from_seed(seed, curve_name)
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
async def get_seed(session_id: int) -> bytes:
|
async def get_seed(ctx: wire.Context) -> bytes:
|
||||||
from . import cache
|
from . import cache
|
||||||
if cache.seed is None:
|
if cache.seed is None:
|
||||||
cache.seed = await compute_seed(session_id)
|
cache.seed = await compute_seed(ctx)
|
||||||
return cache.seed
|
return cache.seed
|
||||||
|
|
||||||
|
|
||||||
async def compute_seed(session_id: int) -> bytes:
|
async def compute_seed(ctx: wire.Context) -> bytes:
|
||||||
from trezor.messages.FailureType import ProcessError
|
from trezor.messages.FailureType import ProcessError
|
||||||
from .request_passphrase import protect_by_passphrase
|
from .request_passphrase import protect_by_passphrase
|
||||||
from .request_pin import protect_by_pin
|
from .request_pin import protect_by_pin
|
||||||
@ -27,9 +27,9 @@ async def compute_seed(session_id: int) -> bytes:
|
|||||||
if not storage.is_initialized():
|
if not storage.is_initialized():
|
||||||
raise wire.FailureError(ProcessError, 'Device is not initialized')
|
raise wire.FailureError(ProcessError, 'Device is not initialized')
|
||||||
|
|
||||||
await protect_by_pin(session_id)
|
await protect_by_pin(ctx)
|
||||||
|
|
||||||
passphrase = await protect_by_passphrase(session_id)
|
passphrase = await protect_by_passphrase(ctx)
|
||||||
return bip39.seed(storage.get_mnemonic(), passphrase)
|
return bip39.seed(storage.get_mnemonic(), passphrase)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,13 +4,13 @@ from trezor.messages.wire_types import \
|
|||||||
DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase
|
DebugLinkMemoryRead, DebugLinkMemoryWrite, DebugLinkFlashErase
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkDecision(session_id, msg):
|
async def dispatch_DebugLinkDecision(ctx, msg):
|
||||||
from trezor.ui.confirm import CONFIRMED, CANCELLED
|
from trezor.ui.confirm import CONFIRMED, CANCELLED
|
||||||
from apps.common.confirm import signal
|
from apps.common.confirm import signal
|
||||||
signal.send(CONFIRMED if msg.yes_no else CANCELLED)
|
signal.send(CONFIRMED if msg.yes_no else CANCELLED)
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkGetState(session_id, msg):
|
async def dispatch_DebugLinkGetState(ctx, msg):
|
||||||
from trezor.messages.DebugLinkState import DebugLinkState
|
from trezor.messages.DebugLinkState import DebugLinkState
|
||||||
from apps.common import storage, request_pin
|
from apps.common import storage, request_pin
|
||||||
from apps.management import reset_device
|
from apps.management import reset_device
|
||||||
@ -36,11 +36,11 @@ async def dispatch_DebugLinkGetState(session_id, msg):
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkStop(session_id, msg):
|
async def dispatch_DebugLinkStop(ctx, msg):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkMemoryRead(session_id, msg):
|
async def dispatch_DebugLinkMemoryRead(ctx, msg):
|
||||||
from trezor.messages.DebugLinkMemory import DebugLinkMemory
|
from trezor.messages.DebugLinkMemory import DebugLinkMemory
|
||||||
from uctypes import bytes_at
|
from uctypes import bytes_at
|
||||||
m = DebugLinkMemory()
|
m = DebugLinkMemory()
|
||||||
@ -48,14 +48,14 @@ async def dispatch_DebugLinkMemoryRead(session_id, msg):
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkMemoryWrite(session_id, msg):
|
async def dispatch_DebugLinkMemoryWrite(ctx, msg):
|
||||||
from uctypes import bytearray_at
|
from uctypes import bytearray_at
|
||||||
l = len(msg.memory)
|
l = len(msg.memory)
|
||||||
data = bytearray_at(msg.address, l)
|
data = bytearray_at(msg.address, l)
|
||||||
data[0:l] = msg.memory
|
data[0:l] = msg.memory
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_DebugLinkFlashErase(session_id, msg):
|
async def dispatch_DebugLinkFlashErase(ctx, msg):
|
||||||
# TODO: erase(msg.sector)
|
# TODO: erase(msg.sector)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -3,13 +3,13 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_ethereum_get_address(session_id, msg):
|
async def layout_ethereum_get_address(ctx, msg):
|
||||||
from trezor.messages.EthereumAddress import EthereumAddress
|
from trezor.messages.EthereumAddress import EthereumAddress
|
||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from trezor.crypto.hashlib import sha3_256
|
from trezor.crypto.hashlib import sha3_256
|
||||||
from ..common import seed
|
from ..common import seed
|
||||||
|
|
||||||
node = await seed.get_root(session_id)
|
node = await seed.get_root(ctx)
|
||||||
node.derive_path(msg.address_n or ())
|
node.derive_path(msg.address_n or ())
|
||||||
|
|
||||||
seckey = node.private_key()
|
seckey = node.private_key()
|
||||||
@ -17,11 +17,11 @@ async def layout_ethereum_get_address(session_id, msg):
|
|||||||
address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak
|
address = sha3_256(public_key[1:]).digest(True)[12:] # Keccak
|
||||||
|
|
||||||
if msg.show_display:
|
if msg.show_display:
|
||||||
await _show_address(session_id, address)
|
await _show_address(ctx, address)
|
||||||
return EthereumAddress(address=address)
|
return EthereumAddress(address=address)
|
||||||
|
|
||||||
|
|
||||||
async def _show_address(session_id, address):
|
async def _show_address(ctx, address):
|
||||||
from trezor.messages.ButtonRequestType import Address
|
from trezor.messages.ButtonRequestType import Address
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from ..common.confirm import require_confirm
|
from ..common.confirm import require_confirm
|
||||||
@ -30,7 +30,7 @@ async def _show_address(session_id, address):
|
|||||||
|
|
||||||
content = Text('Confirm address', ui.ICON_RESET,
|
content = Text('Confirm address', ui.ICON_RESET,
|
||||||
ui.MONO, *_split_address(address))
|
ui.MONO, *_split_address(address))
|
||||||
await require_confirm(session_id, content, code=Address)
|
await require_confirm(ctx, content, code=Address)
|
||||||
|
|
||||||
|
|
||||||
def _split_address(address):
|
def _split_address(address):
|
||||||
|
@ -226,11 +226,11 @@ class Cmd:
|
|||||||
return Msg(self.cid, cla, ins, p1, p2, lc, data)
|
return Msg(self.cid, cla, ins, p1, p2, lc, data)
|
||||||
|
|
||||||
|
|
||||||
async def read_cmd(iface: int) -> Cmd:
|
async def read_cmd(iface: io.HID) -> Cmd:
|
||||||
desc_init = frame_init()
|
desc_init = frame_init()
|
||||||
desc_cont = frame_cont()
|
desc_cont = frame_cont()
|
||||||
|
|
||||||
buf, = await loop.select(iface)
|
buf, = await loop.select(iface.iface_num())
|
||||||
# log.debug(__name__, 'read init %s', buf)
|
# log.debug(__name__, 'read init %s', buf)
|
||||||
|
|
||||||
ifrm = overlay_struct(buf, desc_init)
|
ifrm = overlay_struct(buf, desc_init)
|
||||||
@ -252,7 +252,7 @@ async def read_cmd(iface: int) -> Cmd:
|
|||||||
data = data[:bcnt]
|
data = data[:bcnt]
|
||||||
|
|
||||||
while datalen < bcnt:
|
while datalen < bcnt:
|
||||||
buf, = await loop.select(iface)
|
buf, = await loop.select(iface.iface_num())
|
||||||
# log.debug(__name__, 'read cont %s', buf)
|
# log.debug(__name__, 'read cont %s', buf)
|
||||||
|
|
||||||
cfrm = overlay_struct(buf, desc_cont)
|
cfrm = overlay_struct(buf, desc_cont)
|
||||||
@ -282,7 +282,7 @@ async def read_cmd(iface: int) -> Cmd:
|
|||||||
return Cmd(ifrm.cid, ifrm.cmd, data)
|
return Cmd(ifrm.cid, ifrm.cmd, data)
|
||||||
|
|
||||||
|
|
||||||
def send_cmd(cmd: Cmd, iface: int) -> None:
|
def send_cmd(cmd: Cmd, iface: io.HID) -> None:
|
||||||
init_desc = frame_init()
|
init_desc = frame_init()
|
||||||
cont_desc = frame_cont()
|
cont_desc = frame_cont()
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -295,7 +295,7 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
|
|||||||
frm.bcnt = datalen
|
frm.bcnt = datalen
|
||||||
|
|
||||||
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||||
io.send(iface, buf)
|
iface.write(buf)
|
||||||
# log.debug(__name__, 'send init %s', buf)
|
# log.debug(__name__, 'send init %s', buf)
|
||||||
|
|
||||||
if offset < datalen:
|
if offset < datalen:
|
||||||
@ -304,18 +304,17 @@ def send_cmd(cmd: Cmd, iface: int) -> None:
|
|||||||
while offset < datalen:
|
while offset < datalen:
|
||||||
frm.seq = seq
|
frm.seq = seq
|
||||||
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||||
utime.sleep_ms(1) # FIXME: do async send
|
utime.sleep_ms(1) # FIXME: async write
|
||||||
io.send(iface, buf)
|
iface.write(buf)
|
||||||
# log.debug(__name__, 'send cont %s', buf)
|
# log.debug(__name__, 'send cont %s', buf)
|
||||||
seq += 1
|
seq += 1
|
||||||
|
|
||||||
|
|
||||||
def boot():
|
def boot(iface: io.HID):
|
||||||
iface = 0x03
|
|
||||||
loop.schedule_task(handle_reports(iface))
|
loop.schedule_task(handle_reports(iface))
|
||||||
|
|
||||||
|
|
||||||
async def handle_reports(iface: int):
|
async def handle_reports(iface: io.HID):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
req = await read_cmd(iface)
|
req = await read_cmd(iface)
|
||||||
|
@ -4,7 +4,7 @@ from trezor.messages.wire_types import Initialize, GetFeatures, Ping
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def respond_Features(session_id, msg):
|
async def respond_Features(ctx, msg):
|
||||||
from apps.common import storage, coins
|
from apps.common import storage, coins
|
||||||
from trezor.messages.Features import Features
|
from trezor.messages.Features import Features
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ async def respond_Features(session_id, msg):
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def respond_Pong(session_id, msg):
|
async def respond_Pong(ctx, msg):
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
|
|
||||||
s = Success()
|
s = Success()
|
||||||
@ -36,11 +36,11 @@ async def respond_Pong(session_id, msg):
|
|||||||
|
|
||||||
if msg.pin_protection:
|
if msg.pin_protection:
|
||||||
from apps.common.request_pin import protect_by_pin
|
from apps.common.request_pin import protect_by_pin
|
||||||
await protect_by_pin(session_id)
|
await protect_by_pin(ctx)
|
||||||
|
|
||||||
if msg.passphrase_protection:
|
if msg.passphrase_protection:
|
||||||
from apps.common.request_passphrase import protect_by_passphrase
|
from apps.common.request_passphrase import protect_by_passphrase
|
||||||
await protect_by_passphrase(session_id)
|
await protect_by_passphrase(ctx)
|
||||||
|
|
||||||
# TODO: handle other fields:
|
# TODO: handle other fields:
|
||||||
# button_protection
|
# button_protection
|
||||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_apply_settings(session_id, msg):
|
async def layout_apply_settings(ctx, msg):
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
from trezor.messages.FailureType import ProcessError
|
from trezor.messages.FailureType import ProcessError
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
@ -11,7 +11,7 @@ async def layout_apply_settings(session_id, msg):
|
|||||||
from ..common.request_pin import protect_by_pin
|
from ..common.request_pin import protect_by_pin
|
||||||
from ..common import storage
|
from ..common import storage
|
||||||
|
|
||||||
await protect_by_pin(session_id)
|
await protect_by_pin(ctx)
|
||||||
|
|
||||||
if msg.homescreen is not None:
|
if msg.homescreen is not None:
|
||||||
raise wire.FailureError(
|
raise wire.FailureError(
|
||||||
@ -21,20 +21,20 @@ async def layout_apply_settings(session_id, msg):
|
|||||||
raise wire.FailureError(ProcessError, 'No setting provided')
|
raise wire.FailureError(ProcessError, 'No setting provided')
|
||||||
|
|
||||||
if msg.label is not None:
|
if msg.label is not None:
|
||||||
await require_confirm(session_id, Text(
|
await require_confirm(ctx, Text(
|
||||||
'Change label', ui.ICON_RESET,
|
'Change label', ui.ICON_RESET,
|
||||||
'Do you really want to', 'change label to',
|
'Do you really want to', 'change label to',
|
||||||
ui.BOLD, '%s' % msg.label))
|
ui.BOLD, '%s' % msg.label))
|
||||||
|
|
||||||
if msg.language is not None:
|
if msg.language is not None:
|
||||||
await require_confirm(session_id, Text(
|
await require_confirm(ctx, Text(
|
||||||
'Change language', ui.ICON_RESET,
|
'Change language', ui.ICON_RESET,
|
||||||
'Do you really want to', 'change language to',
|
'Do you really want to', 'change language to',
|
||||||
ui.BOLD, '%s' % msg.language,
|
ui.BOLD, '%s' % msg.language,
|
||||||
ui.NORMAL, '?'))
|
ui.NORMAL, '?'))
|
||||||
|
|
||||||
if msg.use_passphrase is not None:
|
if msg.use_passphrase is not None:
|
||||||
await require_confirm(session_id, Text(
|
await require_confirm(ctx, Text(
|
||||||
'Enable passphrase' if msg.use_passphrase else 'Disable passphrase',
|
'Enable passphrase' if msg.use_passphrase else 'Disable passphrase',
|
||||||
ui.ICON_RESET,
|
ui.ICON_RESET,
|
||||||
'Do you really want to',
|
'Do you really want to',
|
||||||
|
@ -2,52 +2,52 @@ from trezor import ui
|
|||||||
from trezor.utils import unimport
|
from trezor.utils import unimport
|
||||||
|
|
||||||
|
|
||||||
def confirm_set_pin(session_id):
|
def confirm_set_pin(ctx):
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
return require_confirm(session_id, Text(
|
return require_confirm(ctx, Text(
|
||||||
'Change PIN', ui.ICON_RESET,
|
'Change PIN', ui.ICON_RESET,
|
||||||
'Do you really want to', ui.BOLD,
|
'Do you really want to', ui.BOLD,
|
||||||
'set new PIN?'))
|
'set new PIN?'))
|
||||||
|
|
||||||
|
|
||||||
def confirm_change_pin(session_id):
|
def confirm_change_pin(ctx):
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
return require_confirm(session_id, Text(
|
return require_confirm(ctx, Text(
|
||||||
'Change PIN', ui.ICON_RESET,
|
'Change PIN', ui.ICON_RESET,
|
||||||
'Do you really want to', ui.BOLD,
|
'Do you really want to', ui.BOLD,
|
||||||
'change current PIN?'))
|
'change current PIN?'))
|
||||||
|
|
||||||
|
|
||||||
def confirm_remove_pin(session_id):
|
def confirm_remove_pin(ctx):
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
return require_confirm(session_id, Text(
|
return require_confirm(ctx, Text(
|
||||||
'Remove PIN', ui.ICON_RESET,
|
'Remove PIN', ui.ICON_RESET,
|
||||||
'Do you really want to', ui.BOLD,
|
'Do you really want to', ui.BOLD,
|
||||||
'remove current PIN?'))
|
'remove current PIN?'))
|
||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_change_pin(session_id, msg):
|
async def layout_change_pin(ctx, msg):
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
from apps.common.request_pin import protect_by_pin, request_pin_twice
|
from apps.common.request_pin import protect_by_pin, request_pin_twice
|
||||||
from apps.common import storage
|
from apps.common import storage
|
||||||
|
|
||||||
if msg.remove:
|
if msg.remove:
|
||||||
if storage.is_protected_by_pin():
|
if storage.is_protected_by_pin():
|
||||||
await confirm_remove_pin(session_id)
|
await confirm_remove_pin(ctx)
|
||||||
await protect_by_pin(session_id, at_least_once=True)
|
await protect_by_pin(ctx, at_least_once=True)
|
||||||
pin = ''
|
pin = ''
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if storage.is_protected_by_pin():
|
if storage.is_protected_by_pin():
|
||||||
await confirm_change_pin(session_id)
|
await confirm_change_pin(ctx)
|
||||||
await protect_by_pin(session_id, at_least_once=True)
|
await protect_by_pin(ctx, at_least_once=True)
|
||||||
else:
|
else:
|
||||||
await confirm_set_pin(session_id)
|
await confirm_set_pin(ctx)
|
||||||
pin = await request_pin_twice(session_id)
|
pin = await request_pin_twice(ctx)
|
||||||
|
|
||||||
storage.load_settings(pin=pin)
|
storage.load_settings(pin=pin)
|
||||||
if pin:
|
if pin:
|
||||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_load_device(session_id, msg):
|
async def layout_load_device(ctx, msg):
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
from trezor.messages.FailureType import UnexpectedMessage, ProcessError
|
from trezor.messages.FailureType import UnexpectedMessage, ProcessError
|
||||||
@ -20,7 +20,7 @@ async def layout_load_device(session_id, msg):
|
|||||||
if not msg.skip_checksum and not bip39.check(msg.mnemonic):
|
if not msg.skip_checksum and not bip39.check(msg.mnemonic):
|
||||||
raise wire.FailureError(ProcessError, 'Mnemonic is not valid')
|
raise wire.FailureError(ProcessError, 'Mnemonic is not valid')
|
||||||
|
|
||||||
await require_confirm(session_id, Text(
|
await require_confirm(ctx, Text(
|
||||||
'Loading seed', ui.ICON_RESET,
|
'Loading seed', ui.ICON_RESET,
|
||||||
ui.BOLD, 'Loading private seed', 'is not recommended.',
|
ui.BOLD, 'Loading private seed', 'is not recommended.',
|
||||||
ui.NORMAL, 'Continue only if you', 'know what you are doing!'))
|
ui.NORMAL, 'Continue only if you', 'know what you are doing!'))
|
||||||
|
@ -11,7 +11,7 @@ def nth(n):
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_recovery_device(session_id, msg):
|
async def layout_recovery_device(ctx, msg):
|
||||||
|
|
||||||
msg = 'Please enter ' + nth(msg.word_count) + ' word'
|
msg = 'Please enter ' + nth(msg.word_count) + ' word'
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ if __debug__:
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_reset_device(session_id, msg):
|
async def layout_reset_device(ctx, msg):
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from trezor.crypto import hashlib, random, bip39
|
from trezor.crypto import hashlib, random, bip39
|
||||||
from trezor.messages.EntropyRequest import EntropyRequest
|
from trezor.messages.EntropyRequest import EntropyRequest
|
||||||
@ -39,21 +39,21 @@ async def layout_reset_device(session_id, msg):
|
|||||||
if msg.display_random:
|
if msg.display_random:
|
||||||
entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16)
|
entropy_lines = chunks(ubinascii.hexlify(internal_entropy), 16)
|
||||||
entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines)
|
entropy_content = Text('Internal entropy', ui.ICON_RESET, *entropy_lines)
|
||||||
await require_confirm(session_id, entropy_content, ButtonRequestType.ResetDevice)
|
await require_confirm(ctx, entropy_content, ButtonRequestType.ResetDevice)
|
||||||
|
|
||||||
if msg.pin_protection:
|
if msg.pin_protection:
|
||||||
pin = await request_pin_twice(session_id)
|
pin = await request_pin_twice(ctx)
|
||||||
else:
|
else:
|
||||||
pin = None
|
pin = None
|
||||||
|
|
||||||
external_entropy_ack = await wire.call(session_id, EntropyRequest(), EntropyAck)
|
external_entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
|
||||||
ctx = hashlib.sha256()
|
ctx = hashlib.sha256()
|
||||||
ctx.update(internal_entropy)
|
ctx.update(internal_entropy)
|
||||||
ctx.update(external_entropy_ack.entropy)
|
ctx.update(external_entropy_ack.entropy)
|
||||||
entropy = ctx.digest()
|
entropy = ctx.digest()
|
||||||
mnemonic = bip39.from_data(entropy[:msg.strength // 8])
|
mnemonic = bip39.from_data(entropy[:msg.strength // 8])
|
||||||
|
|
||||||
await show_mnemonic_by_word(session_id, mnemonic)
|
await show_mnemonic_by_word(ctx, mnemonic)
|
||||||
|
|
||||||
storage.load_mnemonic(mnemonic)
|
storage.load_mnemonic(mnemonic)
|
||||||
storage.load_settings(pin=pin,
|
storage.load_settings(pin=pin,
|
||||||
@ -64,7 +64,7 @@ async def layout_reset_device(session_id, msg):
|
|||||||
return Success(message='Initialized')
|
return Success(message='Initialized')
|
||||||
|
|
||||||
|
|
||||||
async def show_mnemonic_by_word(session_id, mnemonic):
|
async def show_mnemonic_by_word(ctx, mnemonic):
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from trezor.messages.ButtonRequestType import ConfirmWord
|
from trezor.messages.ButtonRequestType import ConfirmWord
|
||||||
from apps.common.confirm import confirm
|
from apps.common.confirm import confirm
|
||||||
@ -80,7 +80,7 @@ async def show_mnemonic_by_word(session_id, mnemonic):
|
|||||||
while index < len(words):
|
while index < len(words):
|
||||||
word = words[index]
|
word = words[index]
|
||||||
current_word = word
|
current_word = word
|
||||||
await confirm(session_id,
|
await confirm(ctx,
|
||||||
Text(
|
Text(
|
||||||
'Recovery seed setup', ui.ICON_RESET,
|
'Recovery seed setup', ui.ICON_RESET,
|
||||||
ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ',
|
ui.NORMAL, 'Write down seed word' if recovery else 'Confirm seed word', ' ',
|
||||||
|
@ -3,13 +3,13 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_wipe_device(session_id, msg):
|
async def layout_wipe_device(ctx, msg):
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from ..common.confirm import hold_to_confirm
|
from ..common.confirm import hold_to_confirm
|
||||||
from ..common import storage
|
from ..common import storage
|
||||||
|
|
||||||
await hold_to_confirm(session_id, Text(
|
await hold_to_confirm(ctx, Text(
|
||||||
'WIPE DEVICE',
|
'WIPE DEVICE',
|
||||||
ui.ICON_WIPE,
|
ui.ICON_WIPE,
|
||||||
ui.NORMAL, 'Do you really want to', 'wipe the device?',
|
ui.NORMAL, 'Do you really want to', 'wipe the device?',
|
||||||
|
@ -26,7 +26,7 @@ def cipher_key_value(msg, seckey: bytes) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_cipher_key_value(session_id, msg):
|
async def layout_cipher_key_value(ctx, msg):
|
||||||
from trezor.messages.CipheredKeyValue import CipheredKeyValue
|
from trezor.messages.CipheredKeyValue import CipheredKeyValue
|
||||||
from ..common import seed
|
from ..common import seed
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ async def layout_cipher_key_value(session_id, msg):
|
|||||||
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
||||||
ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK)
|
ui.display.text(10, 60, msg.key, ui.MONO, ui.WHITE, ui.BLACK)
|
||||||
|
|
||||||
node = await seed.get_root(session_id)
|
node = await seed.get_root(ctx)
|
||||||
node.derive_path(msg.address_n)
|
node.derive_path(msg.address_n)
|
||||||
|
|
||||||
value = cipher_key_value(msg, node.private_key())
|
value = cipher_key_value(msg, node.private_key())
|
||||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_get_address(session_id, msg):
|
async def layout_get_address(ctx, msg):
|
||||||
from trezor.messages.Address import Address
|
from trezor.messages.Address import Address
|
||||||
from trezor.messages.FailureType import ProcessError
|
from trezor.messages.FailureType import ProcessError
|
||||||
from ..common import coins
|
from ..common import coins
|
||||||
@ -15,18 +15,18 @@ async def layout_get_address(session_id, msg):
|
|||||||
address_n = msg.address_n or ()
|
address_n = msg.address_n or ()
|
||||||
coin_name = msg.coin_name or 'Bitcoin'
|
coin_name = msg.coin_name or 'Bitcoin'
|
||||||
|
|
||||||
node = await seed.get_root(session_id)
|
node = await seed.get_root(ctx)
|
||||||
node.derive_path(address_n)
|
node.derive_path(address_n)
|
||||||
coin = coins.by_name(coin_name)
|
coin = coins.by_name(coin_name)
|
||||||
address = node.address(coin.address_type)
|
address = node.address(coin.address_type)
|
||||||
|
|
||||||
if msg.show_display:
|
if msg.show_display:
|
||||||
await _show_address(session_id, address)
|
await _show_address(ctx, address)
|
||||||
|
|
||||||
return Address(address=address)
|
return Address(address=address)
|
||||||
|
|
||||||
|
|
||||||
async def _show_address(session_id, address):
|
async def _show_address(ctx, address):
|
||||||
from trezor.messages.ButtonRequestType import Address
|
from trezor.messages.ButtonRequestType import Address
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from trezor.ui.qr import Qr
|
from trezor.ui.qr import Qr
|
||||||
@ -37,7 +37,7 @@ async def _show_address(session_id, address):
|
|||||||
content = Container(
|
content = Container(
|
||||||
Qr(address, (120, 135), 3),
|
Qr(address, (120, 135), 3),
|
||||||
Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines))
|
Text('Confirm address', ui.ICON_RESET, ui.MONO, *lines))
|
||||||
await require_confirm(session_id, content, code=Address)
|
await require_confirm(ctx, content, code=Address)
|
||||||
|
|
||||||
|
|
||||||
def _split_address(address):
|
def _split_address(address):
|
||||||
|
@ -3,18 +3,18 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_get_entropy(session_id, msg):
|
async def layout_get_entropy(ctx, msg):
|
||||||
from trezor.messages.Entropy import Entropy
|
from trezor.messages.Entropy import Entropy
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
|
|
||||||
l = min(msg.size, 1024)
|
l = min(msg.size, 1024)
|
||||||
|
|
||||||
await _show_entropy(session_id)
|
await _show_entropy(ctx)
|
||||||
|
|
||||||
return Entropy(entropy=random.bytes(l))
|
return Entropy(entropy=random.bytes(l))
|
||||||
|
|
||||||
|
|
||||||
async def _show_entropy(session_id):
|
async def _show_entropy(ctx):
|
||||||
from trezor.messages.ButtonRequestType import ProtectCall
|
from trezor.messages.ButtonRequestType import ProtectCall
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from trezor.ui.container import Container
|
from trezor.ui.container import Container
|
||||||
@ -23,4 +23,4 @@ async def _show_entropy(session_id):
|
|||||||
content = Container(
|
content = Container(
|
||||||
Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?'))
|
Text('Confirm entropy', ui.ICON_RESET, ui.MONO, 'Do you really want to send entropy?'))
|
||||||
|
|
||||||
await require_confirm(session_id, content, code=ProtectCall)
|
await require_confirm(ctx, content, code=ProtectCall)
|
||||||
|
@ -2,7 +2,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_get_public_key(session_id, msg):
|
async def layout_get_public_key(ctx, msg):
|
||||||
from trezor.messages.HDNodeType import HDNodeType
|
from trezor.messages.HDNodeType import HDNodeType
|
||||||
from trezor.messages.PublicKey import PublicKey
|
from trezor.messages.PublicKey import PublicKey
|
||||||
from ..common import coins
|
from ..common import coins
|
||||||
@ -11,7 +11,7 @@ async def layout_get_public_key(session_id, msg):
|
|||||||
address_n = msg.address_n or ()
|
address_n = msg.address_n or ()
|
||||||
coin_name = msg.coin_name or 'Bitcoin'
|
coin_name = msg.coin_name or 'Bitcoin'
|
||||||
|
|
||||||
node = await seed.get_root(session_id)
|
node = await seed.get_root(ctx)
|
||||||
node.derive_path(address_n)
|
node.derive_path(address_n)
|
||||||
coin = coins.by_name(coin_name)
|
coin = coins.by_name(coin_name)
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def sign_challenge(seckey: bytes,
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_sign_identity(session_id, msg):
|
async def layout_sign_identity(ctx, msg):
|
||||||
from trezor.messages.SignedIdentity import SignedIdentity
|
from trezor.messages.SignedIdentity import SignedIdentity
|
||||||
from ..common import coins
|
from ..common import coins
|
||||||
from ..common import seed
|
from ..common import seed
|
||||||
@ -92,7 +92,7 @@ async def layout_sign_identity(session_id, msg):
|
|||||||
display_identity(identity, msg.challenge_visual)
|
display_identity(identity, msg.challenge_visual)
|
||||||
|
|
||||||
address_n = get_identity_path(identity, msg.identity.index or 0)
|
address_n = get_identity_path(identity, msg.identity.index or 0)
|
||||||
node = await seed.get_root(session_id, msg.ecdsa_curve_name)
|
node = await seed.get_root(ctx, msg.ecdsa_curve_name)
|
||||||
node.derive_path(address_n)
|
node.derive_path(address_n)
|
||||||
|
|
||||||
coin = coins.by_name('Bitcoin')
|
coin = coins.by_name('Bitcoin')
|
||||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_sign_message(session_id, msg):
|
async def layout_sign_message(ctx, msg):
|
||||||
from trezor.messages.MessageSignature import MessageSignature
|
from trezor.messages.MessageSignature import MessageSignature
|
||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from ..common import coins
|
from ..common import coins
|
||||||
@ -18,7 +18,7 @@ async def layout_sign_message(session_id, msg):
|
|||||||
coin_name = msg.coin_name or 'Bitcoin'
|
coin_name = msg.coin_name or 'Bitcoin'
|
||||||
coin = coins.by_name(coin_name)
|
coin = coins.by_name(coin_name)
|
||||||
|
|
||||||
node = await seed.get_root(session_id)
|
node = await seed.get_root(ctx)
|
||||||
node.derive_path(msg.address_n)
|
node.derive_path(msg.address_n)
|
||||||
|
|
||||||
seckey = node.private_key()
|
seckey = node.private_key()
|
||||||
|
@ -3,7 +3,7 @@ from trezor import wire
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def sign_tx(session_id, msg):
|
async def sign_tx(ctx, msg):
|
||||||
from trezor.messages.RequestType import TXFINISHED
|
from trezor.messages.RequestType import TXFINISHED
|
||||||
from trezor.messages.wire_types import TxAck
|
from trezor.messages.wire_types import TxAck
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ async def sign_tx(session_id, msg):
|
|||||||
from . import signing
|
from . import signing
|
||||||
from . import layout
|
from . import layout
|
||||||
|
|
||||||
root = await seed.get_root(session_id)
|
root = await seed.get_root(ctx)
|
||||||
|
|
||||||
signer = signing.sign_tx(msg, root)
|
signer = signing.sign_tx(msg, root)
|
||||||
res = None
|
res = None
|
||||||
@ -23,13 +23,13 @@ async def sign_tx(session_id, msg):
|
|||||||
if req.__qualname__ == 'TxRequest':
|
if req.__qualname__ == 'TxRequest':
|
||||||
if req.request_type == TXFINISHED:
|
if req.request_type == TXFINISHED:
|
||||||
break
|
break
|
||||||
res = await wire.call(session_id, req, TxAck)
|
res = await ctx.call(req, TxAck)
|
||||||
elif req.__qualname__ == 'UiConfirmOutput':
|
elif req.__qualname__ == 'UiConfirmOutput':
|
||||||
res = await layout.confirm_output(session_id, req.output, req.coin)
|
res = await layout.confirm_output(ctx, req.output, req.coin)
|
||||||
elif req.__qualname__ == 'UiConfirmTotal':
|
elif req.__qualname__ == 'UiConfirmTotal':
|
||||||
res = await layout.confirm_total(session_id, req.spending, req.fee, req.coin)
|
res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
|
||||||
elif req.__qualname__ == 'UiConfirmFeeOverThreshold':
|
elif req.__qualname__ == 'UiConfirmFeeOverThreshold':
|
||||||
res = await layout.confirm_feeoverthreshold(session_id, req.fee, req.coin)
|
res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
|
||||||
else:
|
else:
|
||||||
raise TypeError('Invalid signing instruction')
|
raise TypeError('Invalid signing instruction')
|
||||||
return req
|
return req
|
||||||
|
@ -14,22 +14,22 @@ def split_address(address):
|
|||||||
return chunks(address, 17)
|
return chunks(address, 17)
|
||||||
|
|
||||||
|
|
||||||
async def confirm_output(session_id, output, coin):
|
async def confirm_output(ctx, output, coin):
|
||||||
content = Text('Confirm output', ui.ICON_RESET,
|
content = Text('Confirm output', ui.ICON_RESET,
|
||||||
ui.BOLD, format_amount(output.amount, coin),
|
ui.BOLD, format_amount(output.amount, coin),
|
||||||
ui.NORMAL, 'to',
|
ui.NORMAL, 'to',
|
||||||
ui.MONO, *split_address(output.address))
|
ui.MONO, *split_address(output.address))
|
||||||
return await confirm(session_id, content, ButtonRequestType.ConfirmOutput)
|
return await confirm(ctx, content, ButtonRequestType.ConfirmOutput)
|
||||||
|
|
||||||
|
|
||||||
async def confirm_total(session_id, spending, fee, coin):
|
async def confirm_total(ctx, spending, fee, coin):
|
||||||
content = Text('Confirm transaction', ui.ICON_RESET,
|
content = Text('Confirm transaction', ui.ICON_RESET,
|
||||||
'Sending: %s' % format_amount(spending, coin),
|
'Sending: %s' % format_amount(spending, coin),
|
||||||
'Fee: %s' % format_amount(fee, coin))
|
'Fee: %s' % format_amount(fee, coin))
|
||||||
return await hold_to_confirm(session_id, content, ButtonRequestType.SignTx)
|
return await hold_to_confirm(ctx, content, ButtonRequestType.SignTx)
|
||||||
|
|
||||||
|
|
||||||
async def confirm_feeoverthreshold(session_id, fee, coin):
|
async def confirm_feeoverthreshold(ctx, fee, coin):
|
||||||
content = Text('Confirm high fee:', ui.ICON_RESET,
|
content = Text('Confirm high fee:', ui.ICON_RESET,
|
||||||
ui.BOLD, format_amount(fee, coin))
|
ui.BOLD, format_amount(fee, coin))
|
||||||
return await confirm(session_id, content, ButtonRequestType.FeeOverThreshold)
|
return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold)
|
||||||
|
@ -3,7 +3,7 @@ from trezor.utils import unimport
|
|||||||
|
|
||||||
|
|
||||||
@unimport
|
@unimport
|
||||||
async def layout_verify_message(session_id, msg):
|
async def layout_verify_message(ctx, msg):
|
||||||
from trezor.messages.Success import Success
|
from trezor.messages.Success import Success
|
||||||
from trezor.crypto.curve import secp256k1
|
from trezor.crypto.curve import secp256k1
|
||||||
from trezor.crypto.hashlib import ripemd160, sha256
|
from trezor.crypto.hashlib import ripemd160, sha256
|
||||||
|
48
src/main.py
48
src/main.py
@ -4,26 +4,7 @@ from trezor import io
|
|||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor import main
|
from trezor import main
|
||||||
|
|
||||||
# Load applications
|
# initialize the USB stack
|
||||||
from apps.common import storage
|
|
||||||
if __debug__:
|
|
||||||
from apps import debug
|
|
||||||
from apps import homescreen
|
|
||||||
from apps import management
|
|
||||||
from apps import wallet
|
|
||||||
from apps import ethereum
|
|
||||||
from apps import fido_u2f
|
|
||||||
|
|
||||||
# Boot applications
|
|
||||||
if __debug__:
|
|
||||||
debug.boot()
|
|
||||||
homescreen.boot()
|
|
||||||
management.boot()
|
|
||||||
wallet.boot()
|
|
||||||
ethereum.boot()
|
|
||||||
fido_u2f.boot()
|
|
||||||
|
|
||||||
# Intialize the USB stack
|
|
||||||
usb_wire = io.HID(
|
usb_wire = io.HID(
|
||||||
iface_num=0x00,
|
iface_num=0x00,
|
||||||
ep_in=0x81,
|
ep_in=0x81,
|
||||||
@ -90,11 +71,30 @@ usb.add(usb_vcp)
|
|||||||
usb.add(usb_u2f)
|
usb.add(usb_u2f)
|
||||||
usb.open()
|
usb.open()
|
||||||
|
|
||||||
# Initialize the wire codec pipeline
|
# load applications
|
||||||
wire.setup(usb_wire.iface_num())
|
from apps.common import storage
|
||||||
|
if __debug__:
|
||||||
|
from apps import debug
|
||||||
|
from apps import homescreen
|
||||||
|
from apps import management
|
||||||
|
from apps import wallet
|
||||||
|
from apps import ethereum
|
||||||
|
from apps import fido_u2f
|
||||||
|
|
||||||
# Load default homescreen
|
# boot applications
|
||||||
|
if __debug__:
|
||||||
|
debug.boot()
|
||||||
|
homescreen.boot()
|
||||||
|
management.boot()
|
||||||
|
wallet.boot()
|
||||||
|
ethereum.boot()
|
||||||
|
fido_u2f.boot(usb_u2f)
|
||||||
|
|
||||||
|
# initialize the wire codec pipeline
|
||||||
|
wire.setup(usb_wire)
|
||||||
|
|
||||||
|
# load default homescreen
|
||||||
from apps.homescreen.homescreen import layout_homescreen
|
from apps.homescreen.homescreen import layout_homescreen
|
||||||
|
|
||||||
# Run main even loop and specify which screen is default
|
# run main even loop and specify which screen is default
|
||||||
main.run(default_workflow=layout_homescreen)
|
main.run(default_workflow=layout_homescreen)
|
||||||
|
@ -14,7 +14,7 @@ async def load_uvarint(reader):
|
|||||||
shift = 0
|
shift = 0
|
||||||
byte = 0x80
|
byte = 0x80
|
||||||
while byte & 0x80:
|
while byte & 0x80:
|
||||||
await reader.readinto(buffer)
|
await reader.areadinto(buffer)
|
||||||
byte = buffer[0]
|
byte = buffer[0]
|
||||||
result += (byte & 0x7F) << shift
|
result += (byte & 0x7F) << shift
|
||||||
shift += 7
|
shift += 7
|
||||||
@ -27,7 +27,7 @@ async def dump_uvarint(writer, n):
|
|||||||
while shifted:
|
while shifted:
|
||||||
shifted = n >> 7
|
shifted = n >> 7
|
||||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||||
await writer.write(buffer)
|
await writer.awrite(buffer)
|
||||||
n = shifted
|
n = shifted
|
||||||
|
|
||||||
|
|
||||||
@ -69,11 +69,11 @@ class LimitedReader:
|
|||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
|
||||||
async def readinto(self, buf):
|
async def areadinto(self, buf):
|
||||||
if self.limit < len(buf):
|
if self.limit < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
else:
|
else:
|
||||||
nread = await self.reader.readinto(buf)
|
nread = await self.reader.areadinto(buf)
|
||||||
self.limit -= nread
|
self.limit -= nread
|
||||||
return nread
|
return nread
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class CountingWriter:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.size = 0
|
self.size = 0
|
||||||
|
|
||||||
async def write(self, buf):
|
async def awrite(self, buf):
|
||||||
nwritten = len(buf)
|
nwritten = len(buf)
|
||||||
self.size += nwritten
|
self.size += nwritten
|
||||||
return nwritten
|
return nwritten
|
||||||
@ -112,7 +112,7 @@ async def load_message(reader, msg_type):
|
|||||||
await load_uvarint(reader)
|
await load_uvarint(reader)
|
||||||
elif wtype == 2:
|
elif wtype == 2:
|
||||||
ivalue = await load_uvarint(reader)
|
ivalue = await load_uvarint(reader)
|
||||||
await reader.readinto(bytearray(ivalue))
|
await reader.areadinto(bytearray(ivalue))
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
continue
|
continue
|
||||||
@ -129,10 +129,10 @@ async def load_message(reader, msg_type):
|
|||||||
fvalue = bool(ivalue)
|
fvalue = bool(ivalue)
|
||||||
elif ftype is BytesType:
|
elif ftype is BytesType:
|
||||||
fvalue = bytearray(ivalue)
|
fvalue = bytearray(ivalue)
|
||||||
await reader.readinto(fvalue)
|
await reader.areadinto(fvalue)
|
||||||
elif ftype is UnicodeType:
|
elif ftype is UnicodeType:
|
||||||
fvalue = bytearray(ivalue)
|
fvalue = bytearray(ivalue)
|
||||||
await reader.readinto(fvalue)
|
await reader.areadinto(fvalue)
|
||||||
fvalue = str(fvalue, 'utf8')
|
fvalue = str(fvalue, 'utf8')
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
||||||
@ -186,11 +186,11 @@ async def dump_message(writer, msg):
|
|||||||
|
|
||||||
elif ftype is BytesType:
|
elif ftype is BytesType:
|
||||||
await dump_uvarint(writer, len(svalue))
|
await dump_uvarint(writer, len(svalue))
|
||||||
await writer.write(svalue)
|
await writer.awrite(svalue)
|
||||||
|
|
||||||
elif ftype is UnicodeType:
|
elif ftype is UnicodeType:
|
||||||
await dump_uvarint(writer, len(svalue))
|
await dump_uvarint(writer, len(svalue))
|
||||||
await writer.write(bytes(svalue, 'utf8'))
|
await writer.awrite(bytes(svalue, 'utf8'))
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
counter = CountingWriter()
|
counter = CountingWriter()
|
||||||
|
@ -8,58 +8,82 @@ from trezor import workflow
|
|||||||
from . import codec_v1
|
from . import codec_v1
|
||||||
from . import codec_v2
|
from . import codec_v2
|
||||||
|
|
||||||
workflows = {}
|
workflow_handlers = {}
|
||||||
|
|
||||||
|
|
||||||
def register(wire_type, handler, *args):
|
def register(mtype, handler, *args):
|
||||||
if wire_type in workflows:
|
'''Register `handler` to get scheduled after `mtype` message is received.'''
|
||||||
|
if mtype in workflow_handlers:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
workflows[wire_type] = (handler, args)
|
workflow_handlers[mtype] = (handler, args)
|
||||||
|
|
||||||
|
|
||||||
def setup(interface):
|
def setup(iface):
|
||||||
session_supervisor = codec_v2.SesssionSupervisor(interface,
|
'''Initialize the wire stack on passed USB interface.'''
|
||||||
session_handler)
|
session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler)
|
||||||
session_supervisor.open(codec_v1.SESSION_ID)
|
session_supervisor.open(codec_v1.SESSION_ID)
|
||||||
loop.schedule_task(session_supervisor.listen())
|
loop.schedule_task(session_supervisor.listen())
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, interface, session_id):
|
def __init__(self, iface, sid):
|
||||||
self.interface = interface
|
self.iface = iface
|
||||||
self.session_id = session_id
|
self.sid = sid
|
||||||
|
|
||||||
def get_reader(self):
|
|
||||||
if self.session_id == codec_v1.SESSION_ID:
|
|
||||||
return codec_v1.Reader(self.interface)
|
|
||||||
else:
|
|
||||||
return codec_v2.Reader(self.interface, self.session_id)
|
|
||||||
|
|
||||||
def get_writer(self, mtype, msize):
|
|
||||||
if self.session_id == codec_v1.SESSION_ID:
|
|
||||||
return codec_v1.Writer(self.interface, mtype, msize)
|
|
||||||
else:
|
|
||||||
return codec_v2.Writer(self.interface, self.session_id, mtype, msize)
|
|
||||||
|
|
||||||
async def read(self, types):
|
|
||||||
reader = self.get_reader()
|
|
||||||
await reader.open()
|
|
||||||
if reader.type not in types:
|
|
||||||
raise UnexpectedMessageError(reader)
|
|
||||||
return await protobuf.load_message(reader,
|
|
||||||
messages.get_type(reader.type))
|
|
||||||
|
|
||||||
async def write(self, msg):
|
|
||||||
counter = protobuf.CountingWriter()
|
|
||||||
await protobuf.dump_message(counter, msg)
|
|
||||||
writer = self.get_writer(msg.MESSAGE_WIRE_TYPE, counter.size)
|
|
||||||
await protobuf.dump_message(writer, msg)
|
|
||||||
await writer.close()
|
|
||||||
|
|
||||||
async def call(self, msg, types):
|
async def call(self, msg, types):
|
||||||
|
'''
|
||||||
|
Reply with `msg` and wait for one of `types`. See `self.write()` and
|
||||||
|
`self.read()`.
|
||||||
|
'''
|
||||||
await self.write(msg)
|
await self.write(msg)
|
||||||
return await self.read(types)
|
return await self.read(types)
|
||||||
|
|
||||||
|
async def read(self, types):
|
||||||
|
'''
|
||||||
|
Wait for incoming message on this wire context and return it. Raises
|
||||||
|
`UnexpectedMessageError` if the message type does not match one of
|
||||||
|
`types`; and caller should always make sure to re-raise it.
|
||||||
|
'''
|
||||||
|
reader = self.getreader()
|
||||||
|
|
||||||
|
await reader.aopen() # wait for the message header
|
||||||
|
|
||||||
|
# if we got a message with unexpected type, raise the reader via
|
||||||
|
# `UnexpectedMessageError` and let the session handler deal with it
|
||||||
|
if reader.type not in types:
|
||||||
|
raise UnexpectedMessageError(reader)
|
||||||
|
|
||||||
|
# look up the protobuf class and parse the message
|
||||||
|
pbtype = messages.get_type(reader.type)
|
||||||
|
return await protobuf.load_message(reader, pbtype)
|
||||||
|
|
||||||
|
async def write(self, msg):
|
||||||
|
'''
|
||||||
|
Write a protobuf message to this wire context.
|
||||||
|
'''
|
||||||
|
writer = self.getwriter()
|
||||||
|
|
||||||
|
# get the message size
|
||||||
|
counter = protobuf.CountingWriter()
|
||||||
|
await protobuf.dump_message(counter, msg)
|
||||||
|
|
||||||
|
# write the message
|
||||||
|
writer.setheader(msg.MESSAGE_WIRE_TYPE, counter.size)
|
||||||
|
await protobuf.dump_message(writer, msg)
|
||||||
|
await writer.aclose()
|
||||||
|
|
||||||
|
def getreader(self):
|
||||||
|
if self.sid == codec_v1.SESSION_ID:
|
||||||
|
return codec_v1.Reader(self.iface)
|
||||||
|
else:
|
||||||
|
return codec_v2.Reader(self.iface, self.sid)
|
||||||
|
|
||||||
|
def getwriter(self):
|
||||||
|
if self.sid == codec_v1.SESSION_ID:
|
||||||
|
return codec_v1.Writer(self.iface)
|
||||||
|
else:
|
||||||
|
return codec_v2.Writer(self.iface, self.sid)
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessageError(Exception):
|
class UnexpectedMessageError(Exception):
|
||||||
def __init__(self, reader):
|
def __init__(self, reader):
|
||||||
@ -74,60 +98,69 @@ class FailureError(Exception):
|
|||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
class Workflow:
|
async def session_handler(iface, sid):
|
||||||
def __init__(self, default):
|
reader = None
|
||||||
self.handlers = {}
|
ctx = Context(iface, sid)
|
||||||
self.default = default
|
while True:
|
||||||
|
try:
|
||||||
async def __call__(self, interface, session_id):
|
# wait for new message, if needed, and find handler
|
||||||
ctx = Context(interface, session_id)
|
if not reader:
|
||||||
while True:
|
reader = ctx.getreader()
|
||||||
|
await reader.aopen()
|
||||||
try:
|
try:
|
||||||
reader = ctx.get_reader()
|
handler, args = workflow_handlers[reader.type]
|
||||||
await reader.open()
|
except KeyError:
|
||||||
try:
|
handler, args = unexpected_msg, ()
|
||||||
handler = self.handlers[reader.type]
|
|
||||||
except KeyError:
|
await handler(ctx, reader, *args)
|
||||||
handler = self.default
|
|
||||||
try:
|
except UnexpectedMessageError as exc:
|
||||||
await handler(ctx, reader)
|
# retry with opened reader from the exception
|
||||||
except UnexpectedMessageError as unexp_msg:
|
reader = exc.reader
|
||||||
reader = unexp_msg.reader
|
continue
|
||||||
except Exception as e:
|
except FailureError as exc:
|
||||||
log.exception(__name__, e)
|
# we log FailureError as warning, not as exception
|
||||||
|
log.warning(__name__, 'failure: %s', exc.message)
|
||||||
|
except Exception as exc:
|
||||||
|
# sessions are never closed by raised exceptions
|
||||||
|
log.exception(__name__, exc)
|
||||||
|
|
||||||
|
# read new message in next iteration
|
||||||
|
reader = None
|
||||||
|
|
||||||
|
|
||||||
async def protobuf_workflow(ctx, reader, handler, *args):
|
async def protobuf_workflow(ctx, reader, handler, *args):
|
||||||
msg = await protobuf.load_message(reader, messages.get_type(reader.type))
|
from trezor.messages.Failure import Failure
|
||||||
|
from trezor.messages.FailureType import FirmwareError
|
||||||
|
|
||||||
|
req = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||||||
try:
|
try:
|
||||||
res = await handler(reader.sid, msg, *args)
|
res = await handler(ctx, req, *args)
|
||||||
except Exception as exc:
|
except UnexpectedMessageError:
|
||||||
if not isinstance(exc, UnexpectedMessageError):
|
# session handler takes care of this one
|
||||||
await ctx.write(make_failure_msg(exc))
|
|
||||||
raise
|
raise
|
||||||
else:
|
except FailureError as exc:
|
||||||
if res:
|
# respond with specific code and message
|
||||||
await ctx.write(res)
|
await ctx.write(Failure(code=exc.code, message=exc.message))
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
# respond with a generic code and message
|
||||||
|
await ctx.write(Failure(code=FirmwareError, message='Firmware error'))
|
||||||
|
raise
|
||||||
|
if res:
|
||||||
|
# respond with a specific response
|
||||||
|
await ctx.write(res)
|
||||||
|
|
||||||
|
|
||||||
async def handle_unexp_msg(ctx, reader):
|
async def unexpected_msg(ctx, reader):
|
||||||
|
from trezor.messages.Failure import Failure
|
||||||
|
from trezor.messages.FailureType import UnexpectedMessage
|
||||||
|
|
||||||
# receive the message and throw it away
|
# receive the message and throw it away
|
||||||
while reader.size > 0:
|
while reader.size > 0:
|
||||||
buf = bytearray(reader.size)
|
buf = bytearray(reader.size)
|
||||||
await reader.readinto(buf)
|
await reader.areadinto(buf)
|
||||||
|
|
||||||
# respond with an unknown message error
|
# respond with an unknown message error
|
||||||
from trezor.messages.Failure import Failure
|
|
||||||
from trezor.messages.FailureType import UnexpectedMessage
|
|
||||||
await ctx.write(
|
await ctx.write(
|
||||||
Failure(code=UnexpectedMessage, message='Unexpected message'))
|
Failure(code=UnexpectedMessage, message='Unexpected message'))
|
||||||
|
|
||||||
def make_failure_msg(exc):
|
|
||||||
from trezor.messages.Failure import Failure
|
|
||||||
from trezor.messages.FailureType import FirmwareError
|
|
||||||
if isinstance(exc, FailureError):
|
|
||||||
code = exc.code
|
|
||||||
message = exc.message
|
|
||||||
else:
|
|
||||||
code = FirmwareError
|
|
||||||
message = 'Firmware Error'
|
|
||||||
return Failure(code=code, message=message)
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
import ustruct
|
import ustruct
|
||||||
|
|
||||||
from trezor import io
|
|
||||||
from trezor import loop
|
from trezor import loop
|
||||||
from trezor import utils
|
from trezor import utils
|
||||||
|
|
||||||
@ -32,13 +31,13 @@ class Reader:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
|
return '<ReaderV1: type=%d size=%dB>' % (self.type, self.size)
|
||||||
|
|
||||||
async def open(self):
|
async def aopen(self):
|
||||||
'''
|
'''
|
||||||
Begin the message transmission by waiting for initial V2 message report
|
Begin the message transmission by waiting for initial V2 message report
|
||||||
on this session. `self.type` and `self.size` are initialized and
|
on this session. `self.type` and `self.size` are initialized and
|
||||||
available after `open()` returns.
|
available after `aopen()` returns.
|
||||||
'''
|
'''
|
||||||
read = loop.select(self.iface | loop.READ)
|
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||||
while True:
|
while True:
|
||||||
# wait for initial report
|
# wait for initial report
|
||||||
report = await read
|
report = await read
|
||||||
@ -55,7 +54,7 @@ class Reader:
|
|||||||
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
self.data = report[_REP_INIT_DATA:_REP_INIT_DATA + msize]
|
||||||
self.ofs = 0
|
self.ofs = 0
|
||||||
|
|
||||||
async def readinto(self, buf):
|
async def areadinto(self, buf):
|
||||||
'''
|
'''
|
||||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||||
@ -64,7 +63,7 @@ class Reader:
|
|||||||
if self.size < len(buf):
|
if self.size < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
||||||
read = loop.select(self.iface | loop.READ)
|
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||||
nread = 0
|
nread = 0
|
||||||
while nread < len(buf):
|
while nread < len(buf):
|
||||||
if self.ofs == len(self.data):
|
if self.ofs == len(self.data):
|
||||||
@ -93,20 +92,28 @@ class Writer:
|
|||||||
async-file-like interface.
|
async-file-like interface.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, iface, mtype, msize):
|
def __init__(self, iface):
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.type = mtype
|
self.type = None
|
||||||
self.size = msize
|
self.size = None
|
||||||
self.data = bytearray(_REP_LEN)
|
self.data = bytearray(_REP_LEN)
|
||||||
self.ofs = _REP_INIT_DATA
|
self.ofs = 0
|
||||||
|
|
||||||
# load the report with initial header
|
|
||||||
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
|
return '<WriterV2: type=%d size=%dB>' % (self.type, self.size)
|
||||||
|
|
||||||
async def write(self, buf):
|
def setheader(self, mtype, msize):
|
||||||
|
'''
|
||||||
|
Reset the writer state and load the message header with passed type and
|
||||||
|
total message size.
|
||||||
|
'''
|
||||||
|
self.type = mtype
|
||||||
|
self.size = msize
|
||||||
|
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC,
|
||||||
|
_REP_MAGIC, mtype, msize)
|
||||||
|
self.ofs = _REP_INIT_DATA
|
||||||
|
|
||||||
|
async def awrite(self, buf):
|
||||||
'''
|
'''
|
||||||
Encode and write every byte from `buf`. Does not need to be called in
|
Encode and write every byte from `buf`. Does not need to be called in
|
||||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||||
@ -115,7 +122,7 @@ class Writer:
|
|||||||
if self.size < len(buf):
|
if self.size < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
||||||
write = loop.select(self.iface | loop.WRITE)
|
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||||
nwritten = 0
|
nwritten = 0
|
||||||
while nwritten < len(buf):
|
while nwritten < len(buf):
|
||||||
# copy as much as possible to report buffer
|
# copy as much as possible to report buffer
|
||||||
@ -127,12 +134,12 @@ class Writer:
|
|||||||
if self.ofs == _REP_LEN:
|
if self.ofs == _REP_LEN:
|
||||||
# we are at the end of the report, flush it
|
# we are at the end of the report, flush it
|
||||||
await write
|
await write
|
||||||
io.write(self.iface, self.data)
|
self.iface.write(self.data)
|
||||||
self.ofs = _REP_CONT_DATA
|
self.ofs = _REP_CONT_DATA
|
||||||
|
|
||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
async def close(self):
|
async def aclose(self):
|
||||||
'''Flush and close the message transmission.'''
|
'''Flush and close the message transmission.'''
|
||||||
if self.ofs != _REP_CONT_DATA:
|
if self.ofs != _REP_CONT_DATA:
|
||||||
# we didn't write anything or last write() wasn't report-aligned,
|
# we didn't write anything or last write() wasn't report-aligned,
|
||||||
@ -141,5 +148,5 @@ class Writer:
|
|||||||
self.data[self.ofs] = 0x00
|
self.data[self.ofs] = 0x00
|
||||||
self.ofs += 1
|
self.ofs += 1
|
||||||
|
|
||||||
await loop.select(self.iface | loop.WRITE)
|
await loop.select(self.iface.iface_num() | loop.WRITE)
|
||||||
io.send(self.iface, self.data)
|
self.iface.write(self.data)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
import ustruct
|
import ustruct
|
||||||
|
|
||||||
from trezor import io
|
|
||||||
from trezor import loop
|
from trezor import loop
|
||||||
from trezor import utils
|
from trezor import utils
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
@ -28,11 +27,11 @@ _REP_MARKER_CONT = const(0x02)
|
|||||||
_REP_MARKER_OPEN = const(0x03)
|
_REP_MARKER_OPEN = const(0x03)
|
||||||
_REP_MARKER_CLOSE = const(0x04)
|
_REP_MARKER_CLOSE = const(0x04)
|
||||||
|
|
||||||
_REP = '>BL' # marker, session_id
|
_REP = '>BL' # marker, session_id
|
||||||
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
|
_REP_INIT = '>BLLL' # marker, session_id, message_type, message_size
|
||||||
_REP_CONT = '>BLL' # marker, session_id, sequence
|
_REP_CONT = '>BLL' # marker, session_id, sequence
|
||||||
_REP_INIT_DATA = const(13) # offset of data in init report
|
_REP_INIT_DATA = const(13) # offset of data in init report
|
||||||
_REP_CONT_DATA = const(9) # offset of data in cont report
|
_REP_CONT_DATA = const(9) # offset of data in cont report
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
class Reader:
|
||||||
@ -51,15 +50,16 @@ class Reader:
|
|||||||
self.seq = 0
|
self.seq = 0
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type, self.size)
|
return '<Reader: sid=%x type=%d size=%dB>' % (self.sid, self.type,
|
||||||
|
self.size)
|
||||||
|
|
||||||
async def open(self):
|
async def aopen(self):
|
||||||
'''
|
'''
|
||||||
Begin the message transmission by waiting for initial V2 message report
|
Begin the message transmission by waiting for initial V2 message report
|
||||||
on this session. `self.type` and `self.size` are initialized and
|
on this session. `self.type` and `self.size` are initialized and
|
||||||
available after `open()` returns.
|
available after `aopen()` returns.
|
||||||
'''
|
'''
|
||||||
read = loop.select(self.iface | loop.READ)
|
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||||
while True:
|
while True:
|
||||||
# wait for initial report
|
# wait for initial report
|
||||||
report = await read
|
report = await read
|
||||||
@ -74,7 +74,7 @@ class Reader:
|
|||||||
self.ofs = 0
|
self.ofs = 0
|
||||||
self.seq = 0
|
self.seq = 0
|
||||||
|
|
||||||
async def readinto(self, buf):
|
async def areadinto(self, buf):
|
||||||
'''
|
'''
|
||||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||||
@ -83,7 +83,7 @@ class Reader:
|
|||||||
if self.size < len(buf):
|
if self.size < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
||||||
read = loop.select(self.iface | loop.READ)
|
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||||
nread = 0
|
nread = 0
|
||||||
while nread < len(buf):
|
while nread < len(buf):
|
||||||
if self.ofs == len(self.data):
|
if self.ofs == len(self.data):
|
||||||
@ -115,20 +115,28 @@ class Writer:
|
|||||||
interface.
|
interface.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, iface, sid, mtype, msize):
|
def __init__(self, iface, sid):
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
|
self.type = None
|
||||||
|
self.size = None
|
||||||
|
self.data = bytearray(_REP_LEN)
|
||||||
|
self.ofs = 0
|
||||||
|
self.seq = 0
|
||||||
|
|
||||||
|
def setheader(self, mtype, msize):
|
||||||
|
'''
|
||||||
|
Reset the writer state and load the message header with passed type and
|
||||||
|
total message size.
|
||||||
|
'''
|
||||||
self.type = mtype
|
self.type = mtype
|
||||||
self.size = msize
|
self.size = msize
|
||||||
self.data = bytearray(_REP_LEN)
|
ustruct.pack_into(_REP_INIT, self.data, 0, _REP_MARKER_INIT, self.sid,
|
||||||
|
mtype, msize)
|
||||||
self.ofs = _REP_INIT_DATA
|
self.ofs = _REP_INIT_DATA
|
||||||
self.seq = 0
|
self.seq = 0
|
||||||
|
|
||||||
# load the report with initial header
|
async def awrite(self, buf):
|
||||||
ustruct.pack_into(_REP_INIT, self.data, 0,
|
|
||||||
_REP_MARKER_INIT, sid, mtype, msize)
|
|
||||||
|
|
||||||
async def write(self, buf):
|
|
||||||
'''
|
'''
|
||||||
Encode and write every byte from `buf`. Does not need to be called in
|
Encode and write every byte from `buf`. Does not need to be called in
|
||||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||||
@ -137,7 +145,7 @@ class Writer:
|
|||||||
if self.size < len(buf):
|
if self.size < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
|
||||||
write = loop.select(self.iface | loop.WRITE)
|
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||||
nwritten = 0
|
nwritten = 0
|
||||||
while nwritten < len(buf):
|
while nwritten < len(buf):
|
||||||
# copy as much as possible to report buffer
|
# copy as much as possible to report buffer
|
||||||
@ -149,15 +157,15 @@ class Writer:
|
|||||||
if self.ofs == _REP_LEN:
|
if self.ofs == _REP_LEN:
|
||||||
# we are at the end of the report, flush it, and prepare header
|
# we are at the end of the report, flush it, and prepare header
|
||||||
await write
|
await write
|
||||||
io.send(self.iface, self.data)
|
self.iface.write(self.data)
|
||||||
ustruct.pack_into(_REP_CONT, self.data, 0,
|
ustruct.pack_into(_REP_CONT, self.data, 0, _REP_MARKER_CONT,
|
||||||
_REP_MARKER_CONT, self.sid, self.seq)
|
self.sid, self.seq)
|
||||||
self.ofs = _REP_CONT_DATA
|
self.ofs = _REP_CONT_DATA
|
||||||
self.seq += 1
|
self.seq += 1
|
||||||
|
|
||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
async def close(self):
|
async def aclose(self):
|
||||||
'''Flush and close the message transmission.'''
|
'''Flush and close the message transmission.'''
|
||||||
if self.ofs != _REP_CONT_DATA:
|
if self.ofs != _REP_CONT_DATA:
|
||||||
# we didn't write anything or last write() wasn't report-aligned,
|
# we didn't write anything or last write() wasn't report-aligned,
|
||||||
@ -166,8 +174,8 @@ class Writer:
|
|||||||
self.data[self.ofs] = 0x00
|
self.data[self.ofs] = 0x00
|
||||||
self.ofs += 1
|
self.ofs += 1
|
||||||
|
|
||||||
await loop.select(self.iface | loop.WRITE)
|
await loop.select(self.iface.iface_num() | loop.WRITE)
|
||||||
io.send(self.iface, self.data)
|
self.iface.write(self.data)
|
||||||
|
|
||||||
|
|
||||||
class SesssionSupervisor:
|
class SesssionSupervisor:
|
||||||
@ -186,8 +194,8 @@ class SesssionSupervisor:
|
|||||||
After close request, the handling task is closed and session terminated.
|
After close request, the handling task is closed and session terminated.
|
||||||
Both requests receive responses confirming the operation.
|
Both requests receive responses confirming the operation.
|
||||||
'''
|
'''
|
||||||
read = loop.select(self.iface | loop.READ)
|
read = loop.select(self.iface.iface_num() | loop.READ)
|
||||||
write = loop.select(self.iface | loop.WRITE)
|
write = loop.select(self.iface.iface_num() | loop.WRITE)
|
||||||
while True:
|
while True:
|
||||||
report = await read
|
report = await read
|
||||||
repmarker, repsid = ustruct.unpack(_REP, report)
|
repmarker, repsid = ustruct.unpack(_REP, report)
|
||||||
@ -225,8 +233,8 @@ class SesssionSupervisor:
|
|||||||
|
|
||||||
def writeopen(self, sid):
|
def writeopen(self, sid):
|
||||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
|
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_OPEN, sid)
|
||||||
io.write(self.iface, self.session_report)
|
self.iface.write(self.session_report)
|
||||||
|
|
||||||
def writeclose(self, sid):
|
def writeclose(self, sid):
|
||||||
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
|
ustruct.pack_into(_REP, self.session_report, 0, _REP_MARKER_CLOSE, sid)
|
||||||
io.write(self.iface, self.session_report)
|
self.iface.write(self.session_report)
|
||||||
|
@ -26,26 +26,26 @@ def test_reader():
|
|||||||
|
|
||||||
# open, expected one read
|
# open, expected one read
|
||||||
first_report = report_header + message[:rep_len - len(report_header)]
|
first_report = report_header + message[:rep_len - len(report_header)]
|
||||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||||
assert_eq(reader.type, message_type)
|
assert_eq(reader.type, message_type)
|
||||||
assert_eq(reader.size, message_len)
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
# empty read
|
# empty read
|
||||||
empty_buffer = bytearray()
|
empty_buffer = bytearray()
|
||||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(len(empty_buffer), 0)
|
assert_eq(len(empty_buffer), 0)
|
||||||
assert_eq(reader.size, message_len)
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
# short read, expected no read
|
# short read, expected no read
|
||||||
short_buffer = bytearray(32)
|
short_buffer = bytearray(32)
|
||||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(len(short_buffer), 32)
|
assert_eq(len(short_buffer), 32)
|
||||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer))
|
||||||
|
|
||||||
# aligned read, expected no read
|
# aligned read, expected no read
|
||||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||||
|
|
||||||
@ -53,12 +53,12 @@ def test_reader():
|
|||||||
next_report_header = bytearray(unhexlify('3f'))
|
next_report_header = bytearray(unhexlify('3f'))
|
||||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||||
onebyte_buffer = bytearray(1)
|
onebyte_buffer = bytearray(1)
|
||||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||||
|
|
||||||
# too long read, raises eof
|
# too long read, raises eof
|
||||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||||
|
|
||||||
# long read, expect multiple reads
|
# long read, expect multiple reads
|
||||||
start_size = reader.size
|
start_size = reader.size
|
||||||
@ -74,12 +74,12 @@ def test_reader():
|
|||||||
prev_report = next_reports[i - 1] if i > 0 else None
|
prev_report = next_reports[i - 1] if i > 0 else None
|
||||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
assert_async(reader.areadinto(long_buffer), expected_syscalls)
|
||||||
assert_eq(long_buffer, message[-start_size:])
|
assert_eq(long_buffer, message[-start_size:])
|
||||||
assert_eq(reader.size, 0)
|
assert_eq(reader.size, 0)
|
||||||
|
|
||||||
# one byte read, raises eof
|
# one byte read, raises eof
|
||||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
|
||||||
|
|
||||||
|
|
||||||
def test_writer():
|
def test_writer():
|
||||||
@ -87,7 +87,8 @@ def test_writer():
|
|||||||
interface = 0xdeadbeef
|
interface = 0xdeadbeef
|
||||||
message_type = 0x87654321
|
message_type = 0x87654321
|
||||||
message_len = 1024
|
message_len = 1024
|
||||||
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID, message_type, message_len)
|
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID)
|
||||||
|
writer.setheader(message_type, message_len)
|
||||||
|
|
||||||
# init header corresponding to the data above
|
# init header corresponding to the data above
|
||||||
report_header = bytearray(unhexlify('3f2323432100000400'))
|
report_header = bytearray(unhexlify('3f2323432100000400'))
|
||||||
@ -96,14 +97,14 @@ def test_writer():
|
|||||||
|
|
||||||
# empty write
|
# empty write
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
|
||||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
assert_eq(writer.size, start_size)
|
assert_eq(writer.size, start_size)
|
||||||
|
|
||||||
# short write, expected no report
|
# short write, expected no report
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
short_payload = bytearray(range(4))
|
short_payload = bytearray(range(4))
|
||||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data,
|
assert_eq(writer.data,
|
||||||
report_header
|
report_header
|
||||||
@ -118,7 +119,7 @@ def test_writer():
|
|||||||
+ short_payload
|
+ short_payload
|
||||||
+ aligned_payload
|
+ aligned_payload
|
||||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||||
msg.send.assert_called_n_times(1)
|
msg.send.assert_called_n_times(1)
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
@ -126,7 +127,7 @@ def test_writer():
|
|||||||
# short write, expected no report, but data starts with correct seq and cont marker
|
# short write, expected no report, but data starts with correct seq and cont marker
|
||||||
report_header = bytearray(unhexlify('3f'))
|
report_header = bytearray(unhexlify('3f'))
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||||
report_header + short_payload)
|
report_header + short_payload)
|
||||||
@ -142,19 +143,19 @@ def test_writer():
|
|||||||
# test write
|
# test write
|
||||||
expected_write_reports = expected_reports[:-1]
|
expected_write_reports = expected_reports[:-1]
|
||||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
assert_eq(writer.size, start_size - len(long_payload))
|
assert_eq(writer.size, start_size - len(long_payload))
|
||||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
# test write raises eof
|
# test write raises eof
|
||||||
msg.send = mock_call(msg.send, [])
|
msg.send = mock_call(msg.send, [])
|
||||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
|
||||||
msg.send.assert_called_n_times(0)
|
msg.send.assert_called_n_times(0)
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
# test close
|
# test close
|
||||||
expected_close_reports = expected_reports[-1:]
|
expected_close_reports = expected_reports[-1:]
|
||||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
assert_eq(writer.size, 0)
|
assert_eq(writer.size, 0)
|
||||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
|
@ -26,26 +26,26 @@ def test_reader():
|
|||||||
|
|
||||||
# open, expected one read
|
# open, expected one read
|
||||||
first_report = report_header + message[:rep_len - len(report_header)]
|
first_report = report_header + message[:rep_len - len(report_header)]
|
||||||
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
||||||
assert_eq(reader.type, message_type)
|
assert_eq(reader.type, message_type)
|
||||||
assert_eq(reader.size, message_len)
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
# empty read
|
# empty read
|
||||||
empty_buffer = bytearray()
|
empty_buffer = bytearray()
|
||||||
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(len(empty_buffer), 0)
|
assert_eq(len(empty_buffer), 0)
|
||||||
assert_eq(reader.size, message_len)
|
assert_eq(reader.size, message_len)
|
||||||
|
|
||||||
# short read, expected no read
|
# short read, expected no read
|
||||||
short_buffer = bytearray(32)
|
short_buffer = bytearray(32)
|
||||||
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(len(short_buffer), 32)
|
assert_eq(len(short_buffer), 32)
|
||||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
assert_eq(short_buffer, message[:len(short_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer))
|
||||||
|
|
||||||
# aligned read, expected no read
|
# aligned read, expected no read
|
||||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||||
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()),])
|
||||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||||
|
|
||||||
@ -53,12 +53,12 @@ def test_reader():
|
|||||||
next_report_header = bytearray(unhexlify('021234567800000000'))
|
next_report_header = bytearray(unhexlify('021234567800000000'))
|
||||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||||
onebyte_buffer = bytearray(1)
|
onebyte_buffer = bytearray(1)
|
||||||
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
||||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||||
|
|
||||||
# too long read, raises eof
|
# too long read, raises eof
|
||||||
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
||||||
|
|
||||||
# long read, expect multiple reads
|
# long read, expect multiple reads
|
||||||
start_size = reader.size
|
start_size = reader.size
|
||||||
@ -74,12 +74,12 @@ def test_reader():
|
|||||||
prev_report = next_reports[i - 1] if i > 0 else None
|
prev_report = next_reports[i - 1] if i > 0 else None
|
||||||
expected_syscalls.append((prev_report, Select(READ | interface)))
|
expected_syscalls.append((prev_report, Select(READ | interface)))
|
||||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||||
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
assert_async(reader.areadinto(long_buffer), expected_syscalls)
|
||||||
assert_eq(long_buffer, message[-start_size:])
|
assert_eq(long_buffer, message[-start_size:])
|
||||||
assert_eq(reader.size, 0)
|
assert_eq(reader.size, 0)
|
||||||
|
|
||||||
# one byte read, raises eof
|
# one byte read, raises eof
|
||||||
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()),])
|
||||||
|
|
||||||
|
|
||||||
def test_writer():
|
def test_writer():
|
||||||
@ -88,7 +88,8 @@ def test_writer():
|
|||||||
session_id = 0x12345678
|
session_id = 0x12345678
|
||||||
message_type = 0x87654321
|
message_type = 0x87654321
|
||||||
message_len = 1024
|
message_len = 1024
|
||||||
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
|
writer = codec_v2.Writer(interface, session_id)
|
||||||
|
writer.setheader(message_type, message_len)
|
||||||
|
|
||||||
# init header corresponding to the data above
|
# init header corresponding to the data above
|
||||||
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
||||||
@ -97,14 +98,14 @@ def test_writer():
|
|||||||
|
|
||||||
# empty write
|
# empty write
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
assert_async(writer.awrite(bytearray()), [(None, StopIteration()),])
|
||||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
assert_eq(writer.size, start_size)
|
assert_eq(writer.size, start_size)
|
||||||
|
|
||||||
# short write, expected no report
|
# short write, expected no report
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
short_payload = bytearray(range(4))
|
short_payload = bytearray(range(4))
|
||||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data,
|
assert_eq(writer.data,
|
||||||
report_header
|
report_header
|
||||||
@ -119,7 +120,7 @@ def test_writer():
|
|||||||
+ short_payload
|
+ short_payload
|
||||||
+ aligned_payload
|
+ aligned_payload
|
||||||
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
||||||
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
assert_eq(writer.size, start_size - len(aligned_payload))
|
||||||
msg.send.assert_called_n_times(1)
|
msg.send.assert_called_n_times(1)
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
@ -127,7 +128,7 @@ def test_writer():
|
|||||||
# short write, expected no report, but data starts with correct seq and cont marker
|
# short write, expected no report, but data starts with correct seq and cont marker
|
||||||
report_header = bytearray(unhexlify('021234567800000000'))
|
report_header = bytearray(unhexlify('021234567800000000'))
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
assert_async(writer.awrite(short_payload), [(None, StopIteration()),])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
assert_eq(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
||||||
report_header + short_payload)
|
report_header + short_payload)
|
||||||
@ -145,19 +146,19 @@ def test_writer():
|
|||||||
# test write
|
# test write
|
||||||
expected_write_reports = expected_reports[:-1]
|
expected_write_reports = expected_reports[:-1]
|
||||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
||||||
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
assert_eq(writer.size, start_size - len(long_payload))
|
assert_eq(writer.size, start_size - len(long_payload))
|
||||||
msg.send.assert_called_n_times(len(expected_write_reports))
|
msg.send.assert_called_n_times(len(expected_write_reports))
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
# test write raises eof
|
# test write raises eof
|
||||||
msg.send = mock_call(msg.send, [])
|
msg.send = mock_call(msg.send, [])
|
||||||
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
|
||||||
msg.send.assert_called_n_times(0)
|
msg.send.assert_called_n_times(0)
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
# test close
|
# test close
|
||||||
expected_close_reports = expected_reports[-1:]
|
expected_close_reports = expected_reports[-1:]
|
||||||
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
||||||
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
||||||
assert_eq(writer.size, 0)
|
assert_eq(writer.size, 0)
|
||||||
msg.send.assert_called_n_times(len(expected_close_reports))
|
msg.send.assert_called_n_times(len(expected_close_reports))
|
||||||
msg.send = msg.send.original
|
msg.send = msg.send.original
|
||||||
|
Loading…
Reference in New Issue
Block a user