1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-17 11:58:13 +00:00

src/apps/fido_u2f: fix confirmation, refactor

This commit is contained in:
Jan Pochyla 2018-03-01 05:13:01 +01:00
parent f74cbead5e
commit eda280213f
2 changed files with 147 additions and 137 deletions

View File

@ -234,7 +234,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
read = loop.select(iface.iface_num() | io.POLL_READ) read = loop.select(iface.iface_num() | io.POLL_READ)
buf = await read buf = await read
# log.debug(__name__, 'read init %s', buf)
ifrm = overlay_struct(buf, desc_init) ifrm = overlay_struct(buf, desc_init)
bcnt = ifrm.bcnt bcnt = ifrm.bcnt
@ -244,6 +243,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
if ifrm.cmd & _TYPE_MASK == _TYPE_CONT: if ifrm.cmd & _TYPE_MASK == _TYPE_CONT:
# unexpected cont packet, abort current msg # unexpected cont packet, abort current msg
if __debug__:
log.warning(__name__, '_TYPE_CONT') log.warning(__name__, '_TYPE_CONT')
return None return None
@ -256,7 +256,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
while datalen < bcnt: while datalen < bcnt:
buf = await read buf = await read
# log.debug(__name__, 'read cont %s', buf)
cfrm = overlay_struct(buf, desc_cont) cfrm = overlay_struct(buf, desc_cont)
@ -268,6 +267,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
if cfrm.cid != ifrm.cid: if cfrm.cid != ifrm.cid:
# cont frame for a different channel, reply with BUSY and skip # cont frame for a different channel, reply with BUSY and skip
if __debug__:
log.warning(__name__, '_ERR_CHANNEL_BUSY') log.warning(__name__, '_ERR_CHANNEL_BUSY')
await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface) await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface)
continue continue
@ -275,6 +275,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
if cfrm.seq != seq: if cfrm.seq != seq:
# cont frame for this channel, but incorrect seq number, abort # cont frame for this channel, but incorrect seq number, abort
# current msg # current msg
if __debug__:
log.warning(__name__, '_ERR_INVALID_SEQ') log.warning(__name__, '_ERR_INVALID_SEQ')
await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface) await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface)
return None return None
@ -299,7 +300,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen) offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
iface.write(buf) iface.write(buf)
# log.debug(__name__, 'send init %s', buf)
if offset < datalen: if offset < datalen:
frm = overlay_struct(buf, cont_desc) frm = overlay_struct(buf, cont_desc)
@ -312,7 +312,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
await write await write
if iface.write(buf) > 0: if iface.write(buf) > 0:
break break
# log.debug(__name__, 'send cont %s', buf)
seq += 1 seq += 1
@ -321,52 +320,64 @@ def boot(iface: io.HID):
async def handle_reports(iface: io.HID): async def handle_reports(iface: io.HID):
state = ConfirmState()
while True: while True:
try: try:
req = await read_cmd(iface) req = await read_cmd(iface)
if req is None: if req is None:
continue continue
resp = dispatch_cmd(req) resp = dispatch_cmd(req, state)
await send_cmd(resp, iface) await send_cmd(resp, iface)
except Exception as e: except Exception as e:
log.exception(__name__, e) log.exception(__name__, e)
def dispatch_cmd(req: Cmd) -> Cmd: def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd:
if req.cmd == _CMD_MSG: if req.cmd == _CMD_MSG:
m = req.to_msg() m = req.to_msg()
if m.cla != 0: if m.cla != 0:
if __debug__:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED') log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED) return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
if m.lc + _APDU_DATA > len(req.data): if m.lc + _APDU_DATA > len(req.data):
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH') log.warning(__name__, '_SW_WRONG_LENGTH')
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
if m.ins == _MSG_REGISTER: if m.ins == _MSG_REGISTER:
if __debug__:
log.debug(__name__, '_MSG_REGISTER') log.debug(__name__, '_MSG_REGISTER')
return msg_register(m) return msg_register(m, state)
elif m.ins == _MSG_AUTHENTICATE: elif m.ins == _MSG_AUTHENTICATE:
if __debug__:
log.debug(__name__, '_MSG_AUTHENTICATE') log.debug(__name__, '_MSG_AUTHENTICATE')
return msg_authenticate(m) return msg_authenticate(m, state)
elif m.ins == _MSG_VERSION: elif m.ins == _MSG_VERSION:
if __debug__:
log.debug(__name__, '_MSG_VERSION') log.debug(__name__, '_MSG_VERSION')
return msg_version(m) return msg_version(m)
else: else:
if __debug__:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins) log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
return msg_error(req.cid, _SW_INS_NOT_SUPPORTED) return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
elif req.cmd == _CMD_INIT: elif req.cmd == _CMD_INIT:
if __debug__:
log.debug(__name__, '_CMD_INIT') log.debug(__name__, '_CMD_INIT')
return cmd_init(req) return cmd_init(req)
elif req.cmd == _CMD_PING: elif req.cmd == _CMD_PING:
if __debug__:
log.debug(__name__, '_CMD_PING') log.debug(__name__, '_CMD_PING')
return req return req
elif req.cmd == _CMD_WINK: elif req.cmd == _CMD_WINK:
if __debug__:
log.debug(__name__, '_CMD_WINK') log.debug(__name__, '_CMD_WINK')
return req return req
else: else:
if __debug__:
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd) log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd)
return cmd_error(req.cid, _ERR_INVALID_CMD) return cmd_error(req.cid, _ERR_INVALID_CMD)
@ -394,6 +405,63 @@ def cmd_init(req: Cmd) -> Cmd:
_CONFIRM_REGISTER = const(0) _CONFIRM_REGISTER = const(0)
_CONFIRM_AUTHENTICATE = const(1) _CONFIRM_AUTHENTICATE = const(1)
_CONFIRM_TIMEOUT_MS = const(10 * 1000)
class ConfirmState:
def __init__(self) -> None:
self.reset()
def reset(self):
self.action = None
self.checksum = None
self.app_id = None
self.confirmed = None
self.deadline = None
self.workflow = None
def compare(self, action: int, checksum: bytes) -> bool:
if self.action != action or self.checksum != checksum:
return False
if utime.ticks_ms() >= self.deadline:
return False
return True
def setup(self, action: int, checksum: bytes, app_id: bytes) -> None:
if self.workflow is not None:
loop.close(self.workflow)
if workflow.workflows:
return False
self.action = action
self.checksum = checksum
self.app_id = app_id
self.confirmed = None
self.workflow = self.confirm_workflow()
loop.schedule(self.workflow)
return True
def keepalive(self):
self.deadline = utime.ticks_ms() + _CONFIRM_TIMEOUT_MS
async def confirm_workflow(self) -> None:
try:
workflow.onstart(self.workflow)
await self.confirm_layout()
finally:
workflow.onclose(self.workflow)
self.workflow = None
@ui.layout
async def confirm_layout(self) -> None:
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
content = ConfirmContent(self.action, self.app_id)
dialog = ConfirmDialog(content, )
self.confirmed = await dialog == CONFIRMED
class ConfirmContent(ui.Widget): class ConfirmContent(ui.Widget):
@ -420,7 +488,7 @@ class ConfirmContent(ui.Widget):
name = knownapps.knownapps[app_id] name = knownapps.knownapps[app_id]
try: try:
icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_')) icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_'))
except FileNotFoundError: except Exception:
icon = res.load('apps/fido_u2f/res/u2f_generic.toif') icon = res.load('apps/fido_u2f/res/u2f_generic.toif')
else: else:
name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode()) name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode())
@ -438,74 +506,17 @@ class ConfirmContent(ui.Widget):
ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG) ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG)
_CONFIRM_STATE_TIMEOUT_MS = const(10 * 1000) def msg_register(req: Msg, state: ConfirmState) -> Cmd:
class ConfirmState:
def __init__(self, action: int, app_id: bytes) -> None:
self.action = action
self.app_id = app_id
self.deadline_ms = None
self.confirmed = None
self.task = None
def fork(self) -> None:
self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS
self.task = self.confirm()
workflow.onstart(self.task)
loop.schedule(self.task)
def kill(self) -> None:
if self.task is not None:
loop.close(self.task)
self.task = None
async def confirm(self) -> None:
confirmed = False
try:
confirmed = await self.confirm_layout()
finally:
self.confirmed = confirmed
workflow.onclose(self.task)
@ui.layout
async def confirm_layout(self) -> None:
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
from trezor.ui.text import Text
if bytes(self.app_id) == _BOGUS_APPID:
text = Text(
'U2F mismatch', ui.ICON_WRONG,
'Another U2F device',
'was used to register',
'in this application.',
icon_color=ui.RED)
text.render()
await loop.sleep(3 * 1000 * 1000)
return True
content = ConfirmContent(self.action, self.app_id)
dialog = HoldToConfirmDialog(content)
return await dialog == CONFIRMED
_state = None # type: Optional[ConfirmState] # state for msg_register and msg_authenticate
_lastreq = None # type: Optional[Msg] # last received register/authenticate request
def msg_register(req: Msg) -> Cmd:
global _state
global _lastreq
from apps.common import storage from apps.common import storage
if not storage.is_initialized(): if not storage.is_initialized():
if __debug__:
log.warning(__name__, 'not initialized') log.warning(__name__, 'not initialized')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# check length of input data # check length of input data
if len(req.data) != 64: if len(req.data) != 64:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
@ -514,26 +525,24 @@ def msg_register(req: Msg) -> Cmd:
app_id = req.data[32:] app_id = req.data[32:]
# check equality with last request # check equality with last request
if _lastreq is None or _lastreq.__dict__ != req.__dict__: if not state.compare(_CONFIRM_REGISTER, req.data):
if _state is not None: if not state.setup(_CONFIRM_REGISTER, req.data, app_id):
_state.kill() return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None state.keepalive()
_lastreq = req
# wait for a button or continue # wait for a button or continue
if _state is not None and utime.ticks_ms() > _state.deadline_ms: if not state.confirmed:
_state.kill() if __debug__:
_state = None
if _state is None:
_state = ConfirmState(_CONFIRM_REGISTER, app_id)
_state.fork()
if _state.confirmed is None:
log.info(__name__, 'waiting for button') log.info(__name__, 'waiting for button')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None
# sign the registration challenge and return
if __debug__:
log.info(__name__, 'signing register')
buf = msg_register_sign(chal, app_id) buf = msg_register_sign(chal, app_id)
state.reset()
return Cmd(req.cid, _CMD_MSG, buf) return Cmd(req.cid, _CMD_MSG, buf)
@ -586,25 +595,24 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
return buf return buf
def msg_authenticate(req: Msg) -> Cmd: def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
global _state
global _lastreq
from apps.common import storage from apps.common import storage
if not storage.is_initialized(): if not storage.is_initialized():
if __debug__:
log.warning(__name__, 'not initialized') log.warning(__name__, 'not initialized')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# we need at least keyHandleLen # we need at least keyHandleLen
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN: if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
# check keyHandleLen # check keyHandleLen
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN] khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
if khlen != 64: if khlen != 64:
if __debug__:
log.warning(__name__, '_SW_WRONG_LENGTH khlen') log.warning(__name__, '_SW_WRONG_LENGTH khlen')
return msg_error(req.cid, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
@ -618,40 +626,39 @@ def msg_authenticate(req: Msg) -> Cmd:
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already # if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
if req.p1 == _AUTH_CHECK_ONLY: if req.p1 == _AUTH_CHECK_ONLY:
if __debug__:
log.info(__name__, '_AUTH_CHECK_ONLY') log.info(__name__, '_AUTH_CHECK_ONLY')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# from now on, only _AUTH_ENFORCE is supported # from now on, only _AUTH_ENFORCE is supported
if req.p1 != _AUTH_ENFORCE: if req.p1 != _AUTH_ENFORCE:
if __debug__:
log.info(__name__, '_AUTH_ENFORCE') log.info(__name__, '_AUTH_ENFORCE')
return msg_error(req.cid, _SW_WRONG_DATA) return msg_error(req.cid, _SW_WRONG_DATA)
# check equality with last request # check equality with last request
if _lastreq is None or _lastreq.__dict__ != req.__dict__: if not state.compare(_CONFIRM_AUTHENTICATE, req.data):
if _state is not None: if not state.setup(_CONFIRM_AUTHENTICATE, req.data, auth.appId):
_state.kill() return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None state.keepalive()
_lastreq = req
# wait for a button or continue # wait for a button or continue
if _state is not None and utime.ticks_ms() > _state.deadline_ms: if not state.confirmed:
_state.kill() if __debug__:
_state = None
if _state is None:
_state = ConfirmState(_CONFIRM_AUTHENTICATE, auth.appId)
_state.fork()
if _state.confirmed is None:
log.info(__name__, 'waiting for button') log.info(__name__, 'waiting for button')
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None
# sign the authentication challenge and return
if __debug__:
log.info(__name__, 'signing authentication')
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key()) buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
state.reset()
return Cmd(req.cid, _CMD_MSG, buf) return Cmd(req.cid, _CMD_MSG, buf)
def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes): def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
from apps.common import seed from apps.common import seed
# unpack the keypath from the first half of keyhandle # unpack the keypath from the first half of keyhandle
@ -661,6 +668,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# check high bit for hardened keys # check high bit for hardened keys
for i in keypath: for i in keypath:
if not i & 0x80000000: if not i & 0x80000000:
if __debug__:
log.warning(__name__, 'invalid key path') log.warning(__name__, 'invalid key path')
return None return None
@ -675,6 +683,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
# verify the hmac # verify the hmac
if keybase != keyhandle[32:]: if keybase != keyhandle[32:]:
if __debug__:
log.warning(__name__, 'invalid key handle') log.warning(__name__, 'invalid key handle')
return None return None

View File

@ -1,19 +1,19 @@
from trezor import loop from trezor import loop
started = [] workflows = []
default = None
default_handler = None
layouts = [] layouts = []
default = None
default_layout = None
def onstart(w): def onstart(w):
started.append(w) workflows.append(w)
def onclose(w): def onclose(w):
started.remove(w) workflows.remove(w)
if not started and not layouts and default_handler: if not layouts and default_layout:
startdefault(default_handler) startdefault(default_layout)
def closedefault(): def closedefault():
@ -24,13 +24,13 @@ def closedefault():
default = None default = None
def startdefault(handler): def startdefault(layout):
global default global default
global default_handler global default_layout
if not default: if not default:
default_handler = handler default_layout = layout
default = handler() default = layout()
loop.schedule(default) loop.schedule(default)
@ -47,4 +47,5 @@ def onlayoutstart(l):
def onlayoutclose(l): def onlayoutclose(l):
if l in layouts:
layouts.remove(l) layouts.remove(l)