tests: finish extracting common functionality for emulators

pull/520/head
matejcik 5 years ago
parent 643122b651
commit 6e4921c030

@ -14,42 +14,38 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from collections import defaultdict
import os import os
import subprocess import subprocess
import tempfile import tempfile
import time import time
from collections import defaultdict
from trezorlib.debuglink import TrezorClientDebugLink from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import TransportException, get_transport from trezorlib.transport import TransportException, get_transport
BINDIR = os.path.dirname(os.path.abspath(__file__)) + "/emulators" BINDIR = os.path.dirname(os.path.abspath(__file__)) + "/emulators"
ENV = {"SDL_VIDEODRIVER": "dummy"}
ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../" ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../"
LOCAL_BUILDS = { LOCAL_BUILD_PATHS = {
"core": ROOT + "core/build/unix/micropython", "core": ROOT + "core/build/unix/micropython",
"legacy": ROOT + "legacy/firmware/trezor.elf", "legacy": ROOT + "legacy/firmware/trezor.elf",
} }
BIN_DIR = os.path.dirname(os.path.abspath(__file__)) + "/emulators"
ENV = {"SDL_VIDEODRIVER": "dummy"}
def check_version(tag, ver_emu):
if tag.startswith("v") and len(tag.split(".")) == 3:
assert tag == "v" + ".".join(["%d" % i for i in ver_emu])
def check_version(tag, version_tuple):
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
version = ".".join(str(i) for i in version_tuple)
if tag[1:] != version:
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
def check_file(gen, tag):
if tag.startswith("/"): def filename_from_tag(gen, tag):
filename = tag return f"{BINDIR}/trezor-emu-{gen}-{tag}"
else:
filename = "%s/trezor-emu-%s-%s" % (BIN_DIR, gen, tag)
if not os.path.exists(filename):
raise ValueError(filename + " not found. Do not forget to build firmware.")
def get_tags(): def get_tags():
files = os.listdir(BIN_DIR) files = os.listdir(BINDIR)
if not files: if not files:
raise ValueError( raise ValueError(
"No files found. Use download_emulators.sh to download emulators." "No files found. Use download_emulators.sh to download emulators."
@ -58,6 +54,7 @@ def get_tags():
result = defaultdict(list) result = defaultdict(list)
for f in sorted(files): for f in sorted(files):
try: try:
# example: "trezor-emu-core-v2.1.1"
_, _, gen, tag = f.split("-", maxsplit=3) _, _, gen, tag = f.split("-", maxsplit=3)
result[gen].append(tag) result[gen].append(tag)
except ValueError: except ValueError:
@ -69,18 +66,28 @@ ALL_TAGS = get_tags()
class EmulatorWrapper: class EmulatorWrapper:
def __init__(self, gen, tag, storage=None): def __init__(self, gen, tag=None, executable=None, storage=None):
self.gen = gen self.gen = gen
self.tag = tag self.tag = tag
if executable is not None:
self.executable = executable
elif tag is not None:
self.executable = filename_from_tag(gen, tag)
else:
self.executable = LOCAL_BUILD_PATHS[gen]
if not os.path.exists(self.executable):
raise ValueError(f"emulator executable not found: {self.executable}")
self.workdir = tempfile.TemporaryDirectory() self.workdir = tempfile.TemporaryDirectory()
if storage: if storage:
open(self._storage_file(), "wb").write(storage) open(self._storage_file(), "wb").write(storage)
self.client = None
def __enter__(self): def __enter__(self):
if self.tag.startswith("/"): # full path+filename provided args = [self.executable]
args = [self.tag]
else: # only gen+tag provided
args = ["%s/trezor-emu-%s-%s" % (BINDIR, self.gen, self.tag)]
env = ENV env = ENV
if self.gen == "core": if self.gen == "core":
args += ["-m", "main"] args += ["-m", "main"]
@ -88,39 +95,35 @@ class EmulatorWrapper:
env["TREZOR_PROFILE_DIR"] = self.workdir.name env["TREZOR_PROFILE_DIR"] = self.workdir.name
# for firmware 2.1.1 and older # for firmware 2.1.1 and older
env["TREZOR_PROFILE"] = self.workdir.name env["TREZOR_PROFILE"] = self.workdir.name
self.client = None
self.process = subprocess.Popen( self.process = subprocess.Popen(
args, cwd=self.workdir.name, env=ENV, stdout=open(os.devnull, "w") args, cwd=self.workdir.name, env=env, stdout=open(os.devnull, "w")
) )
# wait until emulator is listening # wait until emulator is listening
i = 0 for _ in range(100):
while True:
try: try:
i += 1
if i > 100:
self.__exit__(None, None, None)
raise RuntimeError("Can't connect to emulator")
self.transport = get_transport("udp:127.0.0.1:21324")
except TransportException:
time.sleep(0.1) time.sleep(0.1)
continue transport = get_transport("udp:127.0.0.1:21324")
break break
self.client = TrezorClientDebugLink(self.transport) except TransportException:
pass
if self.process.poll() is not None:
self._cleanup()
raise RuntimeError("Emulator proces died")
else:
# could not connect after 100 attempts * 0.1s = 10s of waiting
self._cleanup()
raise RuntimeError("Can't connect to emulator")
self.client = TrezorClientDebugLink(transport)
self.client.open() self.client.open()
# check whether the reported version matches the expected one check_version(self.tag, self.client.version)
if self.tag[0] == "v":
version = "v%d.%d.%d" % (
self.client.features["major_version"],
self.client.features["minor_version"],
self.client.features["patch_version"],
)
assert self.tag == version, "expected: %s reported: %s" % (
self.tag,
version,
)
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self._cleanup()
return False
def _cleanup(self):
if self.client: if self.client:
self.client.close() self.client.close()
self.process.terminate() self.process.terminate()

@ -15,17 +15,17 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os import os
from collections import defaultdict
import pytest import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device
from trezorlib.tools import H_ from trezorlib.tools import H_
from ..emulators import ALL_TAGS, EmulatorWrapper
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0) MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0) MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)
from ..emulators import EmulatorWrapper, ALL_TAGS, LOCAL_BUILDS
# **** COMMON DEFINITIONS **** # **** COMMON DEFINITIONS ****
@ -41,21 +41,24 @@ def for_all(*args, minimum_version=(1, 0, 0)):
if not args: if not args:
args = ("core", "legacy") args = ("core", "legacy")
enabled_gens = os.environ.get("TREZOR_UPGRADE_TEST", "").split(",") specified_gens = os.environ.get("TREZOR_UPGRADE_TEST")
if specified_gens is not None:
enabled_gens = specified_gens.split(",")
else:
enabled_gens = args
all_params = [] all_params = []
for gen in args: for gen in args:
if gen not in enabled_gens: if gen not in enabled_gens:
continue continue
try: try:
to_tag = LOCAL_BUILDS[gen] to_tag = None
from_tags = ALL_TAGS[gen] + [to_tag] from_tags = ALL_TAGS[gen] + [to_tag]
for from_tag in from_tags: for from_tag in from_tags:
if from_tag.startswith("v"): if from_tag is not None and from_tag.startswith("v"):
tag_version = tuple(int(n) for n in from_tag[1:].split(".")) tag_version = tuple(int(n) for n in from_tag[1:].split("."))
if tag_version < minimum_version: if tag_version < minimum_version:
continue continue
check_file(gen, from_tag)
all_params.append((gen, from_tag, to_tag)) all_params.append((gen, from_tag, to_tag))
except KeyError: except KeyError:
pass pass
@ -69,7 +72,6 @@ def for_all(*args, minimum_version=(1, 0, 0)):
@for_all() @for_all()
def test_upgrade_load(gen, from_tag, to_tag): def test_upgrade_load(gen, from_tag, to_tag):
def asserts(tag, client): def asserts(tag, client):
check_version(tag, emu.client.version)
assert not client.features.pin_protection assert not client.features.pin_protection
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized
@ -98,7 +100,6 @@ def test_upgrade_load(gen, from_tag, to_tag):
@for_all("legacy") @for_all("legacy")
def test_upgrade_reset(gen, from_tag, to_tag): def test_upgrade_reset(gen, from_tag, to_tag):
def asserts(tag, client): def asserts(tag, client):
check_version(tag, emu.client.version)
assert not client.features.pin_protection assert not client.features.pin_protection
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized
@ -132,7 +133,6 @@ def test_upgrade_reset(gen, from_tag, to_tag):
@for_all() @for_all()
def test_upgrade_reset_skip_backup(gen, from_tag, to_tag): def test_upgrade_reset_skip_backup(gen, from_tag, to_tag):
def asserts(tag, client): def asserts(tag, client):
check_version(tag, emu.client.version)
assert not client.features.pin_protection assert not client.features.pin_protection
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized
@ -167,7 +167,6 @@ def test_upgrade_reset_skip_backup(gen, from_tag, to_tag):
@for_all(minimum_version=(1, 7, 2)) @for_all(minimum_version=(1, 7, 2))
def test_upgrade_reset_no_backup(gen, from_tag, to_tag): def test_upgrade_reset_no_backup(gen, from_tag, to_tag):
def asserts(tag, client): def asserts(tag, client):
check_version(tag, emu.client.version)
assert not client.features.pin_protection assert not client.features.pin_protection
assert not client.features.passphrase_protection assert not client.features.passphrase_protection
assert client.features.initialized assert client.features.initialized

Loading…
Cancel
Save