mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
Added session depth
This commit is contained in:
parent
8e5abb560e
commit
813fb233a1
@ -64,28 +64,34 @@ class BitkeyClient(object):
|
||||
print '----------------------'
|
||||
print "Sending", self._pprint(msg)
|
||||
|
||||
self.transport.write(msg)
|
||||
resp = self.transport.read_blocking()
|
||||
try:
|
||||
self.transport.session_begin()
|
||||
|
||||
if isinstance(resp, proto.ButtonRequest):
|
||||
if self.debuglink and self.debug_button:
|
||||
print "Pressing button", self.debug_button
|
||||
self.debuglink.press_button(self.debug_button)
|
||||
self.transport.write(msg)
|
||||
resp = self.transport.read_blocking()
|
||||
|
||||
return self.call(proto.ButtonAck())
|
||||
if isinstance(resp, proto.ButtonRequest):
|
||||
if self.debuglink and self.debug_button:
|
||||
print "Pressing button", self.debug_button
|
||||
self.debuglink.press_button(self.debug_button)
|
||||
|
||||
if isinstance(resp, proto.PinMatrixRequest):
|
||||
if self.debuglink:
|
||||
if self.debug_pin:
|
||||
pin = self.debuglink.read_pin_encoded()
|
||||
msg2 = proto.PinMatrixAck(pin=pin)
|
||||
return self.call(proto.ButtonAck())
|
||||
|
||||
if isinstance(resp, proto.PinMatrixRequest):
|
||||
if self.debuglink:
|
||||
if self.debug_pin:
|
||||
pin = self.debuglink.read_pin_encoded()
|
||||
msg2 = proto.PinMatrixAck(pin=pin)
|
||||
else:
|
||||
msg2 = proto.PinMatrixAck(pin='444444222222')
|
||||
else:
|
||||
msg2 = proto.PinMatrixAck(pin='444444222222')
|
||||
else:
|
||||
pin = self.input_func("PIN required: ", resp.message)
|
||||
msg2 = proto.PinMatrixAck(pin=pin)
|
||||
pin = self.input_func("PIN required: ", resp.message)
|
||||
msg2 = proto.PinMatrixAck(pin=pin)
|
||||
|
||||
return self.call(msg2)
|
||||
return self.call(msg2)
|
||||
|
||||
finally:
|
||||
self.transport.session_end()
|
||||
|
||||
if isinstance(resp, proto.Failure):
|
||||
self.message_func(resp.message)
|
||||
@ -132,48 +138,54 @@ class BitkeyClient(object):
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Prepare and send initial message
|
||||
tx = proto.SignTx()
|
||||
tx.inputs_count = len(inputs)
|
||||
tx.outputs_count = len(outputs)
|
||||
res = self.call(tx)
|
||||
try:
|
||||
self.transport.session_begin()
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures = [None]*len(inputs)
|
||||
serialized_tx = ''
|
||||
# Prepare and send initial message
|
||||
tx = proto.SignTx()
|
||||
tx.inputs_count = len(inputs)
|
||||
tx.outputs_count = len(outputs)
|
||||
res = self.call(tx)
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
counter += 1
|
||||
# Prepare structure for signatures
|
||||
signatures = [None]*len(inputs)
|
||||
serialized_tx = ''
|
||||
|
||||
if isinstance(res, proto.Failure):
|
||||
raise CallException("Signing failed")
|
||||
counter = 0
|
||||
while True:
|
||||
counter += 1
|
||||
|
||||
if not isinstance(res, proto.TxRequest):
|
||||
raise CallException("Unexpected message")
|
||||
if isinstance(res, proto.Failure):
|
||||
raise CallException("Signing failed")
|
||||
|
||||
# If there's some part of signed transaction, let's add it
|
||||
if res.serialized_tx:
|
||||
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
||||
serialized_tx += res.serialized_tx
|
||||
if not isinstance(res, proto.TxRequest):
|
||||
raise CallException("Unexpected message")
|
||||
|
||||
if res.signed_index >= 0 and res.signature:
|
||||
print "!!! SIGNED INPUT", res.signed_index
|
||||
signatures[res.signed_index] = res.signature
|
||||
# If there's some part of signed transaction, let's add it
|
||||
if res.serialized_tx:
|
||||
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
||||
serialized_tx += res.serialized_tx
|
||||
|
||||
if res.request_index < 0:
|
||||
# Device didn't ask for more information, finish workflow
|
||||
break
|
||||
if res.signed_index >= 0 and res.signature:
|
||||
print "!!! SIGNED INPUT", res.signed_index
|
||||
signatures[res.signed_index] = res.signature
|
||||
|
||||
# Device asked for one more information, let's process it.
|
||||
if res.request_type == proto.TXOUTPUT:
|
||||
res = self.call(outputs[res.request_index])
|
||||
continue
|
||||
if res.request_index < 0:
|
||||
# Device didn't ask for more information, finish workflow
|
||||
break
|
||||
|
||||
elif res.request_type == proto.TXINPUT:
|
||||
print "REQUESTING", res.request_index
|
||||
res = self.call(inputs[res.request_index])
|
||||
continue
|
||||
# Device asked for one more information, let's process it.
|
||||
if res.request_type == proto.TXOUTPUT:
|
||||
res = self.call(outputs[res.request_index])
|
||||
continue
|
||||
|
||||
elif res.request_type == proto.TXINPUT:
|
||||
print "REQUESTING", res.request_index
|
||||
res = self.call(inputs[res.request_index])
|
||||
continue
|
||||
|
||||
finally:
|
||||
self.transport.session_end()
|
||||
|
||||
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
|
||||
(time.time() - start, counter, len(serialized_tx))
|
||||
|
@ -7,6 +7,7 @@ class NotImplementedException(Exception):
|
||||
class Transport(object):
|
||||
def __init__(self, device, *args, **kwargs):
|
||||
self.device = device
|
||||
self.session_depth = 0
|
||||
self._open()
|
||||
|
||||
def _open(self):
|
||||
@ -21,9 +22,26 @@ class Transport(object):
|
||||
def _read(self):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def _session_begin(self):
|
||||
pass
|
||||
|
||||
def _session_end(self):
|
||||
pass
|
||||
|
||||
def ready_to_read(self):
|
||||
raise NotImplementedException("Not implemented")
|
||||
|
||||
def session_begin(self):
|
||||
if self.session_depth == 0:
|
||||
self._session_begin()
|
||||
self.session_depth += 1
|
||||
|
||||
def session_end(self):
|
||||
self.session_depth -= 1
|
||||
self.session_depth = max(0, self.session_depth)
|
||||
if self.session_depth == 0:
|
||||
self._session_end()
|
||||
|
||||
def close(self):
|
||||
self._close()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user