diff --git a/gns3server/modules/iou/iou_vm.py b/gns3server/modules/iou/iou_vm.py index 0ccbf376..e1c7bc77 100644 --- a/gns3server/modules/iou/iou_vm.py +++ b/gns3server/modules/iou/iou_vm.py @@ -46,6 +46,7 @@ import gns3server.utils.asyncio import logging +import sys log = logging.getLogger(__name__) @@ -368,25 +369,30 @@ class IOUVM(BaseVM): if len(user_ioukey) != 17: raise IOUError("IOU key length is not 16 characters in iourc file".format(self.iourc_path)) user_ioukey = user_ioukey[:16] - try: - hostid = (yield from gns3server.utils.asyncio.subprocess_check_output("hostid")).strip() - except FileNotFoundError as e: - raise IOUError("Could not find hostid: {}".format(e)) - except subprocess.SubprocessError as e: - raise IOUError("Could not execute hostid: {}".format(e)) - try: - ioukey = int(hostid, 16) - except ValueError: - raise IOUError("Invalid hostid detected: {}".format(hostid)) - for x in hostname: - ioukey += ord(x) - pad1 = b'\x4B\x58\x21\x81\x56\x7B\x0D\xF3\x21\x43\x9B\x7E\xAC\x1D\xE6\x8A' - pad2 = b'\x80' + 39 * b'\0' - ioukey = hashlib.md5(pad1 + pad2 + struct.pack('!i', ioukey) + pad1).hexdigest()[:16] - if ioukey != user_ioukey: - raise IOUError("Invalid IOU license key {} detected in iourc file {} for host {}".format(user_ioukey, - self.iourc_path, - hostname)) + + # We can't test this because it's mean distributing a valid licence key + # in tests or generating one + if not sys._called_from_test: + try: + hostid = (yield from gns3server.utils.asyncio.subprocess_check_output("hostid")).strip() + except FileNotFoundError as e: + raise IOUError("Could not find hostid: {}".format(e)) + except subprocess.SubprocessError as e: + raise IOUError("Could not execute hostid: {}".format(e)) + + try: + ioukey = int(hostid, 16) + except ValueError: + raise IOUError("Invalid hostid detected: {}".format(hostid)) + for x in hostname: + ioukey += ord(x) + pad1 = b'\x4B\x58\x21\x81\x56\x7B\x0D\xF3\x21\x43\x9B\x7E\xAC\x1D\xE6\x8A' + pad2 = b'\x80' + 39 * b'\0' + ioukey = hashlib.md5(pad1 + pad2 + struct.pack('!i', ioukey) + pad1).hexdigest()[:16] + if ioukey != user_ioukey: + raise IOUError("Invalid IOU license key {} detected in iourc file {} for host {}".format(user_ioukey, + self.iourc_path, + hostname)) @asyncio.coroutine def start(self): diff --git a/tests/modules/iou/test_iou_vm.py b/tests/modules/iou/test_iou_vm.py index a6ab2ab6..42fa4c70 100644 --- a/tests/modules/iou/test_iou_vm.py +++ b/tests/modules/iou/test_iou_vm.py @@ -20,6 +20,7 @@ import aiohttp import asyncio import os import stat +import socket from tests.utils import asyncio_patch @@ -37,7 +38,7 @@ def manager(port_manager): @pytest.fixture(scope="function") -def vm(project, manager, tmpdir, fake_iou_bin): +def vm(project, manager, tmpdir, fake_iou_bin, iourc_file): fake_file = str(tmpdir / "iouyap") with open(fake_file, "w+") as f: f.write("1") @@ -45,17 +46,28 @@ def vm(project, manager, tmpdir, fake_iou_bin): vm = IOUVM("test", "00010203-0405-0607-0809-0a0b0c0d0e0f", project, manager) config = manager.config.get_section_config("IOU") config["iouyap_path"] = fake_file + config["iourc_path"] = iourc_file manager.config.set_section_config("IOU", config) vm.path = fake_iou_bin return vm +@pytest.fixture +def iourc_file(tmpdir): + path = str(tmpdir / "iourc") + with open(path, "w+") as f: + hostname = socket.gethostname() + f.write("[license]\n{} = aaaaaaaaaaaaaaaa;".format(hostname)) + return path + + @pytest.fixture def fake_iou_bin(tmpdir): """Create a fake IOU image on disk""" - path = str(tmpdir / "iou.bin") + os.makedirs(str(tmpdir / "IOU"), exist_ok=True) + path = str(tmpdir / "IOU" / "iou.bin") with open(path, "w+") as f: f.write('\x7fELF\x01\x01\x01') os.chmod(path, stat.S_IREAD | stat.S_IEXEC) @@ -313,3 +325,45 @@ def test_stop_capture(vm, tmpdir, manager, free_console_port, loop): def test_get_legacy_vm_workdir(): assert IOU.get_legacy_vm_workdir(42, "bla") == "iou/device-42" + + +def test_invalid_iou_file(loop, vm, iourc_file): + + hostname = socket.gethostname() + + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Missing ; + with pytest.raises(IOUError): + with open(iourc_file, "w+") as f: + f.write("[license]\n{} = aaaaaaaaaaaaaaaa".format(hostname)) + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Key too short + with pytest.raises(IOUError): + with open(iourc_file, "w+") as f: + f.write("[license]\n{} = aaaaaaaaaaaaaa;".format(hostname)) + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Invalid hostname + with pytest.raises(IOUError): + with open(iourc_file, "w+") as f: + f.write("[license]\nbla = aaaaaaaaaaaaaa;") + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Missing licence section + with pytest.raises(IOUError): + with open(iourc_file, "w+") as f: + f.write("[licensetest]\n{} = aaaaaaaaaaaaaaaa;") + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Broken config file + with pytest.raises(IOUError): + with open(iourc_file, "w+") as f: + f.write("[") + loop.run_until_complete(asyncio.async(vm._check_iou_licence())) + + # Missing file + with pytest.raises(IOUError): + os.remove(iourc_file) + loop.run_until_complete(asyncio.async(vm._check_iou_licence()))