You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
6.3 KiB
212 lines
6.3 KiB
8 years ago
|
import protobuf
|
||
6 years ago
|
from trezor import log, loop, messages, utils, workflow
|
||
|
from trezor.wire import codec_v1
|
||
|
from trezor.wire.errors import *
|
||
8 years ago
|
|
||
6 years ago
|
from apps.common import seed
|
||
|
|
||
7 years ago
|
workflow_handlers = {}
|
||
8 years ago
|
|
||
8 years ago
|
|
||
6 years ago
|
def add(mtype, pkgname, modname, namespace=None):
|
||
6 years ago
|
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
|
||
6 years ago
|
if namespace is not None:
|
||
|
register(
|
||
|
mtype,
|
||
|
protobuf_workflow,
|
||
|
keychain_workflow,
|
||
|
namespace,
|
||
|
import_workflow,
|
||
|
pkgname,
|
||
|
modname,
|
||
|
)
|
||
|
else:
|
||
|
register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
|
||
6 years ago
|
|
||
|
|
||
7 years ago
|
def register(mtype, handler, *args):
|
||
6 years ago
|
"""Register `handler` to get scheduled after `mtype` message is received."""
|
||
6 years ago
|
if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType):
|
||
|
mtype = mtype.MESSAGE_WIRE_TYPE
|
||
7 years ago
|
if mtype in workflow_handlers:
|
||
7 years ago
|
raise KeyError
|
||
7 years ago
|
workflow_handlers[mtype] = (handler, args)
|
||
8 years ago
|
|
||
8 years ago
|
|
||
7 years ago
|
def setup(iface):
|
||
6 years ago
|
"""Initialize the wire stack on passed USB interface."""
|
||
6 years ago
|
loop.schedule(session_handler(iface, codec_v1.SESSION_ID))
|
||
8 years ago
|
|
||
|
|
||
7 years ago
|
class Context:
|
||
7 years ago
|
def __init__(self, iface, sid):
|
||
|
self.iface = iface
|
||
|
self.sid = sid
|
||
8 years ago
|
|
||
7 years ago
|
async def call(self, msg, *types):
|
||
6 years ago
|
"""
|
||
7 years ago
|
Reply with `msg` and wait for one of `types`. See `self.write()` and
|
||
|
`self.read()`.
|
||
6 years ago
|
"""
|
||
7 years ago
|
await self.write(msg)
|
||
6 years ago
|
del msg
|
||
7 years ago
|
return await self.read(types)
|
||
8 years ago
|
|
||
7 years ago
|
async def read(self, types):
|
||
6 years ago
|
"""
|
||
7 years ago
|
Wait for incoming message on this wire context and return it. Raises
|
||
|
`UnexpectedMessageError` if the message type does not match one of
|
||
|
`types`; and caller should always make sure to re-raise it.
|
||
6 years ago
|
"""
|
||
7 years ago
|
reader = self.getreader()
|
||
|
|
||
7 years ago
|
if __debug__:
|
||
6 years ago
|
log.debug(
|
||
|
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, types
|
||
|
)
|
||
7 years ago
|
|
||
7 years ago
|
await reader.aopen() # wait for the message header
|
||
|
|
||
|
# if we got a message with unexpected type, raise the reader via
|
||
|
# `UnexpectedMessageError` and let the session handler deal with it
|
||
7 years ago
|
if reader.type not in types:
|
||
|
raise UnexpectedMessageError(reader)
|
||
7 years ago
|
|
||
|
# look up the protobuf class and parse the message
|
||
|
pbtype = messages.get_type(reader.type)
|
||
|
return await protobuf.load_message(reader, pbtype)
|
||
8 years ago
|
|
||
7 years ago
|
async def write(self, msg):
|
||
6 years ago
|
"""
|
||
7 years ago
|
Write a protobuf message to this wire context.
|
||
6 years ago
|
"""
|
||
7 years ago
|
writer = self.getwriter()
|
||
|
|
||
7 years ago
|
if __debug__:
|
||
6 years ago
|
log.debug(
|
||
|
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||
|
)
|
||
7 years ago
|
|
||
7 years ago
|
# get the message size
|
||
6 years ago
|
fields = msg.get_fields()
|
||
|
size = protobuf.count_message(msg, fields)
|
||
7 years ago
|
|
||
|
# write the message
|
||
6 years ago
|
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
|
||
|
await protobuf.dump_message(writer, msg, fields)
|
||
7 years ago
|
await writer.aclose()
|
||
8 years ago
|
|
||
6 years ago
|
def wait(self, *tasks):
|
||
6 years ago
|
"""
|
||
6 years ago
|
Wait until one of the passed tasks finishes, and return the result,
|
||
|
while servicing the wire context. If a message comes until one of the
|
||
|
tasks ends, `UnexpectedMessageError` is raised.
|
||
6 years ago
|
"""
|
||
6 years ago
|
return loop.spawn(self.read(()), *tasks)
|
||
6 years ago
|
|
||
7 years ago
|
def getreader(self):
|
||
6 years ago
|
return codec_v1.Reader(self.iface)
|
||
7 years ago
|
|
||
|
def getwriter(self):
|
||
6 years ago
|
return codec_v1.Writer(self.iface)
|
||
8 years ago
|
|
||
|
|
||
7 years ago
|
class UnexpectedMessageError(Exception):
|
||
|
def __init__(self, reader):
|
||
|
super().__init__()
|
||
|
self.reader = reader
|
||
8 years ago
|
|
||
|
|
||
7 years ago
|
async def session_handler(iface, sid):
|
||
|
reader = None
|
||
|
ctx = Context(iface, sid)
|
||
|
while True:
|
||
|
try:
|
||
|
# wait for new message, if needed, and find handler
|
||
|
if not reader:
|
||
|
reader = ctx.getreader()
|
||
|
await reader.aopen()
|
||
7 years ago
|
try:
|
||
7 years ago
|
handler, args = workflow_handlers[reader.type]
|
||
|
except KeyError:
|
||
|
handler, args = unexpected_msg, ()
|
||
|
|
||
6 years ago
|
m = utils.unimport_begin()
|
||
7 years ago
|
w = handler(ctx, reader, *args)
|
||
|
try:
|
||
|
workflow.onstart(w)
|
||
|
await w
|
||
|
finally:
|
||
|
workflow.onclose(w)
|
||
6 years ago
|
utils.unimport_end(m)
|
||
7 years ago
|
|
||
|
except UnexpectedMessageError as exc:
|
||
|
# retry with opened reader from the exception
|
||
|
reader = exc.reader
|
||
|
continue
|
||
6 years ago
|
except Error as exc:
|
||
|
# we log wire.Error as warning, not as exception
|
||
6 years ago
|
if __debug__:
|
||
|
log.warning(__name__, "failure: %s", exc.message)
|
||
7 years ago
|
except Exception as exc:
|
||
|
# sessions are never closed by raised exceptions
|
||
6 years ago
|
if __debug__:
|
||
|
log.exception(__name__, exc)
|
||
7 years ago
|
|
||
|
# read new message in next iteration
|
||
|
reader = None
|
||
7 years ago
|
|
||
|
|
||
|
async def protobuf_workflow(ctx, reader, handler, *args):
|
||
7 years ago
|
from trezor.messages.Failure import Failure
|
||
|
|
||
|
req = await protobuf.load_message(reader, messages.get_type(reader.type))
|
||
8 years ago
|
try:
|
||
7 years ago
|
res = await handler(ctx, req, *args)
|
||
|
except UnexpectedMessageError:
|
||
|
# session handler takes care of this one
|
||
|
raise
|
||
6 years ago
|
except Error as exc:
|
||
7 years ago
|
# respond with specific code and message
|
||
|
await ctx.write(Failure(code=exc.code, message=exc.message))
|
||
|
raise
|
||
6 years ago
|
except Exception:
|
||
7 years ago
|
# respond with a generic code and message
|
||
6 years ago
|
await ctx.write(
|
||
|
Failure(code=FailureType.FirmwareError, message="Firmware error")
|
||
|
)
|
||
8 years ago
|
raise
|
||
7 years ago
|
if res:
|
||
|
# respond with a specific response
|
||
|
await ctx.write(res)
|
||
8 years ago
|
|
||
8 years ago
|
|
||
6 years ago
|
async def keychain_workflow(ctx, req, namespace, handler, *args):
|
||
|
keychain = await seed.get_keychain(ctx, namespace)
|
||
|
args += (keychain,)
|
||
6 years ago
|
try:
|
||
|
return await handler(ctx, req, *args)
|
||
|
finally:
|
||
|
keychain.__del__()
|
||
6 years ago
|
|
||
|
|
||
6 years ago
|
def import_workflow(ctx, req, pkgname, modname, *args):
|
||
6 years ago
|
modpath = "%s.%s" % (pkgname, modname)
|
||
|
module = __import__(modpath, None, None, (modname,), 0)
|
||
|
handler = getattr(module, modname)
|
||
6 years ago
|
return handler(ctx, req, *args)
|
||
6 years ago
|
|
||
|
|
||
7 years ago
|
async def unexpected_msg(ctx, reader):
|
||
|
from trezor.messages.Failure import Failure
|
||
|
|
||
7 years ago
|
# receive the message and throw it away
|
||
|
while reader.size > 0:
|
||
|
buf = bytearray(reader.size)
|
||
7 years ago
|
await reader.areadinto(buf)
|
||
|
|
||
7 years ago
|
# respond with an unknown message error
|
||
6 years ago
|
await ctx.write(
|
||
|
Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
||
|
)
|