1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-20 20:31:06 +00:00

trezorlib/transport: smarter handling of prefix search

For UDP transport, it's useful to be able to specify a path that should be tried directly,
without enumerating first.
This commit is contained in:
matejcik 2018-03-02 18:22:33 +01:00
parent d2913c20bd
commit 6519657808
2 changed files with 28 additions and 9 deletions

View File

@ -53,7 +53,7 @@ class Transport(object):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def find_by_path(cls, path, prefix_search = True): def find_by_path(cls, path, prefix_search=False):
for device in cls.enumerate(): for device in cls.enumerate():
if path is None or device.get_path() == path \ if path is None or device.get_path() == path \
or (prefix_search and device.get_path().startswith(path)): or (prefix_search and device.get_path().startswith(path)):

View File

@ -53,15 +53,34 @@ class UdpTransport(Transport):
host, port = self.device host, port = self.device
return UdpTransport('{}:{}'.format(host, port+1), self.protocol) return UdpTransport('{}:{}'.format(host, port+1), self.protocol)
@staticmethod @classmethod
def enumerate(): def _try_path(cls, path):
d = cls(path)
try:
d.open()
if d._ping():
return d
else:
raise TransportException('No TREZOR device found at address {}'.format(path))
finally:
d.close()
@classmethod
def enumerate(cls):
devices = [] devices = []
d = UdpTransport("%s:%d" % (UdpTransport.DEFAULT_HOST, UdpTransport.DEFAULT_PORT)) default_path = '{}:{}'.format(cls.DEFAULT_HOST, cls.DEFAULT_PORT)
d.open() try:
if d._ping(): return [cls._try_path(default_path)]
devices.append(d) except TransportException:
d.close() return []
return devices
@classmethod
def find_by_path(cls, path, prefix_search=False):
if prefix_search:
return super().find_by_path(path, prefix_search)
else:
path = path.replace('{}:'.format(cls.PATH_PREFIX), '')
return cls._try_path(path)
def open(self): def open(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)