diff --git a/gns3server/services/authentication.py b/gns3server/services/authentication.py index 1a7cde65..efb69741 100644 --- a/gns3server/services/authentication.py +++ b/gns3server/services/authentication.py @@ -36,10 +36,6 @@ DEFAULT_JWT_SECRET_KEY = "efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76 class AuthService: - def __init__(self): - - self._controller_config = Config.instance().settings.Controller - def hash_password(self, password: str) -> str: return pwd_context.hash(password) @@ -56,15 +52,15 @@ class AuthService: ) -> str: if not expires_in: - expires_in = self._controller_config.jwt_access_token_expire_minutes + expires_in = Config.instance().settings.Controller.jwt_access_token_expire_minutes expire = datetime.utcnow() + timedelta(minutes=expires_in) to_encode = {"sub": username, "exp": expire} if secret_key is None: - secret_key = self._controller_config.jwt_secret_key + secret_key = Config.instance().settings.Controller.jwt_secret_key if secret_key is None: secret_key = DEFAULT_JWT_SECRET_KEY log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") - algorithm = self._controller_config.jwt_algorithm + algorithm = Config.instance().settings.Controller.jwt_algorithm encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) return encoded_jwt @@ -77,11 +73,11 @@ class AuthService: ) try: if secret_key is None: - secret_key = self._controller_config.jwt_secret_key + secret_key = Config.instance().settings.Controller.jwt_secret_key if secret_key is None: secret_key = DEFAULT_JWT_SECRET_KEY log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") - algorithm = self._controller_config.jwt_algorithm + algorithm = Config.instance().settings.Controller.jwt_algorithm payload = jwt.decode(token, secret_key, algorithms=[algorithm]) username: str = payload.get("sub") if username is None: diff --git a/tests/api/routes/compute/test_qemu_nodes.py b/tests/api/routes/compute/test_qemu_nodes.py index 9f64fea2..11596239 100644 --- a/tests/api/routes/compute/test_qemu_nodes.py +++ b/tests/api/routes/compute/test_qemu_nodes.py @@ -445,7 +445,6 @@ async def test_create_img_absolute_non_local(app: FastAPI, client: AsyncClient, async def test_create_img_absolute_local(app: FastAPI, client: AsyncClient, config) -> None: - config.settings.Server.local = True params = { "qemu_img": "/tmp/qemu-img", "path": "/tmp/hda.qcow2", diff --git a/tests/api/routes/controller/test_controller.py b/tests/api/routes/controller/test_controller.py index 686f0163..37656b0c 100644 --- a/tests/api/routes/controller/test_controller.py +++ b/tests/api/routes/controller/test_controller.py @@ -30,7 +30,6 @@ pytestmark = pytest.mark.asyncio async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config) -> None: os.kill = MagicMock() - config.settings.Server.local = True response = await client.post(app.url_path_for("shutdown")) assert response.status_code == status.HTTP_204_NO_CONTENT assert os.kill.called diff --git a/tests/api/routes/controller/test_projects.py b/tests/api/routes/controller/test_projects.py index 81a61aec..01cea201 100644 --- a/tests/api/routes/controller/test_projects.py +++ b/tests/api/routes/controller/test_projects.py @@ -178,7 +178,6 @@ async def test_open_project(app: FastAPI, client: AsyncClient, project: Project) async def test_load_project(app: FastAPI, client: AsyncClient, project: Project, config) -> None: - config.settings.Server.local = True with asyncio_patch("gns3server.controller.Controller.load_project", return_value=project) as mock: response = await client.post(app.url_path_for("load_project"), json={"path": "/tmp/test.gns3"}) assert response.status_code == status.HTTP_201_CREATED diff --git a/tests/compute/iou/test_iou_vm.py b/tests/compute/iou/test_iou_vm.py index 56a00669..de729afb 100644 --- a/tests/compute/iou/test_iou_vm.py +++ b/tests/compute/iou/test_iou_vm.py @@ -222,7 +222,6 @@ async def test_close(vm, port_manager): def test_path(vm, fake_iou_bin, config): - config.settings.Server.local = True vm.path = fake_iou_bin assert vm.path == fake_iou_bin @@ -235,7 +234,6 @@ def test_path_relative(vm, fake_iou_bin): def test_path_invalid_bin(vm, tmpdir, config): - config.settings.Server.local = True path = str(tmpdir / "test.bin") with open(path, "w+") as f: diff --git a/tests/conftest.py b/tests/conftest.py index 0da67ea4..bd0ddf98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import shutil import sys import os import uuid +import configparser from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -24,6 +25,7 @@ from gns3server.api.routes.controller.dependencies.database import get_db_sessio from gns3server import schemas from gns3server.schemas.computes import Protocol from gns3server.services import auth_service +from gns3server.services.authentication import DEFAULT_JWT_SECRET_KEY sys._called_from_test = True sys.original_platform = sys.platform @@ -192,9 +194,14 @@ def compute_project(tmpdir): @pytest.fixture -def config(): +def config(tmpdir): - config = Config.instance() + path = str(tmpdir / "server.conf") + config = configparser.ConfigParser() + with open(path, "w+") as f: + config.write(f) + Config.reset() + config = Config.instance(files=[path]) config.clear() return config @@ -220,7 +227,6 @@ def symbols_dir(config): path = config.settings.Server.symbols_path os.makedirs(path, exist_ok=True) - print(path) return path @@ -337,6 +343,9 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle for module in MODULES: module._instance = None + config.settings.Controller.jwt_secret_key = DEFAULT_JWT_SECRET_KEY + config.settings.Server.secrets_dir = os.path.join(tmppath, 'secrets') + os.makedirs(os.path.join(tmppath, 'projects')) config.settings.Server.projects_path = os.path.join(tmppath, 'projects') config.settings.Server.symbols_path = os.path.join(tmppath, 'symbols') @@ -368,4 +377,3 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle shutil.rmtree(tmppath) except BaseException: pass - diff --git a/tests/controller/test_snapshot.py b/tests/controller/test_snapshot.py index 07d30db9..fa6d2237 100644 --- a/tests/controller/test_snapshot.py +++ b/tests/controller/test_snapshot.py @@ -95,7 +95,6 @@ async def test_restore(project, controller, config): assert len(project.nodes) == 2 controller._notification = MagicMock() - config.settings.Server.local = True await snapshot.restore() assert "snapshot.restored" in [c[0][0] for c in controller.notification.project_emit.call_args_list] diff --git a/tests/test_run.py b/tests/test_run.py index 0e89de66..d5b38b2a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -67,7 +67,7 @@ def test_parse_arguments(capsys, config, tmpdir): # assert "optional arguments" in out assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" - assert run.parse_arguments([]).host == "localhost" + assert run.parse_arguments([]).host == "0.0.0.0" server_config.host = "192.168.1.2" assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" assert run.parse_arguments([]).host == "192.168.1.2" diff --git a/tests/utils/test_interfaces.py b/tests/utils/test_interfaces.py index ae6aa72c..ecf67f65 100644 --- a/tests/utils/test_interfaces.py +++ b/tests/utils/test_interfaces.py @@ -41,7 +41,6 @@ def test_interfaces(): def test_has_netmask(config): - config.settings.Server.allowed_interfaces = "lo0,lo" if sys.platform.startswith("win"): # No loopback pass