mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
Added session depth
This commit is contained in:
parent
8e5abb560e
commit
813fb233a1
@ -64,28 +64,34 @@ class BitkeyClient(object):
|
|||||||
print '----------------------'
|
print '----------------------'
|
||||||
print "Sending", self._pprint(msg)
|
print "Sending", self._pprint(msg)
|
||||||
|
|
||||||
self.transport.write(msg)
|
try:
|
||||||
resp = self.transport.read_blocking()
|
self.transport.session_begin()
|
||||||
|
|
||||||
if isinstance(resp, proto.ButtonRequest):
|
self.transport.write(msg)
|
||||||
if self.debuglink and self.debug_button:
|
resp = self.transport.read_blocking()
|
||||||
print "Pressing button", self.debug_button
|
|
||||||
self.debuglink.press_button(self.debug_button)
|
if isinstance(resp, proto.ButtonRequest):
|
||||||
|
if self.debuglink and self.debug_button:
|
||||||
return self.call(proto.ButtonAck())
|
print "Pressing button", self.debug_button
|
||||||
|
self.debuglink.press_button(self.debug_button)
|
||||||
if isinstance(resp, proto.PinMatrixRequest):
|
|
||||||
if self.debuglink:
|
|
||||||
if self.debug_pin:
|
|
||||||
pin = self.debuglink.read_pin_encoded()
|
|
||||||
msg2 = proto.PinMatrixAck(pin=pin)
|
|
||||||
else:
|
|
||||||
msg2 = proto.PinMatrixAck(pin='444444222222')
|
|
||||||
else:
|
|
||||||
pin = self.input_func("PIN required: ", resp.message)
|
|
||||||
msg2 = proto.PinMatrixAck(pin=pin)
|
|
||||||
|
|
||||||
return self.call(msg2)
|
return self.call(proto.ButtonAck())
|
||||||
|
|
||||||
|
if isinstance(resp, proto.PinMatrixRequest):
|
||||||
|
if self.debuglink:
|
||||||
|
if self.debug_pin:
|
||||||
|
pin = self.debuglink.read_pin_encoded()
|
||||||
|
msg2 = proto.PinMatrixAck(pin=pin)
|
||||||
|
else:
|
||||||
|
msg2 = proto.PinMatrixAck(pin='444444222222')
|
||||||
|
else:
|
||||||
|
pin = self.input_func("PIN required: ", resp.message)
|
||||||
|
msg2 = proto.PinMatrixAck(pin=pin)
|
||||||
|
|
||||||
|
return self.call(msg2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.transport.session_end()
|
||||||
|
|
||||||
if isinstance(resp, proto.Failure):
|
if isinstance(resp, proto.Failure):
|
||||||
self.message_func(resp.message)
|
self.message_func(resp.message)
|
||||||
@ -132,49 +138,55 @@ class BitkeyClient(object):
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
# Prepare and send initial message
|
try:
|
||||||
tx = proto.SignTx()
|
self.transport.session_begin()
|
||||||
tx.inputs_count = len(inputs)
|
|
||||||
tx.outputs_count = len(outputs)
|
|
||||||
res = self.call(tx)
|
|
||||||
|
|
||||||
# Prepare structure for signatures
|
|
||||||
signatures = [None]*len(inputs)
|
|
||||||
serialized_tx = ''
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
while True:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
if isinstance(res, proto.Failure):
|
# Prepare and send initial message
|
||||||
raise CallException("Signing failed")
|
tx = proto.SignTx()
|
||||||
|
tx.inputs_count = len(inputs)
|
||||||
if not isinstance(res, proto.TxRequest):
|
tx.outputs_count = len(outputs)
|
||||||
raise CallException("Unexpected message")
|
res = self.call(tx)
|
||||||
|
|
||||||
# If there's some part of signed transaction, let's add it
|
# Prepare structure for signatures
|
||||||
if res.serialized_tx:
|
signatures = [None]*len(inputs)
|
||||||
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
serialized_tx = ''
|
||||||
serialized_tx += res.serialized_tx
|
|
||||||
|
counter = 0
|
||||||
if res.signed_index >= 0 and res.signature:
|
while True:
|
||||||
print "!!! SIGNED INPUT", res.signed_index
|
counter += 1
|
||||||
signatures[res.signed_index] = res.signature
|
|
||||||
|
|
||||||
if res.request_index < 0:
|
if isinstance(res, proto.Failure):
|
||||||
# Device didn't ask for more information, finish workflow
|
raise CallException("Signing failed")
|
||||||
break
|
|
||||||
|
if not isinstance(res, proto.TxRequest):
|
||||||
# Device asked for one more information, let's process it.
|
raise CallException("Unexpected message")
|
||||||
if res.request_type == proto.TXOUTPUT:
|
|
||||||
res = self.call(outputs[res.request_index])
|
# If there's some part of signed transaction, let's add it
|
||||||
continue
|
if res.serialized_tx:
|
||||||
|
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
||||||
elif res.request_type == proto.TXINPUT:
|
serialized_tx += res.serialized_tx
|
||||||
print "REQUESTING", res.request_index
|
|
||||||
res = self.call(inputs[res.request_index])
|
if res.signed_index >= 0 and res.signature:
|
||||||
continue
|
print "!!! SIGNED INPUT", res.signed_index
|
||||||
|
signatures[res.signed_index] = res.signature
|
||||||
|
|
||||||
|
if res.request_index < 0:
|
||||||
|
# Device didn't ask for more information, finish workflow
|
||||||
|
break
|
||||||
|
|
||||||
|
# Device asked for one more information, let's process it.
|
||||||
|
if res.request_type == proto.TXOUTPUT:
|
||||||
|
res = self.call(outputs[res.request_index])
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif res.request_type == proto.TXINPUT:
|
||||||
|
print "REQUESTING", res.request_index
|
||||||
|
res = self.call(inputs[res.request_index])
|
||||||
|
continue
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.transport.session_end()
|
||||||
|
|
||||||
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
|
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
|
||||||
(time.time() - start, counter, len(serialized_tx))
|
(time.time() - start, counter, len(serialized_tx))
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ class NotImplementedException(Exception):
|
|||||||
class Transport(object):
|
class Transport(object):
|
||||||
def __init__(self, device, *args, **kwargs):
|
def __init__(self, device, *args, **kwargs):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.session_depth = 0
|
||||||
self._open()
|
self._open()
|
||||||
|
|
||||||
def _open(self):
|
def _open(self):
|
||||||
@ -21,8 +22,25 @@ class Transport(object):
|
|||||||
def _read(self):
|
def _read(self):
|
||||||
raise NotImplementedException("Not implemented")
|
raise NotImplementedException("Not implemented")
|
||||||
|
|
||||||
|
def _session_begin(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _session_end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def ready_to_read(self):
|
def ready_to_read(self):
|
||||||
raise NotImplementedException("Not implemented")
|
raise NotImplementedException("Not implemented")
|
||||||
|
|
||||||
|
def session_begin(self):
|
||||||
|
if self.session_depth == 0:
|
||||||
|
self._session_begin()
|
||||||
|
self.session_depth += 1
|
||||||
|
|
||||||
|
def session_end(self):
|
||||||
|
self.session_depth -= 1
|
||||||
|
self.session_depth = max(0, self.session_depth)
|
||||||
|
if self.session_depth == 0:
|
||||||
|
self._session_end()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._close()
|
self._close()
|
||||||
|
Loading…
Reference in New Issue
Block a user