You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/trezor/wire/__init__.py

270 lines
8.2 KiB

import protobuf
from trezor import log, loop, messages, utils, workflow
from trezor.messages import FailureType
from trezor.wire import codec_v1
from trezor.wire.errors import Error
# import all errors into namespace, so that `wire.Error` is available elsewhere
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if False:
from typing import (
Any,
Awaitable,
Dict,
Callable,
Iterable,
List,
Optional,
Tuple,
Type,
)
from trezorio import WireInterface
from protobuf import LoadedMessageType, MessageType
Handler = Callable[..., loop.Task]
workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]]
def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
if namespace is not None:
register(
mtype,
protobuf_workflow,
keychain_workflow,
namespace,
import_workflow,
pkgname,
modname,
)
else:
register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
def register(mtype: int, handler: Handler, *args: Any) -> None:
"""Register `handler` to get scheduled after `mtype` message is received."""
if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType):
mtype = mtype.MESSAGE_WIRE_TYPE
if mtype in workflow_handlers:
raise KeyError
workflow_handlers[mtype] = (handler, args)
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(session_handler(iface, codec_v1.SESSION_ID))
class Context:
def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface
self.sid = sid
async def call(
self, msg: MessageType, exptype: Type[LoadedMessageType]
) -> LoadedMessageType:
await self.write(msg)
del msg
return await self.read(exptype)
async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType:
await self.write(msg)
del msg
return await self.read_any(allowed_types)
async def read(
self, exptype: Optional[Type[LoadedMessageType]]
) -> LoadedMessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def read_any(self, allowed_types: Iterable[int]) -> MessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
allowed_types,
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in allowed_types:
raise UnexpectedMessageError(reader)
# find the protobuf type
exptype = messages.get_type(reader.type)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
writer = self.make_writer()
if __debug__:
log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
)
# get the message size
fields = msg.get_fields()
size = protobuf.count_message(msg, fields)
# write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
await protobuf.dump_message(writer, msg, fields)
await writer.aclose()
def wait(self, *tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result,
while servicing the wire context. If a message comes until one of the
tasks ends, `UnexpectedMessageError` is raised.
"""
return loop.spawn(self.read(None), *tasks)
def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface)
def make_writer(self) -> codec_v1.Writer:
return codec_v1.Writer(self.iface)
class UnexpectedMessageError(Exception):
def __init__(self, reader: codec_v1.Reader) -> None:
super().__init__()
self.reader = reader
async def session_handler(iface: WireInterface, sid: int) -> None:
reader = None
ctx = Context(iface, sid)
while True:
try:
# wait for new message, if needed, and find handler
if not reader:
reader = ctx.make_reader()
await reader.aopen()
try:
handler, args = workflow_handlers[reader.type]
except KeyError:
handler, args = unexpected_msg, ()
m = utils.unimport_begin()
w = handler(ctx, reader, *args)
try:
workflow.onstart(w)
await w
finally:
workflow.onclose(w)
utils.unimport_end(m)
except UnexpectedMessageError as exc:
# retry with opened reader from the exception
reader = exc.reader
continue
except Error as exc:
# we log wire.Error as warning, not as exception
if __debug__:
log.warning(__name__, "failure: %s", exc.message)
except Exception as exc:
# sessions are never closed by raised exceptions
if __debug__:
log.exception(__name__, exc)
# read new message in next iteration
reader = None
async def protobuf_workflow(
ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any
) -> None:
from trezor.messages.Failure import Failure
req = await protobuf.load_message(reader, messages.get_type(reader.type))
try:
res = await handler(ctx, req, *args)
except UnexpectedMessageError:
# session handler takes care of this one
raise
except Error as exc:
# respond with specific code and message
await ctx.write(Failure(code=exc.code, message=exc.message))
raise
except Exception as e:
# respond with a generic code and message
message = "Firmware error"
if __debug__:
message = "{}: {}".format(type(e), e)
await ctx.write(Failure(code=FailureType.FirmwareError, message=message))
raise
if res:
# respond with a specific response
await ctx.write(res)
async def keychain_workflow(
ctx: Context,
req: protobuf.MessageType,
namespace: List,
handler: Handler,
*args: Any
) -> Any:
from apps.common import seed
keychain = await seed.get_keychain(ctx, namespace)
args += (keychain,)
try:
return await handler(ctx, req, *args)
finally:
keychain.__del__()
def import_workflow(
ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any
) -> Any:
modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0) # type: ignore
handler = getattr(module, modname)
return handler(ctx, req, *args)
async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None:
from trezor.messages.Failure import Failure
# receive the message and throw it away
await read_full_msg(reader)
# respond with an unknown message error
await ctx.write(
Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
)
async def read_full_msg(reader: codec_v1.Reader) -> None:
while reader.size > 0:
buf = bytearray(reader.size)
await reader.areadinto(buf)