parent
df8b367cb2
commit
45f525f916
@ -0,0 +1,71 @@
|
||||
from common import *
|
||||
|
||||
from trezor import io
|
||||
from trezor.loop import wait
|
||||
from trezor.wire import thp_v1
|
||||
from trezor.wire.thp import channel
|
||||
from storage import cache_thp
|
||||
from ubinascii import hexlify
|
||||
from trezor.wire.thp import ChannelState
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
|
||||
def dummy_decode_iface(cached_iface: bytes):
|
||||
return MockHID(0xDEADBEEF)
|
||||
|
||||
|
||||
def getBytes(a):
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
class TestTrezorHostProtocol(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
buffer = bytearray(64)
|
||||
thp_v1.set_buffer(buffer)
|
||||
channel._decode_iface = dummy_decode_iface
|
||||
|
||||
def test_simple(self):
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_channel_allocation(self):
|
||||
cid_req = (
|
||||
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
)
|
||||
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
|
||||
test_counter = cache_thp.cid_counter + 1
|
||||
self.assertEqual(len(thp_v1.CHANNELS), 0)
|
||||
self.assertFalse(test_counter in thp_v1.CHANNELS)
|
||||
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
self.assertEqual(
|
||||
getBytes(self.interface.data[-1]),
|
||||
expected_response,
|
||||
)
|
||||
self.assertTrue(test_counter in thp_v1.CHANNELS)
|
||||
self.assertEqual(len(thp_v1.CHANNELS), 1)
|
||||
|
||||
def test_channel_default_state_is_TH1(self):
|
||||
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue