diff --git a/bitkeylib/client.py b/bitkeylib/client.py index 9463468b58..81ed98941c 100644 --- a/bitkeylib/client.py +++ b/bitkeylib/client.py @@ -64,28 +64,34 @@ class BitkeyClient(object): print '----------------------' print "Sending", self._pprint(msg) - self.transport.write(msg) - resp = self.transport.read_blocking() - - if isinstance(resp, proto.ButtonRequest): - if self.debuglink and self.debug_button: - print "Pressing button", self.debug_button - self.debuglink.press_button(self.debug_button) - - return self.call(proto.ButtonAck()) - - if isinstance(resp, proto.PinMatrixRequest): - if self.debuglink: - if self.debug_pin: - pin = self.debuglink.read_pin_encoded() - msg2 = proto.PinMatrixAck(pin=pin) - else: - msg2 = proto.PinMatrixAck(pin='444444222222') - else: - pin = self.input_func("PIN required: ", resp.message) - msg2 = proto.PinMatrixAck(pin=pin) + try: + self.transport.session_begin() + + self.transport.write(msg) + resp = self.transport.read_blocking() + + if isinstance(resp, proto.ButtonRequest): + if self.debuglink and self.debug_button: + print "Pressing button", self.debug_button + self.debuglink.press_button(self.debug_button) - return self.call(msg2) + return self.call(proto.ButtonAck()) + + if isinstance(resp, proto.PinMatrixRequest): + if self.debuglink: + if self.debug_pin: + pin = self.debuglink.read_pin_encoded() + msg2 = proto.PinMatrixAck(pin=pin) + else: + msg2 = proto.PinMatrixAck(pin='444444222222') + else: + pin = self.input_func("PIN required: ", resp.message) + msg2 = proto.PinMatrixAck(pin=pin) + + return self.call(msg2) + + finally: + self.transport.session_end() if isinstance(resp, proto.Failure): self.message_func(resp.message) @@ -132,49 +138,55 @@ class BitkeyClient(object): start = time.time() - # Prepare and send initial message - tx = proto.SignTx() - tx.inputs_count = len(inputs) - tx.outputs_count = len(outputs) - res = self.call(tx) - - # Prepare structure for signatures - signatures = [None]*len(inputs) - serialized_tx = '' - - counter = 0 - while True: - counter += 1 + try: + self.transport.session_begin() - if isinstance(res, proto.Failure): - raise CallException("Signing failed") - - if not isinstance(res, proto.TxRequest): - raise CallException("Unexpected message") - - # If there's some part of signed transaction, let's add it - if res.serialized_tx: - print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx) - serialized_tx += res.serialized_tx - - if res.signed_index >= 0 and res.signature: - print "!!! SIGNED INPUT", res.signed_index - signatures[res.signed_index] = res.signature + # Prepare and send initial message + tx = proto.SignTx() + tx.inputs_count = len(inputs) + tx.outputs_count = len(outputs) + res = self.call(tx) + + # Prepare structure for signatures + signatures = [None]*len(inputs) + serialized_tx = '' + + counter = 0 + while True: + counter += 1 - if res.request_index < 0: - # Device didn't ask for more information, finish workflow - break - - # Device asked for one more information, let's process it. - if res.request_type == proto.TXOUTPUT: - res = self.call(outputs[res.request_index]) - continue - - elif res.request_type == proto.TXINPUT: - print "REQUESTING", res.request_index - res = self.call(inputs[res.request_index]) - continue - + if isinstance(res, proto.Failure): + raise CallException("Signing failed") + + if not isinstance(res, proto.TxRequest): + raise CallException("Unexpected message") + + # If there's some part of signed transaction, let's add it + if res.serialized_tx: + print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx) + serialized_tx += res.serialized_tx + + if res.signed_index >= 0 and res.signature: + print "!!! SIGNED INPUT", res.signed_index + signatures[res.signed_index] = res.signature + + if res.request_index < 0: + # Device didn't ask for more information, finish workflow + break + + # Device asked for one more information, let's process it. + if res.request_type == proto.TXOUTPUT: + res = self.call(outputs[res.request_index]) + continue + + elif res.request_type == proto.TXINPUT: + print "REQUESTING", res.request_index + res = self.call(inputs[res.request_index]) + continue + + finally: + self.transport.session_end() + print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \ (time.time() - start, counter, len(serialized_tx)) diff --git a/bitkeylib/transport.py b/bitkeylib/transport.py index ad13c3f628..c7e8060746 100644 --- a/bitkeylib/transport.py +++ b/bitkeylib/transport.py @@ -7,6 +7,7 @@ class NotImplementedException(Exception): class Transport(object): def __init__(self, device, *args, **kwargs): self.device = device + self.session_depth = 0 self._open() def _open(self): @@ -21,8 +22,25 @@ class Transport(object): def _read(self): raise NotImplementedException("Not implemented") + def _session_begin(self): + pass + + def _session_end(self): + pass + def ready_to_read(self): raise NotImplementedException("Not implemented") + + def session_begin(self): + if self.session_depth == 0: + self._session_begin() + self.session_depth += 1 + + def session_end(self): + self.session_depth -= 1 + self.session_depth = max(0, self.session_depth) + if self.session_depth == 0: + self._session_end() def close(self): self._close()