1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 22:40:59 +00:00

Added session depth

This commit is contained in:
slush 2013-09-09 15:36:17 +02:00
parent 8e5abb560e
commit 813fb233a1
2 changed files with 92 additions and 62 deletions

View File

@ -64,28 +64,34 @@ class BitkeyClient(object):
print '----------------------'
print "Sending", self._pprint(msg)
self.transport.write(msg)
resp = self.transport.read_blocking()
if isinstance(resp, proto.ButtonRequest):
if self.debuglink and self.debug_button:
print "Pressing button", self.debug_button
self.debuglink.press_button(self.debug_button)
return self.call(proto.ButtonAck())
if isinstance(resp, proto.PinMatrixRequest):
if self.debuglink:
if self.debug_pin:
pin = self.debuglink.read_pin_encoded()
msg2 = proto.PinMatrixAck(pin=pin)
else:
msg2 = proto.PinMatrixAck(pin='444444222222')
else:
pin = self.input_func("PIN required: ", resp.message)
msg2 = proto.PinMatrixAck(pin=pin)
try:
self.transport.session_begin()
self.transport.write(msg)
resp = self.transport.read_blocking()
if isinstance(resp, proto.ButtonRequest):
if self.debuglink and self.debug_button:
print "Pressing button", self.debug_button
self.debuglink.press_button(self.debug_button)
return self.call(msg2)
return self.call(proto.ButtonAck())
if isinstance(resp, proto.PinMatrixRequest):
if self.debuglink:
if self.debug_pin:
pin = self.debuglink.read_pin_encoded()
msg2 = proto.PinMatrixAck(pin=pin)
else:
msg2 = proto.PinMatrixAck(pin='444444222222')
else:
pin = self.input_func("PIN required: ", resp.message)
msg2 = proto.PinMatrixAck(pin=pin)
return self.call(msg2)
finally:
self.transport.session_end()
if isinstance(resp, proto.Failure):
self.message_func(resp.message)
@ -132,49 +138,55 @@ class BitkeyClient(object):
start = time.time()
# Prepare and send initial message
tx = proto.SignTx()
tx.inputs_count = len(inputs)
tx.outputs_count = len(outputs)
res = self.call(tx)
# Prepare structure for signatures
signatures = [None]*len(inputs)
serialized_tx = ''
counter = 0
while True:
counter += 1
try:
self.transport.session_begin()
if isinstance(res, proto.Failure):
raise CallException("Signing failed")
if not isinstance(res, proto.TxRequest):
raise CallException("Unexpected message")
# If there's some part of signed transaction, let's add it
if res.serialized_tx:
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
serialized_tx += res.serialized_tx
if res.signed_index >= 0 and res.signature:
print "!!! SIGNED INPUT", res.signed_index
signatures[res.signed_index] = res.signature
# Prepare and send initial message
tx = proto.SignTx()
tx.inputs_count = len(inputs)
tx.outputs_count = len(outputs)
res = self.call(tx)
# Prepare structure for signatures
signatures = [None]*len(inputs)
serialized_tx = ''
counter = 0
while True:
counter += 1
if res.request_index < 0:
# Device didn't ask for more information, finish workflow
break
# Device asked for one more information, let's process it.
if res.request_type == proto.TXOUTPUT:
res = self.call(outputs[res.request_index])
continue
elif res.request_type == proto.TXINPUT:
print "REQUESTING", res.request_index
res = self.call(inputs[res.request_index])
continue
if isinstance(res, proto.Failure):
raise CallException("Signing failed")
if not isinstance(res, proto.TxRequest):
raise CallException("Unexpected message")
# If there's some part of signed transaction, let's add it
if res.serialized_tx:
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
serialized_tx += res.serialized_tx
if res.signed_index >= 0 and res.signature:
print "!!! SIGNED INPUT", res.signed_index
signatures[res.signed_index] = res.signature
if res.request_index < 0:
# Device didn't ask for more information, finish workflow
break
# Device asked for one more information, let's process it.
if res.request_type == proto.TXOUTPUT:
res = self.call(outputs[res.request_index])
continue
elif res.request_type == proto.TXINPUT:
print "REQUESTING", res.request_index
res = self.call(inputs[res.request_index])
continue
finally:
self.transport.session_end()
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
(time.time() - start, counter, len(serialized_tx))

View File

@ -7,6 +7,7 @@ class NotImplementedException(Exception):
class Transport(object):
def __init__(self, device, *args, **kwargs):
self.device = device
self.session_depth = 0
self._open()
def _open(self):
@ -21,8 +22,25 @@ class Transport(object):
def _read(self):
raise NotImplementedException("Not implemented")
def _session_begin(self):
pass
def _session_end(self):
pass
def ready_to_read(self):
raise NotImplementedException("Not implemented")
def session_begin(self):
if self.session_depth == 0:
self._session_begin()
self.session_depth += 1
def session_end(self):
self.session_depth -= 1
self.session_depth = max(0, self.session_depth)
if self.session_depth == 0:
self._session_end()
def close(self):
self._close()