1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

Added session depth

This commit is contained in:
slush 2013-09-09 15:36:17 +02:00
parent 8e5abb560e
commit 813fb233a1
2 changed files with 92 additions and 62 deletions

View File

@ -64,28 +64,34 @@ class BitkeyClient(object):
print '----------------------' print '----------------------'
print "Sending", self._pprint(msg) print "Sending", self._pprint(msg)
self.transport.write(msg) try:
resp = self.transport.read_blocking() self.transport.session_begin()
if isinstance(resp, proto.ButtonRequest): self.transport.write(msg)
if self.debuglink and self.debug_button: resp = self.transport.read_blocking()
print "Pressing button", self.debug_button
self.debuglink.press_button(self.debug_button) if isinstance(resp, proto.ButtonRequest):
if self.debuglink and self.debug_button:
return self.call(proto.ButtonAck()) print "Pressing button", self.debug_button
self.debuglink.press_button(self.debug_button)
if isinstance(resp, proto.PinMatrixRequest):
if self.debuglink:
if self.debug_pin:
pin = self.debuglink.read_pin_encoded()
msg2 = proto.PinMatrixAck(pin=pin)
else:
msg2 = proto.PinMatrixAck(pin='444444222222')
else:
pin = self.input_func("PIN required: ", resp.message)
msg2 = proto.PinMatrixAck(pin=pin)
return self.call(msg2) return self.call(proto.ButtonAck())
if isinstance(resp, proto.PinMatrixRequest):
if self.debuglink:
if self.debug_pin:
pin = self.debuglink.read_pin_encoded()
msg2 = proto.PinMatrixAck(pin=pin)
else:
msg2 = proto.PinMatrixAck(pin='444444222222')
else:
pin = self.input_func("PIN required: ", resp.message)
msg2 = proto.PinMatrixAck(pin=pin)
return self.call(msg2)
finally:
self.transport.session_end()
if isinstance(resp, proto.Failure): if isinstance(resp, proto.Failure):
self.message_func(resp.message) self.message_func(resp.message)
@ -132,49 +138,55 @@ class BitkeyClient(object):
start = time.time() start = time.time()
# Prepare and send initial message try:
tx = proto.SignTx() self.transport.session_begin()
tx.inputs_count = len(inputs)
tx.outputs_count = len(outputs)
res = self.call(tx)
# Prepare structure for signatures
signatures = [None]*len(inputs)
serialized_tx = ''
counter = 0
while True:
counter += 1
if isinstance(res, proto.Failure): # Prepare and send initial message
raise CallException("Signing failed") tx = proto.SignTx()
tx.inputs_count = len(inputs)
if not isinstance(res, proto.TxRequest): tx.outputs_count = len(outputs)
raise CallException("Unexpected message") res = self.call(tx)
# If there's some part of signed transaction, let's add it # Prepare structure for signatures
if res.serialized_tx: signatures = [None]*len(inputs)
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx) serialized_tx = ''
serialized_tx += res.serialized_tx
counter = 0
if res.signed_index >= 0 and res.signature: while True:
print "!!! SIGNED INPUT", res.signed_index counter += 1
signatures[res.signed_index] = res.signature
if res.request_index < 0: if isinstance(res, proto.Failure):
# Device didn't ask for more information, finish workflow raise CallException("Signing failed")
break
if not isinstance(res, proto.TxRequest):
# Device asked for one more information, let's process it. raise CallException("Unexpected message")
if res.request_type == proto.TXOUTPUT:
res = self.call(outputs[res.request_index]) # If there's some part of signed transaction, let's add it
continue if res.serialized_tx:
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
elif res.request_type == proto.TXINPUT: serialized_tx += res.serialized_tx
print "REQUESTING", res.request_index
res = self.call(inputs[res.request_index]) if res.signed_index >= 0 and res.signature:
continue print "!!! SIGNED INPUT", res.signed_index
signatures[res.signed_index] = res.signature
if res.request_index < 0:
# Device didn't ask for more information, finish workflow
break
# Device asked for one more information, let's process it.
if res.request_type == proto.TXOUTPUT:
res = self.call(outputs[res.request_index])
continue
elif res.request_type == proto.TXINPUT:
print "REQUESTING", res.request_index
res = self.call(inputs[res.request_index])
continue
finally:
self.transport.session_end()
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \ print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
(time.time() - start, counter, len(serialized_tx)) (time.time() - start, counter, len(serialized_tx))

View File

@ -7,6 +7,7 @@ class NotImplementedException(Exception):
class Transport(object): class Transport(object):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
self.device = device self.device = device
self.session_depth = 0
self._open() self._open()
def _open(self): def _open(self):
@ -21,8 +22,25 @@ class Transport(object):
def _read(self): def _read(self):
raise NotImplementedException("Not implemented") raise NotImplementedException("Not implemented")
def _session_begin(self):
pass
def _session_end(self):
pass
def ready_to_read(self): def ready_to_read(self):
raise NotImplementedException("Not implemented") raise NotImplementedException("Not implemented")
def session_begin(self):
if self.session_depth == 0:
self._session_begin()
self.session_depth += 1
def session_end(self):
self.session_depth -= 1
self.session_depth = max(0, self.session_depth)
if self.session_depth == 0:
self._session_end()
def close(self): def close(self):
self._close() self._close()