1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-12 16:30:56 +00:00

python/debuglink: add docstrings, rename functions for clearer usage

This commit is contained in:
matejcik 2020-02-12 15:38:18 +01:00
parent 4c8c96272c
commit 81a03edf61
11 changed files with 92 additions and 54 deletions

View File

@ -184,8 +184,11 @@ class DebugUI:
def __init__(self, debuglink: DebugLink): def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink self.debuglink = debuglink
self.pin = None self.clear()
self.passphrase = "sphinx of black quartz, judge my wov"
def clear(self):
self.pins = None
self.passphrase = ""
self.input_flow = None self.input_flow = None
def button_request(self, code): def button_request(self, code):
@ -221,15 +224,15 @@ class DebugUI:
self.input_flow = self.INPUT_FLOW_DONE self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code=None): def get_pin(self, code=None):
if isinstance(self.pin, str): if self.pins is None:
return self.debuglink.encode_pin(self.pin) # respond with correct pin
elif self.pin == []:
raise AssertionError("PIN sequence ended prematurely")
elif self.pin:
return self.debuglink.encode_pin(self.pin.pop(0))
else:
return self.debuglink.read_pin_encoded() return self.debuglink.read_pin_encoded()
if self.pins == []:
raise AssertionError("PIN sequence ended prematurely")
else:
return self.debuglink.encode_pin(self.pins.pop(0))
def get_passphrase(self, available_on_device): def get_passphrase(self, available_on_device):
return self.passphrase return self.passphrase
@ -269,8 +272,6 @@ class TrezorClientDebugLink(TrezorClient):
self.expected_responses = None self.expected_responses = None
self.current_response = None self.current_response = None
# Use blank passphrase
self.set_passphrase("")
super().__init__(transport, ui=self.ui) super().__init__(transport, ui=self.ui)
def open(self): def open(self):
@ -282,6 +283,15 @@ class TrezorClientDebugLink(TrezorClient):
super().close() super().close()
def set_filter(self, message_type, callback): def set_filter(self, message_type, callback):
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
self.filters[message_type] = callback self.filters[message_type] = callback
def _filter_message(self, msg): def _filter_message(self, msg):
@ -293,10 +303,30 @@ class TrezorClientDebugLink(TrezorClient):
return msg return msg
def set_input_flow(self, input_flow): def set_input_flow(self, input_flow):
if input_flow is None: """Configure a sequence of input events for the current with-block.
self.ui.input_flow = None
return
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with client:
>>> client.set_input_flow(input_flow)
>>> some_call(client)
"""
if not self.in_with_statement: if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement") raise RuntimeError("Must be called inside 'with' statement")
@ -331,29 +361,44 @@ class TrezorClientDebugLink(TrezorClient):
finally: finally:
# Cleanup # Cleanup
self.set_input_flow(None)
self.expected_responses = None self.expected_responses = None
self.current_response = None self.current_response = None
self.ui.pin = None self.ui.clear()
return False return False
def set_expected_responses(self, expected): def set_expected_responses(self, expected):
"""Set a sequence of expected responses to client calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an AssertionError is raised.
If an expected response is given a field value other than None, that field value
must exactly match the received field value. If a given field is None
(or unspecified) in the expected response, the received field value is not
checked.
"""
if not self.in_with_statement: if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement") raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected self.expected_responses = expected
self.current_response = 0 self.current_response = 0
def set_pin(self, pin): def use_pin_sequence(self, pins):
if isinstance(pin, str): """Respond to PIN prompts from device with the provided PINs.
self.ui.pin = pin The sequence must be at least as long as the expected number of PIN prompts.
else: """
self.ui.pin = list(pin) # XXX This currently only works on T1 as a response to PinMatrixRequest, but
# if we modify trezor-core to introduce PIN prompts predictably (i.e. by
# a new ButtonRequestType), it could also be used on TT via debug.input()
self.ui.pins = list(pins)
def set_passphrase(self, passphrase): def use_passphrase(self, passphrase):
"""Respond to passphrase prompts from device with the provided passphrase."""
self.ui.passphrase = Mnemonic.normalize_string(passphrase) self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def set_mnemonic(self, mnemonic): def use_mnemonic(self, mnemonic):
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self): def _raw_read(self):

View File

@ -46,7 +46,7 @@ from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def test_cardano_get_address(client, path, expected_address): def test_cardano_get_address(client, path, expected_address):
# enter passphrase # enter passphrase
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = get_address(client, parse_path(path)) address = get_address(client, parse_path(path))
assert address == expected_address assert address == expected_address

View File

@ -49,7 +49,7 @@ from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def test_cardano_get_public_key(client, path, public_key, chain_code): def test_cardano_get_public_key(client, path, public_key, chain_code):
# enter passphrase # enter passphrase
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
key = get_public_key(client, parse_path(path)) key = get_public_key(client, parse_path(path))

View File

@ -137,7 +137,7 @@ def test_cardano_sign_tx(
client.debug.swipe_up() client.debug.swipe_up()
client.debug.press_yes() client.debug.press_yes()
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
with client: with client:
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
client.set_input_flow(input_flow) client.set_input_flow(input_flow)

View File

@ -45,7 +45,7 @@ def _set_wipe_code(client, wipe_code):
messages.PinMatrixRequest(type=PinType.WipeCodeSecond), messages.PinMatrixRequest(type=PinType.WipeCodeSecond),
] ]
client.set_pin(pins) client.use_pin_sequence(pins)
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] [messages.ButtonRequest()]
+ pin_matrices + pin_matrices
@ -57,7 +57,7 @@ def _set_wipe_code(client, wipe_code):
def _change_pin(client, old_pin, new_pin): def _change_pin(client, old_pin, new_pin):
assert client.features.pin_protection is True assert client.features.pin_protection is True
with client: with client:
client.set_pin([old_pin, new_pin, new_pin]) client.use_pin_sequence([old_pin, new_pin, new_pin])
try: try:
return device.change_pin(client) return device.change_pin(client)
except exceptions.TrezorFailure as f: except exceptions.TrezorFailure as f:
@ -110,7 +110,7 @@ def test_set_wipe_code_mismatch(client):
# Let's set a new wipe code. # Let's set a new wipe code.
with client: with client:
client.set_pin([WIPE_CODE4, WIPE_CODE6]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
client.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
@ -134,7 +134,7 @@ def test_set_wipe_code_to_pin(client):
# Let's try setting the wipe code to the curent PIN value. # Let's try setting the wipe code to the curent PIN value.
with client: with client:
client.set_pin([PIN4, PIN4]) client.use_pin_sequence([PIN4, PIN4])
client.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
@ -157,7 +157,7 @@ def test_set_pin_to_wipe_code(client):
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with client: with client:
client.set_pin([WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
client.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),

View File

@ -52,7 +52,7 @@ class TestDeviceLoad:
passphrase_protection=True, passphrase_protection=True,
label="test", label="test",
) )
client.set_passphrase("passphrase") client.use_passphrase("passphrase")
state = client.debug.state() state = client.debug.state()
assert state.mnemonic_secret == MNEMONIC12.encode() assert state.mnemonic_secret == MNEMONIC12.encode()
@ -114,7 +114,7 @@ class TestDeviceLoad:
language="en-US", language="en-US",
skip_checksum=True, skip_checksum=True,
) )
client.set_passphrase(passphrase_nfkd) client.use_passphrase(passphrase_nfkd)
address_nfkd = btc.get_address(client, "Bitcoin", []) address_nfkd = btc.get_address(client, "Bitcoin", [])
device.wipe(client) device.wipe(client)
@ -127,7 +127,7 @@ class TestDeviceLoad:
language="en-US", language="en-US",
skip_checksum=True, skip_checksum=True,
) )
client.set_passphrase(passphrase_nfc) client.use_passphrase(passphrase_nfc)
address_nfc = btc.get_address(client, "Bitcoin", []) address_nfc = btc.get_address(client, "Bitcoin", [])
device.wipe(client) device.wipe(client)
@ -140,7 +140,7 @@ class TestDeviceLoad:
language="en-US", language="en-US",
skip_checksum=True, skip_checksum=True,
) )
client.set_passphrase(passphrase_nfkc) client.use_passphrase(passphrase_nfkc)
address_nfkc = btc.get_address(client, "Bitcoin", []) address_nfkc = btc.get_address(client, "Bitcoin", [])
device.wipe(client) device.wipe(client)
@ -153,7 +153,7 @@ class TestDeviceLoad:
language="en-US", language="en-US",
skip_checksum=True, skip_checksum=True,
) )
client.set_passphrase(passphrase_nfd) client.use_passphrase(passphrase_nfd)
address_nfd = btc.get_address(client, "Bitcoin", []) address_nfd = btc.get_address(client, "Bitcoin", [])
assert address_nfkd == address_nfc assert address_nfkd == address_nfc

View File

@ -30,12 +30,12 @@ def test_128bit_passphrase(client):
xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV xprv9s21ZrQH143K3dzDLfeY3cMp23u5vDeFYftu5RPYZPucKc99mNEddU4w99GxdgUGcSfMpVDxhnR1XpJzZNXRN1m6xNgnzFS5MwMP6QyBRKV
""" """
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", []) address = btc.get_address(client, "Bitcoin", [])
assert address == "1CX5rv2vbSV8YFAZEAdMwRVqbxxswPnSPw" assert address == "1CX5rv2vbSV8YFAZEAdMwRVqbxxswPnSPw"
client.state = None client.state = None
client.clear_session() client.clear_session()
client.set_passphrase("ROZERT") client.use_passphrase("ROZERT")
address_compare = btc.get_address(client, "Bitcoin", []) address_compare = btc.get_address(client, "Bitcoin", [])
assert address != address_compare assert address != address_compare
@ -49,11 +49,11 @@ def test_256bit_passphrase(client):
xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c xprv9s21ZrQH143K2UspC9FRPfQC9NcDB4HPkx1XG9UEtuceYtpcCZ6ypNZWdgfxQ9dAFVeD1F4Zg4roY7nZm2LB7THPD6kaCege3M7EuS8v85c
""" """
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", []) address = btc.get_address(client, "Bitcoin", [])
assert address == "18oNx6UczHWASBQXc5XQqdSdAAZyhUwdQU" assert address == "18oNx6UczHWASBQXc5XQqdSdAAZyhUwdQU"
client.state = None client.state = None
client.clear_session() client.clear_session()
client.set_passphrase("ROZERT") client.use_passphrase("ROZERT")
address_compare = btc.get_address(client, "Bitcoin", []) address_compare = btc.get_address(client, "Bitcoin", [])
assert address != address_compare assert address != address_compare

View File

@ -30,7 +30,7 @@ def test_3of6_passphrase(client):
xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce xprv9s21ZrQH143K2pMWi8jrTawHaj16uKk4CSbvo4Zt61tcrmuUDMx2o1Byzcr3saXNGNvHP8zZgXVdJHsXVdzYFPavxvCyaGyGr1WkAYG83ce
""" """
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", []) address = btc.get_address(client, "Bitcoin", [])
assert address == "18oZEMRWurCZW1FeK8sWYyXuWx2bFqEKyX" assert address == "18oZEMRWurCZW1FeK8sWYyXuWx2bFqEKyX"
@ -50,6 +50,6 @@ def test_2of5_passphrase(client):
xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni xprv9s21ZrQH143K2o6EXEHpVy8TCYoMmkBnDCCESLdR2ieKwmcNG48ck2XJQY4waS7RUQcXqR9N7HnQbUVEDMWYyREdF1idQqxFHuCfK7fqFni
""" """
assert client.features.passphrase_protection is True assert client.features.passphrase_protection is True
client.set_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = btc.get_address(client, "Bitcoin", []) address = btc.get_address(client, "Bitcoin", [])
assert address == "19Fjs9AvT13Y2Nx8GtoVfADmFWnccsPinQ" assert address == "19Fjs9AvT13Y2Nx8GtoVfADmFWnccsPinQ"

View File

@ -46,30 +46,23 @@ class TestProtectCall:
@pytest.mark.setup_client(pin="1234") @pytest.mark.setup_client(pin="1234")
def test_incorrect_pin(self, client): def test_incorrect_pin(self, client):
client.set_pin("5678")
with pytest.raises(PinException): with pytest.raises(PinException):
client.use_pin_sequence(["5678"])
self._some_protected_call(client) self._some_protected_call(client)
@pytest.mark.setup_client(pin="1234", passphrase=True) @pytest.mark.setup_client(pin="1234", passphrase=True)
def test_exponential_backoff_with_reboot(self, client): def test_exponential_backoff_with_reboot(self, client):
client.set_pin("5678")
def test_backoff(attempts, start): def test_backoff(attempts, start):
if attempts <= 1: if attempts <= 1:
expected = 0 expected = 0
else: else:
expected = (2 ** (attempts - 1)) - 1 expected = (2 ** (attempts - 1)) - 1
got = round(time.time() - start, 2) got = round(time.time() - start, 2)
msg = "Pin delay expected to be at least %s seconds, got %s" % (
expected,
got,
)
print(msg)
assert got >= expected assert got >= expected
for attempt in range(1, 4): for attempt in range(1, 4):
start = time.time() start = time.time()
with pytest.raises(PinException): with client, pytest.raises(PinException):
client.use_pin_sequence(["5678"])
self._some_protected_call(client) self._some_protected_call(client)
test_backoff(attempt, start) test_backoff(attempt, start)

View File

@ -125,7 +125,7 @@ class TestProtectionLevels:
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_recovery_device(self, client): def test_recovery_device(self, client):
client.set_mnemonic(MNEMONIC12) client.use_mnemonic(MNEMONIC12)
with client: with client:
client.set_expected_responses( client.set_expected_responses(
[proto.ButtonRequest()] [proto.ButtonRequest()]

View File

@ -15,7 +15,7 @@ def setup_device_legacy(client, pin, wipe_code):
) )
with client: with client:
client.set_pin([PIN, WIPE_CODE, WIPE_CODE]) client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client) device.change_wipe_code(client)