mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-16 19:38:09 +00:00
src/apps/fido_u2f: fix confirmation, refactor
This commit is contained in:
parent
f74cbead5e
commit
eda280213f
@ -234,7 +234,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
|
||||
read = loop.select(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
buf = await read
|
||||
# log.debug(__name__, 'read init %s', buf)
|
||||
|
||||
ifrm = overlay_struct(buf, desc_init)
|
||||
bcnt = ifrm.bcnt
|
||||
@ -244,6 +243,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
|
||||
|
||||
if ifrm.cmd & _TYPE_MASK == _TYPE_CONT:
|
||||
# unexpected cont packet, abort current msg
|
||||
if __debug__:
|
||||
log.warning(__name__, '_TYPE_CONT')
|
||||
return None
|
||||
|
||||
@ -256,7 +256,6 @@ async def read_cmd(iface: io.HID) -> Cmd:
|
||||
|
||||
while datalen < bcnt:
|
||||
buf = await read
|
||||
# log.debug(__name__, 'read cont %s', buf)
|
||||
|
||||
cfrm = overlay_struct(buf, desc_cont)
|
||||
|
||||
@ -268,6 +267,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
|
||||
|
||||
if cfrm.cid != ifrm.cid:
|
||||
# cont frame for a different channel, reply with BUSY and skip
|
||||
if __debug__:
|
||||
log.warning(__name__, '_ERR_CHANNEL_BUSY')
|
||||
await send_cmd(cmd_error(cfrm.cid, _ERR_CHANNEL_BUSY), iface)
|
||||
continue
|
||||
@ -275,6 +275,7 @@ async def read_cmd(iface: io.HID) -> Cmd:
|
||||
if cfrm.seq != seq:
|
||||
# cont frame for this channel, but incorrect seq number, abort
|
||||
# current msg
|
||||
if __debug__:
|
||||
log.warning(__name__, '_ERR_INVALID_SEQ')
|
||||
await send_cmd(cmd_error(cfrm.cid, _ERR_INVALID_SEQ), iface)
|
||||
return None
|
||||
@ -299,7 +300,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
|
||||
|
||||
offset += utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||
iface.write(buf)
|
||||
# log.debug(__name__, 'send init %s', buf)
|
||||
|
||||
if offset < datalen:
|
||||
frm = overlay_struct(buf, cont_desc)
|
||||
@ -312,7 +312,6 @@ async def send_cmd(cmd: Cmd, iface: io.HID) -> None:
|
||||
await write
|
||||
if iface.write(buf) > 0:
|
||||
break
|
||||
# log.debug(__name__, 'send cont %s', buf)
|
||||
seq += 1
|
||||
|
||||
|
||||
@ -321,52 +320,64 @@ def boot(iface: io.HID):
|
||||
|
||||
|
||||
async def handle_reports(iface: io.HID):
|
||||
state = ConfirmState()
|
||||
|
||||
while True:
|
||||
try:
|
||||
req = await read_cmd(iface)
|
||||
if req is None:
|
||||
continue
|
||||
resp = dispatch_cmd(req)
|
||||
resp = dispatch_cmd(req, state)
|
||||
await send_cmd(resp, iface)
|
||||
except Exception as e:
|
||||
log.exception(__name__, e)
|
||||
|
||||
|
||||
def dispatch_cmd(req: Cmd) -> Cmd:
|
||||
def dispatch_cmd(req: Cmd, state: ConfirmState) -> Cmd:
|
||||
if req.cmd == _CMD_MSG:
|
||||
m = req.to_msg()
|
||||
|
||||
if m.cla != 0:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
|
||||
return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
|
||||
|
||||
if m.lc + _APDU_DATA > len(req.data):
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_WRONG_LENGTH')
|
||||
return msg_error(req.cid, _SW_WRONG_LENGTH)
|
||||
|
||||
if m.ins == _MSG_REGISTER:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_MSG_REGISTER')
|
||||
return msg_register(m)
|
||||
return msg_register(m, state)
|
||||
elif m.ins == _MSG_AUTHENTICATE:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_MSG_AUTHENTICATE')
|
||||
return msg_authenticate(m)
|
||||
return msg_authenticate(m, state)
|
||||
elif m.ins == _MSG_VERSION:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_MSG_VERSION')
|
||||
return msg_version(m)
|
||||
else:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
|
||||
return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
|
||||
|
||||
elif req.cmd == _CMD_INIT:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_CMD_INIT')
|
||||
return cmd_init(req)
|
||||
elif req.cmd == _CMD_PING:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_CMD_PING')
|
||||
return req
|
||||
elif req.cmd == _CMD_WINK:
|
||||
if __debug__:
|
||||
log.debug(__name__, '_CMD_WINK')
|
||||
return req
|
||||
else:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_ERR_INVALID_CMD: %d', req.cmd)
|
||||
return cmd_error(req.cid, _ERR_INVALID_CMD)
|
||||
|
||||
@ -394,6 +405,63 @@ def cmd_init(req: Cmd) -> Cmd:
|
||||
|
||||
_CONFIRM_REGISTER = const(0)
|
||||
_CONFIRM_AUTHENTICATE = const(1)
|
||||
_CONFIRM_TIMEOUT_MS = const(10 * 1000)
|
||||
|
||||
|
||||
class ConfirmState:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.action = None
|
||||
self.checksum = None
|
||||
self.app_id = None
|
||||
|
||||
self.confirmed = None
|
||||
self.deadline = None
|
||||
self.workflow = None
|
||||
|
||||
def compare(self, action: int, checksum: bytes) -> bool:
|
||||
if self.action != action or self.checksum != checksum:
|
||||
return False
|
||||
if utime.ticks_ms() >= self.deadline:
|
||||
return False
|
||||
return True
|
||||
|
||||
def setup(self, action: int, checksum: bytes, app_id: bytes) -> None:
|
||||
if self.workflow is not None:
|
||||
loop.close(self.workflow)
|
||||
if workflow.workflows:
|
||||
return False
|
||||
|
||||
self.action = action
|
||||
self.checksum = checksum
|
||||
self.app_id = app_id
|
||||
|
||||
self.confirmed = None
|
||||
self.workflow = self.confirm_workflow()
|
||||
loop.schedule(self.workflow)
|
||||
return True
|
||||
|
||||
def keepalive(self):
|
||||
self.deadline = utime.ticks_ms() + _CONFIRM_TIMEOUT_MS
|
||||
|
||||
async def confirm_workflow(self) -> None:
|
||||
try:
|
||||
workflow.onstart(self.workflow)
|
||||
await self.confirm_layout()
|
||||
finally:
|
||||
workflow.onclose(self.workflow)
|
||||
self.workflow = None
|
||||
|
||||
@ui.layout
|
||||
async def confirm_layout(self) -> None:
|
||||
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
||||
|
||||
content = ConfirmContent(self.action, self.app_id)
|
||||
dialog = ConfirmDialog(content, )
|
||||
self.confirmed = await dialog == CONFIRMED
|
||||
|
||||
|
||||
class ConfirmContent(ui.Widget):
|
||||
@ -420,7 +488,7 @@ class ConfirmContent(ui.Widget):
|
||||
name = knownapps.knownapps[app_id]
|
||||
try:
|
||||
icon = res.load('apps/fido_u2f/res/u2f_%s.toif' % name.lower().replace(' ', '_'))
|
||||
except FileNotFoundError:
|
||||
except Exception:
|
||||
icon = res.load('apps/fido_u2f/res/u2f_generic.toif')
|
||||
else:
|
||||
name = '%s...%s' % (hexlify(app_id[:4]).decode(), hexlify(app_id[-4:]).decode())
|
||||
@ -438,74 +506,17 @@ class ConfirmContent(ui.Widget):
|
||||
ui.display.text_center(ui.WIDTH // 2, 168, self.app_name, ui.MONO, ui.FG, ui.BG)
|
||||
|
||||
|
||||
_CONFIRM_STATE_TIMEOUT_MS = const(10 * 1000)
|
||||
|
||||
|
||||
class ConfirmState:
|
||||
|
||||
def __init__(self, action: int, app_id: bytes) -> None:
|
||||
self.action = action
|
||||
self.app_id = app_id
|
||||
self.deadline_ms = None
|
||||
self.confirmed = None
|
||||
self.task = None
|
||||
|
||||
def fork(self) -> None:
|
||||
self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS
|
||||
self.task = self.confirm()
|
||||
workflow.onstart(self.task)
|
||||
loop.schedule(self.task)
|
||||
|
||||
def kill(self) -> None:
|
||||
if self.task is not None:
|
||||
loop.close(self.task)
|
||||
self.task = None
|
||||
|
||||
async def confirm(self) -> None:
|
||||
confirmed = False
|
||||
try:
|
||||
confirmed = await self.confirm_layout()
|
||||
finally:
|
||||
self.confirmed = confirmed
|
||||
workflow.onclose(self.task)
|
||||
|
||||
@ui.layout
|
||||
async def confirm_layout(self) -> None:
|
||||
from trezor.ui.confirm import HoldToConfirmDialog, CONFIRMED
|
||||
from trezor.ui.text import Text
|
||||
|
||||
if bytes(self.app_id) == _BOGUS_APPID:
|
||||
text = Text(
|
||||
'U2F mismatch', ui.ICON_WRONG,
|
||||
'Another U2F device',
|
||||
'was used to register',
|
||||
'in this application.',
|
||||
icon_color=ui.RED)
|
||||
text.render()
|
||||
await loop.sleep(3 * 1000 * 1000)
|
||||
return True
|
||||
|
||||
content = ConfirmContent(self.action, self.app_id)
|
||||
dialog = HoldToConfirmDialog(content)
|
||||
return await dialog == CONFIRMED
|
||||
|
||||
|
||||
_state = None # type: Optional[ConfirmState] # state for msg_register and msg_authenticate
|
||||
_lastreq = None # type: Optional[Msg] # last received register/authenticate request
|
||||
|
||||
|
||||
def msg_register(req: Msg) -> Cmd:
|
||||
global _state
|
||||
global _lastreq
|
||||
|
||||
def msg_register(req: Msg, state: ConfirmState) -> Cmd:
|
||||
from apps.common import storage
|
||||
|
||||
if not storage.is_initialized():
|
||||
if __debug__:
|
||||
log.warning(__name__, 'not initialized')
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
|
||||
# check length of input data
|
||||
if len(req.data) != 64:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
|
||||
return msg_error(req.cid, _SW_WRONG_LENGTH)
|
||||
|
||||
@ -514,26 +525,24 @@ def msg_register(req: Msg) -> Cmd:
|
||||
app_id = req.data[32:]
|
||||
|
||||
# check equality with last request
|
||||
if _lastreq is None or _lastreq.__dict__ != req.__dict__:
|
||||
if _state is not None:
|
||||
_state.kill()
|
||||
_state = None
|
||||
_lastreq = req
|
||||
if not state.compare(_CONFIRM_REGISTER, req.data):
|
||||
if not state.setup(_CONFIRM_REGISTER, req.data, app_id):
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
state.keepalive()
|
||||
|
||||
# wait for a button or continue
|
||||
if _state is not None and utime.ticks_ms() > _state.deadline_ms:
|
||||
_state.kill()
|
||||
_state = None
|
||||
if _state is None:
|
||||
_state = ConfirmState(_CONFIRM_REGISTER, app_id)
|
||||
_state.fork()
|
||||
if _state.confirmed is None:
|
||||
if not state.confirmed:
|
||||
if __debug__:
|
||||
log.info(__name__, 'waiting for button')
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
_state = None
|
||||
|
||||
# sign the registration challenge and return
|
||||
if __debug__:
|
||||
log.info(__name__, 'signing register')
|
||||
buf = msg_register_sign(chal, app_id)
|
||||
|
||||
state.reset()
|
||||
|
||||
return Cmd(req.cid, _CMD_MSG, buf)
|
||||
|
||||
|
||||
@ -586,25 +595,24 @@ def msg_register_sign(challenge: bytes, app_id: bytes) -> bytes:
|
||||
return buf
|
||||
|
||||
|
||||
def msg_authenticate(req: Msg) -> Cmd:
|
||||
|
||||
global _state
|
||||
global _lastreq
|
||||
|
||||
def msg_authenticate(req: Msg, state: ConfirmState) -> Cmd:
|
||||
from apps.common import storage
|
||||
|
||||
if not storage.is_initialized():
|
||||
if __debug__:
|
||||
log.warning(__name__, 'not initialized')
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
|
||||
# we need at least keyHandleLen
|
||||
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_WRONG_LENGTH req.data')
|
||||
return msg_error(req.cid, _SW_WRONG_LENGTH)
|
||||
|
||||
# check keyHandleLen
|
||||
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
|
||||
if khlen != 64:
|
||||
if __debug__:
|
||||
log.warning(__name__, '_SW_WRONG_LENGTH khlen')
|
||||
return msg_error(req.cid, _SW_WRONG_LENGTH)
|
||||
|
||||
@ -618,40 +626,39 @@ def msg_authenticate(req: Msg) -> Cmd:
|
||||
|
||||
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
|
||||
if req.p1 == _AUTH_CHECK_ONLY:
|
||||
if __debug__:
|
||||
log.info(__name__, '_AUTH_CHECK_ONLY')
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
|
||||
# from now on, only _AUTH_ENFORCE is supported
|
||||
if req.p1 != _AUTH_ENFORCE:
|
||||
if __debug__:
|
||||
log.info(__name__, '_AUTH_ENFORCE')
|
||||
return msg_error(req.cid, _SW_WRONG_DATA)
|
||||
|
||||
# check equality with last request
|
||||
if _lastreq is None or _lastreq.__dict__ != req.__dict__:
|
||||
if _state is not None:
|
||||
_state.kill()
|
||||
_state = None
|
||||
_lastreq = req
|
||||
if not state.compare(_CONFIRM_AUTHENTICATE, req.data):
|
||||
if not state.setup(_CONFIRM_AUTHENTICATE, req.data, auth.appId):
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
state.keepalive()
|
||||
|
||||
# wait for a button or continue
|
||||
if _state is not None and utime.ticks_ms() > _state.deadline_ms:
|
||||
_state.kill()
|
||||
_state = None
|
||||
if _state is None:
|
||||
_state = ConfirmState(_CONFIRM_AUTHENTICATE, auth.appId)
|
||||
_state.fork()
|
||||
if _state.confirmed is None:
|
||||
if not state.confirmed:
|
||||
if __debug__:
|
||||
log.info(__name__, 'waiting for button')
|
||||
return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
|
||||
_state = None
|
||||
|
||||
# sign the authentication challenge and return
|
||||
if __debug__:
|
||||
log.info(__name__, 'signing authentication')
|
||||
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
|
||||
|
||||
state.reset()
|
||||
|
||||
return Cmd(req.cid, _CMD_MSG, buf)
|
||||
|
||||
|
||||
def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
|
||||
|
||||
from apps.common import seed
|
||||
|
||||
# unpack the keypath from the first half of keyhandle
|
||||
@ -661,6 +668,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
|
||||
# check high bit for hardened keys
|
||||
for i in keypath:
|
||||
if not i & 0x80000000:
|
||||
if __debug__:
|
||||
log.warning(__name__, 'invalid key path')
|
||||
return None
|
||||
|
||||
@ -675,6 +683,7 @@ def msg_authenticate_genkey(app_id: bytes, keyhandle: bytes):
|
||||
|
||||
# verify the hmac
|
||||
if keybase != keyhandle[32:]:
|
||||
if __debug__:
|
||||
log.warning(__name__, 'invalid key handle')
|
||||
return None
|
||||
|
||||
|
@ -1,19 +1,19 @@
|
||||
from trezor import loop
|
||||
|
||||
started = []
|
||||
default = None
|
||||
default_handler = None
|
||||
workflows = []
|
||||
layouts = []
|
||||
default = None
|
||||
default_layout = None
|
||||
|
||||
|
||||
def onstart(w):
|
||||
started.append(w)
|
||||
workflows.append(w)
|
||||
|
||||
|
||||
def onclose(w):
|
||||
started.remove(w)
|
||||
if not started and not layouts and default_handler:
|
||||
startdefault(default_handler)
|
||||
workflows.remove(w)
|
||||
if not layouts and default_layout:
|
||||
startdefault(default_layout)
|
||||
|
||||
|
||||
def closedefault():
|
||||
@ -24,13 +24,13 @@ def closedefault():
|
||||
default = None
|
||||
|
||||
|
||||
def startdefault(handler):
|
||||
def startdefault(layout):
|
||||
global default
|
||||
global default_handler
|
||||
global default_layout
|
||||
|
||||
if not default:
|
||||
default_handler = handler
|
||||
default = handler()
|
||||
default_layout = layout
|
||||
default = layout()
|
||||
loop.schedule(default)
|
||||
|
||||
|
||||
@ -47,4 +47,5 @@ def onlayoutstart(l):
|
||||
|
||||
|
||||
def onlayoutclose(l):
|
||||
if l in layouts:
|
||||
layouts.remove(l)
|
||||
|
Loading…
Reference in New Issue
Block a user