python/debuglink: add docstrings, rename functions for clearer usage

pull/840/head
matejcik 4 years ago
parent 4c8c96272c
commit 81a03edf61

@ -184,8 +184,11 @@ class DebugUI:
def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink
self.pin = None
self.passphrase = "sphinx of black quartz, judge my wov"
self.clear()
def clear(self):
self.pins = None
self.passphrase = ""
self.input_flow = None
def button_request(self, code):
@ -221,14 +224,14 @@ class DebugUI:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code=None):
if isinstance(self.pin, str):
return self.debuglink.encode_pin(self.pin)
elif self.pin == []:
if self.pins is None:
# respond with correct pin
return self.debuglink.read_pin_encoded()
if self.pins == []:
raise AssertionError("PIN sequence ended prematurely")
elif self.pin:
return self.debuglink.encode_pin(self.pin.pop(0))
else:
return self.debuglink.read_pin_encoded()
return self.debuglink.encode_pin(self.pins.pop(0))
def get_passphrase(self, available_on_device):
return self.passphrase
@ -269,8 +272,6 @@ class TrezorClientDebugLink(TrezorClient):
self.expected_responses = None
self.current_response = None
# Use blank passphrase
self.set_passphrase("")
super().__init__(transport, ui=self.ui)
def open(self):
@ -282,6 +283,15 @@ class TrezorClientDebugLink(TrezorClient):
super().close()
def set_filter(self, message_type, callback):
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
self.filters[message_type] = callback
def _filter_message(self, msg):
@ -293,10 +303,30 @@ class TrezorClientDebugLink(TrezorClient):
return msg
def set_input_flow(self, input_flow):
if input_flow is None:
self.ui.input_flow = None
return
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with client:
>>> client.set_input_flow(input_flow)
>>> some_call(client)
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
@ -331,29 +361,44 @@ class TrezorClientDebugLink(TrezorClient):
finally:
# Cleanup
self.set_input_flow(None)
self.expected_responses = None
self.current_response = None
self.ui.pin = None
self.ui.clear()
return False
def set_expected_responses(self, expected):
"""Set a sequence of expected responses to client calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an AssertionError is raised.
If an expected response is given a field value other than None, that field value
must exactly match the received field value. If a given field is None
(or unspecified) in the expected response, the received field value is not
checked.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
self.current_response = 0
def set_pin(self, pin):
if isinstance(pin, str):
self.ui.pin = pin
else:
self.ui.pin = list(pin)
def set_passphrase(self, passphrase):
def use_pin_sequence(self, pins):
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
"""
# XXX This currently only works on T1 as a response to PinMatrixRequest, but
# if we modify trezor-core to introduce PIN prompts predictably (i.e. by
# a new ButtonRequestType), it could also be used on TT via debug.input()
self.ui.pins = list(pins)
def use_passphrase(self, passphrase):
"""Respond to passphrase prompts from device with the provided passphrase."""
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def set_mnemonic(self, mnemonic):
def use_mnemonic(self, mnemonic):
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self):

@ -46,7 +46,7 @@ from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def test_cardano_get_address(client, path, expected_address):
# enter passphrase
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
address = get_address(client, parse_path(path))
assert address == expected_address

@ -49,7 +49,7 @@ from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def test_cardano_get_public_key(client, path, public_key, chain_code):
# enter passphrase
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
key = get_public_key(client, parse_path(path))

@ -137,7 +137,7 @@ def test_cardano_sign_tx(
client.debug.swipe_up()
client.debug.press_yes()
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
with client:
client.set_expected_responses(expected_responses)
client.set_input_flow(input_flow)

@ -45,7 +45,7 @@ def _set_wipe_code(client, wipe_code):
messages.PinMatrixRequest(type=PinType.WipeCodeSecond),
]
client.set_pin(pins)
client.use_pin_sequence(pins)
client.set_expected_responses(
[messages.ButtonRequest()]
+ pin_matrices
@ -57,7 +57,7 @@ def _set_wipe_code(client, wipe_code):
def _change_pin(client, old_pin, new_pin):
assert client.features.pin_protection is True
with client:
client.set_pin([old_pin, new_pin, new_pin])
client.use_pin_sequence([old_pin, new_pin, new_pin])
try:
return device.change_pin(client)
except exceptions.TrezorFailure as f:
@ -110,7 +110,7 @@ def test_set_wipe_code_mismatch(client):
# Let's set a new wipe code.
with client:
client.set_pin([WIPE_CODE4, WIPE_CODE6])
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
client.set_expected_responses(
[
messages.ButtonRequest(),
@ -134,7 +134,7 @@ def test_set_wipe_code_to_pin(client):
# Let's try setting the wipe code to the curent PIN value.
with client:
client.set_pin([PIN4, PIN4])
client.use_pin_sequence([PIN4, PIN4])
client.set_expected_responses(
[
messages.ButtonRequest(),
@ -157,7 +157,7 @@ def test_set_pin_to_wipe_code(client):
# Try to set the PIN to the current wipe code value.
with client:
client.set_pin([WIPE_CODE4, WIPE_CODE4])
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
client.set_expected_responses(
[
messages.ButtonRequest(),

@ -52,7 +52,7 @@ class TestDeviceLoad:
passphrase_protection=True,
label="test",
)
client.set_passphrase("passphrase")
client.use_passphrase("passphrase")
state = client.debug.state()
assert state.mnemonic_secret == MNEMONIC12.encode()
@ -114,7 +114,7 @@ class TestDeviceLoad:
language="en-US",
skip_checksum=True,
)
client.set_passphrase(passphrase_nfkd)
client.use_passphrase(passphrase_nfkd)
address_nfkd = btc.get_address(client, "Bitcoin", [])
device.wipe(client)
@ -127,7 +127,7 @@ class TestDeviceLoad:
language="en-US",
skip_checksum=True,
)
client.set_passphrase(passphrase_nfc)
client.use_passphrase(passphrase_nfc)
address_nfc = btc.get_address(client, "Bitcoin", [])
device.wipe(client)
@ -140,7 +140,7 @@ class TestDeviceLoad:
language="en-US",
skip_checksum=True,
)
client.set_passphrase(passphrase_nfkc)
client.use_passphrase(passphrase_nfkc)
address_nfkc = btc.get_address(client, "Bitcoin", [])
device.wipe(client)
@ -153,7 +153,7 @@ class TestDeviceLoad:
language="en-US",
skip_checksum=True,
)
client.set_passphrase(passphrase_nfd)
client.use_passphrase(passphrase_nfd)
address_nfd = btc.get_address(client, "Bitcoin", [])
assert address_nfkd == address_nfc

@ -30,12 +30,12 @@ def test_128bit_passphrase(client):
xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV
"""
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", [])
assert address == "1CX5rv2vbSV8YFAZEAdMwRVqbxxswPnSPw"
client.state = None
client.clear_session()
client.set_passphrase("ROZERT")
client.use_passphrase("ROZERT")
address_compare = btc.get_address(client, "Bitcoin", [])
assert address != address_compare
@ -49,11 +49,11 @@ def test_256bit_passphrase(client):
xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c
"""
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", [])
assert address == "18oNx6UczHWASBQXc5XQqdSdAAZyhUwdQU"
client.state = None
client.clear_session()
client.set_passphrase("ROZERT")
client.use_passphrase("ROZERT")
address_compare = btc.get_address(client, "Bitcoin", [])
assert address != address_compare

@ -30,7 +30,7 @@ def test_3of6_passphrase(client):
xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce
"""
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", [])
assert address == "18oZEMRWurCZW1FeK8sWYyXuWx2bFqEKyX"
@ -50,6 +50,6 @@ def test_2of5_passphrase(client):
xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni
"""
assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR")
client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", [])
assert address == "19Fjs9AvT13Y2Nx8GtoVfADmFWnccsPinQ"

@ -46,30 +46,23 @@ class TestProtectCall:
@pytest.mark.setup_client(pin="1234")
def test_incorrect_pin(self, client):
client.set_pin("5678")
with pytest.raises(PinException):
client.use_pin_sequence(["5678"])
self._some_protected_call(client)
@pytest.mark.setup_client(pin="1234", passphrase=True)
def test_exponential_backoff_with_reboot(self, client):
client.set_pin("5678")
def test_backoff(attempts, start):
if attempts <= 1:
expected = 0
else:
expected = (2 ** (attempts - 1)) - 1
got = round(time.time() - start, 2)
msg = "Pin delay expected to be at least %s seconds, got %s" % (
expected,
got,
)
print(msg)
assert got >= expected
for attempt in range(1, 4):
start = time.time()
with pytest.raises(PinException):
with client, pytest.raises(PinException):
client.use_pin_sequence(["5678"])
self._some_protected_call(client)
test_backoff(attempt, start)

@ -125,7 +125,7 @@ class TestProtectionLevels:
@pytest.mark.setup_client(uninitialized=True)
def test_recovery_device(self, client):
client.set_mnemonic(MNEMONIC12)
client.use_mnemonic(MNEMONIC12)
with client:
client.set_expected_responses(
[proto.ButtonRequest()]

@ -15,7 +15,7 @@ def setup_device_legacy(client, pin, wipe_code):
)
with client:
client.set_pin([PIN, WIPE_CODE, WIPE_CODE])
client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client)

Loading…
Cancel
Save