1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 15:38:11 +00:00

apps.fido_u2f: more type annotations

This commit is contained in:
Jan Pochyla 2017-06-19 14:47:40 +02:00
parent 26267d532d
commit c94e02b0eb

View File

@ -195,14 +195,26 @@ def make_struct(desc):
return buf, uctypes.struct(uctypes.addressof(buf), desc, uctypes.BIG_ENDIAN) return buf, uctypes.struct(uctypes.addressof(buf), desc, uctypes.BIG_ENDIAN)
class Msg:
def __init__(self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes) -> None:
self.cid = cid
self.cla = cla
self.ins = ins
self.p1 = p1
self.p2 = p2
self.lc = lc
self.data = data
class Cmd: class Cmd:
def __init__(self, cid: int, cmd: int, data: bytes): def __init__(self, cid: int, cmd: int, data: bytes) -> None:
self.cid = cid self.cid = cid
self.cmd = cmd self.cmd = cmd
self.data = data self.data = data
def to_msg(self): def to_msg(self) -> Msg:
cla = self.data[_APDU_CLA] cla = self.data[_APDU_CLA]
ins = self.data[_APDU_INS] ins = self.data[_APDU_INS]
p1 = self.data[_APDU_P1] p1 = self.data[_APDU_P1]
@ -214,23 +226,11 @@ class Cmd:
return Msg(self.cid, cla, ins, p1, p2, lc, data) return Msg(self.cid, cla, ins, p1, p2, lc, data)
class Msg: async def read_cmd(iface: int) -> Cmd:
def __init__(self, cid: int, cla: int, ins: int, p1: int, p2: int, lc: int, data: bytes):
self.cid = cid
self.cla = cla
self.ins = ins
self.p1 = p1
self.p2 = p2
self.lc = lc
self.data = data
def read_cmd(iface: int) -> Cmd:
desc_init = frame_init() desc_init = frame_init()
desc_cont = frame_cont() desc_cont = frame_cont()
buf, = yield loop.select(iface) buf, = await loop.select(iface)
# log.debug(__name__, 'read init %s', buf) # log.debug(__name__, 'read init %s', buf)
ifrm = overlay_struct(buf, desc_init) ifrm = overlay_struct(buf, desc_init)
@ -252,7 +252,7 @@ def read_cmd(iface: int) -> Cmd:
data = data[:bcnt] data = data[:bcnt]
while datalen < bcnt: while datalen < bcnt:
buf, = yield loop.select(iface) buf, = await loop.select(iface)
# log.debug(__name__, 'read cont %s', buf) # log.debug(__name__, 'read cont %s', buf)
cfrm = overlay_struct(buf, desc_cont) cfrm = overlay_struct(buf, desc_cont)
@ -282,7 +282,7 @@ def read_cmd(iface: int) -> Cmd:
return Cmd(ifrm.cid, ifrm.cmd, data) return Cmd(ifrm.cid, ifrm.cmd, data)
def send_cmd(cmd: Cmd, iface: int): def send_cmd(cmd: Cmd, iface: int) -> None:
init_desc = frame_init() init_desc = frame_init()
cont_desc = frame_cont() cont_desc = frame_cont()
offset = 0 offset = 0
@ -315,10 +315,10 @@ def boot():
loop.schedule_task(handle_reports(iface)) loop.schedule_task(handle_reports(iface))
def handle_reports(iface: int): async def handle_reports(iface: int):
while True: while True:
try: try:
req = yield from read_cmd(iface) req = await read_cmd(iface)
if req is None: if req is None:
continue continue
resp = dispatch_cmd(req) resp = dispatch_cmd(req)
@ -333,11 +333,11 @@ def dispatch_cmd(req: Cmd) -> Cmd:
if m.cla != 0: if m.cla != 0:
log.warning(__name__, '_SW_CLA_NOT_SUPPORTED') log.warning(__name__, '_SW_CLA_NOT_SUPPORTED')
return msg_error(req, _SW_CLA_NOT_SUPPORTED) return msg_error(req.cid, _SW_CLA_NOT_SUPPORTED)
if m.lc + _APDU_DATA > len(req.data): if m.lc + _APDU_DATA > len(req.data):
log.warning(__name__, '_SW_WRONG_LENGTH') log.warning(__name__, '_SW_WRONG_LENGTH')
return msg_error(req, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
if m.ins == _MSG_REGISTER: if m.ins == _MSG_REGISTER:
log.debug(__name__, '_MSG_REGISTER') log.debug(__name__, '_MSG_REGISTER')
@ -350,7 +350,7 @@ def dispatch_cmd(req: Cmd) -> Cmd:
return msg_version(m) return msg_version(m)
else: else:
log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins) log.warning(__name__, '_SW_INS_NOT_SUPPORTED: %d', m.ins)
return msg_error(req, _SW_INS_NOT_SUPPORTED) return msg_error(req.cid, _SW_INS_NOT_SUPPORTED)
elif req.cmd == _CMD_INIT: elif req.cmd == _CMD_INIT:
log.debug(__name__, '_CMD_INIT') log.debug(__name__, '_CMD_INIT')
@ -393,14 +393,14 @@ _CONFIRM_AUTHENTICATE = const(1)
class ConfirmContent(ui.Widget): class ConfirmContent(ui.Widget):
def __init__(self, action: int, app_id: bytes): def __init__(self, action: int, app_id: bytes) -> None:
self.action = action self.action = action
self.app_id = app_id self.app_id = app_id
self.app_name = None self.app_name = None
self.app_icon = None self.app_icon = None
self.boot() self.boot()
def boot(self): def boot(self) -> None:
import ubinascii import ubinascii
from trezor import res from trezor import res
from . import knownapps from . import knownapps
@ -420,7 +420,7 @@ class ConfirmContent(ui.Widget):
self.app_name = name self.app_name = name
self.app_icon = icon self.app_icon = icon
def render(self): def render(self) -> None:
if self.action == _CONFIRM_REGISTER: if self.action == _CONFIRM_REGISTER:
header = 'U2F Register' header = 'U2F Register'
else: else:
@ -435,23 +435,23 @@ _CONFIRM_STATE_TIMEOUT_MS = const(10 * 1000)
class ConfirmState: class ConfirmState:
def __init__(self, action: int, app_id: bytes): def __init__(self, action: int, app_id: bytes) -> None:
self.action = action self.action = action
self.app_id = app_id self.app_id = app_id
self.deadline_ms = None self.deadline_ms = None
self.confirmed = None self.confirmed = None
self.task = None self.task = None
def fork(self): def fork(self) -> None:
self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS self.deadline_ms = utime.ticks_ms() + _CONFIRM_STATE_TIMEOUT_MS
self.task = self.confirm() self.task = self.confirm()
workflow.start(self.task) workflow.start(self.task)
def kill(self): def kill(self) -> None:
if self.task is not None: if self.task is not None:
self.task.close() self.task.close()
async def confirm(self): async def confirm(self) -> None:
from trezor.ui.confirm import HoldToConfirmDialog from trezor.ui.confirm import HoldToConfirmDialog
content = ConfirmContent(self.action, self.app_id) content = ConfirmContent(self.action, self.app_id)
dialog = HoldToConfirmDialog(content) dialog = HoldToConfirmDialog(content)
@ -459,8 +459,8 @@ class ConfirmState:
self.confirmed = await dialog self.confirmed = await dialog
_state = None # state for msg_register and msg_authenticate, None or ConfirmState _state = None # type: Optional[ConfirmState] # state for msg_register and msg_authenticate
_lastreq = None # last received register/authenticate request, None or Req _lastreq = None # type: Optional[Msg] # last received register/authenticate request
def msg_register(req: Msg) -> Cmd: def msg_register(req: Msg) -> Cmd:
@ -471,12 +471,12 @@ def msg_register(req: Msg) -> Cmd:
if not storage.is_initialized(): if not storage.is_initialized():
log.warning(__name__, 'not initialized') log.warning(__name__, 'not initialized')
return msg_error(req, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# check length of input data # check length of input data
if len(req.data) != 64: if len(req.data) != 64:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
# parse challenge and app_id # parse challenge and app_id
chal = req.data[:32] chal = req.data[:32]
@ -498,7 +498,7 @@ def msg_register(req: Msg) -> Cmd:
_state.fork() _state.fork()
if _state.confirmed is None: if _state.confirmed is None:
log.info(__name__, 'waiting for button') log.info(__name__, 'waiting for button')
return msg_error(req, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None _state = None
buf = msg_register_sign(chal, app_id) buf = msg_register_sign(chal, app_id)
@ -565,18 +565,18 @@ def msg_authenticate(req: Msg) -> Cmd:
if not storage.is_initialized(): if not storage.is_initialized():
log.warning(__name__, 'not initialized') log.warning(__name__, 'not initialized')
return msg_error(req, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# we need at least keyHandleLen # we need at least keyHandleLen
if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN: if len(req.data) <= _REQ_CMD_AUTHENTICATE_KHLEN:
log.warning(__name__, '_SW_WRONG_LENGTH req.data') log.warning(__name__, '_SW_WRONG_LENGTH req.data')
return msg_error(req, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
# check keyHandleLen # check keyHandleLen
khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN] khlen = req.data[_REQ_CMD_AUTHENTICATE_KHLEN]
if khlen != 64: if khlen != 64:
log.warning(__name__, '_SW_WRONG_LENGTH khlen') log.warning(__name__, '_SW_WRONG_LENGTH khlen')
return msg_error(req, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
auth = overlay_struct(req.data, req_cmd_authenticate(khlen)) auth = overlay_struct(req.data, req_cmd_authenticate(khlen))
@ -584,17 +584,17 @@ def msg_authenticate(req: Msg) -> Cmd:
node = msg_authenticate_genkey(auth.appId, auth.keyHandle) node = msg_authenticate_genkey(auth.appId, auth.keyHandle)
if node is None: if node is None:
# specific error logged in msg_authenticate_genkey # specific error logged in msg_authenticate_genkey
return msg_error(req, _SW_WRONG_DATA) return msg_error(req.cid, _SW_WRONG_DATA)
# if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already # if _AUTH_CHECK_ONLY is requested, return, because keyhandle has been checked already
if req.p1 == _AUTH_CHECK_ONLY: if req.p1 == _AUTH_CHECK_ONLY:
log.info(__name__, '_AUTH_CHECK_ONLY') log.info(__name__, '_AUTH_CHECK_ONLY')
return msg_error(req, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
# from now on, only _AUTH_ENFORCE is supported # from now on, only _AUTH_ENFORCE is supported
if req.p1 != _AUTH_ENFORCE: if req.p1 != _AUTH_ENFORCE:
log.info(__name__, '_AUTH_ENFORCE') log.info(__name__, '_AUTH_ENFORCE')
return msg_error(req, _SW_WRONG_DATA) return msg_error(req.cid, _SW_WRONG_DATA)
# check equality with last request # check equality with last request
if _lastreq is None or _lastreq.__dict__ != req.__dict__: if _lastreq is None or _lastreq.__dict__ != req.__dict__:
@ -612,7 +612,7 @@ def msg_authenticate(req: Msg) -> Cmd:
_state.fork() _state.fork()
if _state.confirmed is None: if _state.confirmed is None:
log.info(__name__, 'waiting for button') log.info(__name__, 'waiting for button')
return msg_error(req, _SW_CONDITIONS_NOT_SATISFIED) return msg_error(req.cid, _SW_CONDITIONS_NOT_SATISFIED)
_state = None _state = None
buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key()) buf = msg_authenticate_sign(auth.chal, auth.appId, node.private_key())
@ -691,12 +691,12 @@ def msg_authenticate_sign(challenge: bytes, app_id: bytes, privkey: bytes) -> by
def msg_version(req: Msg) -> Cmd: def msg_version(req: Msg) -> Cmd:
if req.data: if req.data:
return msg_error(req, _SW_WRONG_LENGTH) return msg_error(req.cid, _SW_WRONG_LENGTH)
return Cmd(req.cid, _CMD_MSG, b'U2F_V2\x90\x00') # includes _SW_NO_ERROR return Cmd(req.cid, _CMD_MSG, b'U2F_V2\x90\x00') # includes _SW_NO_ERROR
def msg_error(req: Msg, code: int) -> Cmd: def msg_error(cid: int, code: int) -> Cmd:
return Cmd(req.cid, _CMD_MSG, ustruct.pack('>H', code)) return Cmd(cid, _CMD_MSG, ustruct.pack('>H', code))
def cmd_error(cid: int, code: int) -> Cmd: def cmd_error(cid: int, code: int) -> Cmd: