1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-24 17:28:08 +00:00

Generate new config for each test. Fixes tests.

This commit is contained in:
grossmj 2021-04-12 19:37:59 +09:30
parent 30ebae207f
commit 1b5a5de4bc
9 changed files with 18 additions and 21 deletions

View File

@ -36,10 +36,6 @@ DEFAULT_JWT_SECRET_KEY = "efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76
class AuthService: class AuthService:
def __init__(self):
self._controller_config = Config.instance().settings.Controller
def hash_password(self, password: str) -> str: def hash_password(self, password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
@ -56,15 +52,15 @@ class AuthService:
) -> str: ) -> str:
if not expires_in: if not expires_in:
expires_in = self._controller_config.jwt_access_token_expire_minutes expires_in = Config.instance().settings.Controller.jwt_access_token_expire_minutes
expire = datetime.utcnow() + timedelta(minutes=expires_in) expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire} to_encode = {"sub": username, "exp": expire}
if secret_key is None: if secret_key is None:
secret_key = self._controller_config.jwt_secret_key secret_key = Config.instance().settings.Controller.jwt_secret_key
if secret_key is None: if secret_key is None:
secret_key = DEFAULT_JWT_SECRET_KEY secret_key = DEFAULT_JWT_SECRET_KEY
log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!")
algorithm = self._controller_config.jwt_algorithm algorithm = Config.instance().settings.Controller.jwt_algorithm
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt return encoded_jwt
@ -77,11 +73,11 @@ class AuthService:
) )
try: try:
if secret_key is None: if secret_key is None:
secret_key = self._controller_config.jwt_secret_key secret_key = Config.instance().settings.Controller.jwt_secret_key
if secret_key is None: if secret_key is None:
secret_key = DEFAULT_JWT_SECRET_KEY secret_key = DEFAULT_JWT_SECRET_KEY
log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!")
algorithm = self._controller_config.jwt_algorithm algorithm = Config.instance().settings.Controller.jwt_algorithm
payload = jwt.decode(token, secret_key, algorithms=[algorithm]) payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:

View File

@ -445,7 +445,6 @@ async def test_create_img_absolute_non_local(app: FastAPI, client: AsyncClient,
async def test_create_img_absolute_local(app: FastAPI, client: AsyncClient, config) -> None: async def test_create_img_absolute_local(app: FastAPI, client: AsyncClient, config) -> None:
config.settings.Server.local = True
params = { params = {
"qemu_img": "/tmp/qemu-img", "qemu_img": "/tmp/qemu-img",
"path": "/tmp/hda.qcow2", "path": "/tmp/hda.qcow2",

View File

@ -30,7 +30,6 @@ pytestmark = pytest.mark.asyncio
async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config) -> None: async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config) -> None:
os.kill = MagicMock() os.kill = MagicMock()
config.settings.Server.local = True
response = await client.post(app.url_path_for("shutdown")) response = await client.post(app.url_path_for("shutdown"))
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert os.kill.called assert os.kill.called

View File

@ -178,7 +178,6 @@ async def test_open_project(app: FastAPI, client: AsyncClient, project: Project)
async def test_load_project(app: FastAPI, client: AsyncClient, project: Project, config) -> None: async def test_load_project(app: FastAPI, client: AsyncClient, project: Project, config) -> None:
config.settings.Server.local = True
with asyncio_patch("gns3server.controller.Controller.load_project", return_value=project) as mock: with asyncio_patch("gns3server.controller.Controller.load_project", return_value=project) as mock:
response = await client.post(app.url_path_for("load_project"), json={"path": "/tmp/test.gns3"}) response = await client.post(app.url_path_for("load_project"), json={"path": "/tmp/test.gns3"})
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED

View File

@ -222,7 +222,6 @@ async def test_close(vm, port_manager):
def test_path(vm, fake_iou_bin, config): def test_path(vm, fake_iou_bin, config):
config.settings.Server.local = True
vm.path = fake_iou_bin vm.path = fake_iou_bin
assert vm.path == fake_iou_bin assert vm.path == fake_iou_bin
@ -235,7 +234,6 @@ def test_path_relative(vm, fake_iou_bin):
def test_path_invalid_bin(vm, tmpdir, config): def test_path_invalid_bin(vm, tmpdir, config):
config.settings.Server.local = True
path = str(tmpdir / "test.bin") path = str(tmpdir / "test.bin")
with open(path, "w+") as f: with open(path, "w+") as f:

View File

@ -5,6 +5,7 @@ import shutil
import sys import sys
import os import os
import uuid import uuid
import configparser
from fastapi import FastAPI from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@ -24,6 +25,7 @@ from gns3server.api.routes.controller.dependencies.database import get_db_sessio
from gns3server import schemas from gns3server import schemas
from gns3server.schemas.computes import Protocol from gns3server.schemas.computes import Protocol
from gns3server.services import auth_service from gns3server.services import auth_service
from gns3server.services.authentication import DEFAULT_JWT_SECRET_KEY
sys._called_from_test = True sys._called_from_test = True
sys.original_platform = sys.platform sys.original_platform = sys.platform
@ -192,9 +194,14 @@ def compute_project(tmpdir):
@pytest.fixture @pytest.fixture
def config(): def config(tmpdir):
config = Config.instance() path = str(tmpdir / "server.conf")
config = configparser.ConfigParser()
with open(path, "w+") as f:
config.write(f)
Config.reset()
config = Config.instance(files=[path])
config.clear() config.clear()
return config return config
@ -220,7 +227,6 @@ def symbols_dir(config):
path = config.settings.Server.symbols_path path = config.settings.Server.symbols_path
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
print(path)
return path return path
@ -337,6 +343,9 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle
for module in MODULES: for module in MODULES:
module._instance = None module._instance = None
config.settings.Controller.jwt_secret_key = DEFAULT_JWT_SECRET_KEY
config.settings.Server.secrets_dir = os.path.join(tmppath, 'secrets')
os.makedirs(os.path.join(tmppath, 'projects')) os.makedirs(os.path.join(tmppath, 'projects'))
config.settings.Server.projects_path = os.path.join(tmppath, 'projects') config.settings.Server.projects_path = os.path.join(tmppath, 'projects')
config.settings.Server.symbols_path = os.path.join(tmppath, 'symbols') config.settings.Server.symbols_path = os.path.join(tmppath, 'symbols')
@ -368,4 +377,3 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle
shutil.rmtree(tmppath) shutil.rmtree(tmppath)
except BaseException: except BaseException:
pass pass

View File

@ -95,7 +95,6 @@ async def test_restore(project, controller, config):
assert len(project.nodes) == 2 assert len(project.nodes) == 2
controller._notification = MagicMock() controller._notification = MagicMock()
config.settings.Server.local = True
await snapshot.restore() await snapshot.restore()
assert "snapshot.restored" in [c[0][0] for c in controller.notification.project_emit.call_args_list] assert "snapshot.restored" in [c[0][0] for c in controller.notification.project_emit.call_args_list]

View File

@ -67,7 +67,7 @@ def test_parse_arguments(capsys, config, tmpdir):
# assert "optional arguments" in out # assert "optional arguments" in out
assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1"
assert run.parse_arguments([]).host == "localhost" assert run.parse_arguments([]).host == "0.0.0.0"
server_config.host = "192.168.1.2" server_config.host = "192.168.1.2"
assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1"
assert run.parse_arguments([]).host == "192.168.1.2" assert run.parse_arguments([]).host == "192.168.1.2"

View File

@ -41,7 +41,6 @@ def test_interfaces():
def test_has_netmask(config): def test_has_netmask(config):
config.settings.Server.allowed_interfaces = "lo0,lo"
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
# No loopback # No loopback
pass pass