1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-02-17 18:42:00 +00:00

Use Pydantic to validate the server config file.

This commit is contained in:
grossmj 2021-04-12 17:02:23 +09:30
parent 478119b40d
commit 30ebae207f
61 changed files with 871 additions and 695 deletions

View File

@ -1,22 +1,27 @@
[Server] [Server]
; What protocol the server uses (http or https)
protocol = http
; IP where the server listen for connections ; IP where the server listen for connections
host = 0.0.0.0 host = 0.0.0.0
; HTTP port for controlling the servers ; HTTP port for controlling the servers
port = 3080 port = 3080
; Secrets directory
secrets_dir = /home/gns3/.config/GNS3/secrets
; Options to enable SSL encryption ; Options to enable SSL encryption
ssl = False ssl = False
certfile = /home/gns3/.config/GNS3/ssl/server.cert certfile = /home/gns3/.config/GNS3/ssl/server.cert
certkey = /home/gns3/.config/GNS3/ssl/server.key certkey = /home/gns3/.config/GNS3/ssl/server.key
; Options for JWT tokens (user authentication)
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
jwt_algorithm = HS256
jwt_access_token_expire_minutes = 1440
; Path where devices images are stored ; Path where devices images are stored
images_path = /home/gns3/GNS3/images images_path = /home/gns3/GNS3/images
; Additional paths to look for images
additional_images_paths = /opt/images;/mnt/disk1/images
; Path where user projects are stored ; Path where user projects are stored
projects_path = /home/gns3/GNS3/projects projects_path = /home/gns3/GNS3/projects
@ -26,6 +31,9 @@ appliances_path = /home/gns3/GNS3/appliances
; Path where custom device symbols are stored ; Path where custom device symbols are stored
symbols_path = /home/gns3/GNS3/symbols symbols_path = /home/gns3/GNS3/symbols
; Path where custom configs are stored
configs_path = /home/gns3/GNS3/configs
; Option to automatically send crash reports to the GNS3 team ; Option to automatically send crash reports to the GNS3 team
report_errors = True report_errors = True
@ -64,6 +72,13 @@ allowed_interfaces = eth0,eth1,virbr0
; Default is virbr0 on Linux (requires libvirt) and vmnet8 for other platforms (requires VMware) ; Default is virbr0 on Linux (requires libvirt) and vmnet8 for other platforms (requires VMware)
default_nat_interface = vmnet10 default_nat_interface = vmnet10
[Controller]
; Options for JWT tokens (user authentication)
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
jwt_algorithm = HS256
jwt_access_token_expire_minutes = 1440
[VPCS] [VPCS]
; VPCS executable location, default: search in PATH ; VPCS executable location, default: search in PATH
;vpcs_path = vpcs ;vpcs_path = vpcs
@ -83,12 +98,24 @@ iourc_path = /home/gns3/.iourc
; Validate if the iourc license file is correct. If you turn this off and your licence is invalid IOU will not start and no errors will be shown. ; Validate if the iourc license file is correct. If you turn this off and your licence is invalid IOU will not start and no errors will be shown.
license_check = True license_check = True
[VirtualBox]
; Path to the VBoxManage binary used to manage VirtualBox
vboxmanage_path = vboxmanage
[VMware]
; Path to the vmrun binary used to manage VMware
vmrun_path = vmrun
vmnet_start_range = 2
vmnet_end_range = 255
block_host_traffic = False
[Qemu] [Qemu]
; !! Remember to add the gns3 user to the KVM group, otherwise you will not have read / write permissions to /dev/kvm !! (Linux only, has priority over enable_hardware_acceleration) ; Use Qemu monitor feature to communicate with Qemu VMs
enable_kvm = True enable_monitor = True
; Require KVM to be installed in order to start VMs (Linux only, has priority over require_hardware_acceleration) ; IP used to listen for the monitor
require_kvm = True monitor_host = 127.0.0.1
; !! Remember to add the gns3 user to the KVM group, otherwise you will not have read / write permissions to /dev/kvm !!
; Enable hardware acceleration (all platforms) ; Enable hardware acceleration (all platforms)
enable_hardware_acceleration = True enable_hardware_acceleration = True
; Require hardware acceleration in order to start VMs (all platforms) ; Require hardware acceleration in order to start VMs
require_hardware_acceleration = False require_hardware_acceleration = False

View File

@ -83,8 +83,7 @@ def compute_version() -> dict:
Retrieve the server version number. Retrieve the server version number.
""" """
config = Config.instance() local_server = Config.instance().settings.Server.local
local_server = config.get_section_config("Server").getboolean("local", False)
return {"version": __version__, "local": local_server} return {"version": __version__, "local": local_server}
@ -153,8 +152,7 @@ async def create_qemu_image(image_data: schemas.QemuImageCreate):
""" """
if os.path.isabs(image_data.path): if os.path.isabs(image_data.path):
config = Config.instance() if Config.instance().settings.Server.local is False:
if config.get_section_config("Server").getboolean("local", False) is False:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
await Qemu.instance().create_disk(image_data.qemu_img, image_data.path, jsonable_encoder(image_data, exclude_unset=True)) await Qemu.instance().create_disk(image_data.qemu_img, image_data.path, jsonable_encoder(image_data, exclude_unset=True))
@ -169,8 +167,7 @@ async def update_qemu_image(image_data: schemas.QemuImageUpdate):
""" """
if os.path.isabs(image_data.path): if os.path.isabs(image_data.path):
config = Config.instance() if Config.instance().settings.Server.local is False:
if config.get_section_config("Server").getboolean("local", False) is False:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if image_data.extend: if image_data.extend:

View File

@ -148,12 +148,7 @@ async def start_qemu_node(node: QemuVM = Depends(dep_node)):
""" """
qemu_manager = Qemu.instance() qemu_manager = Qemu.instance()
hardware_accel = qemu_manager.config.get_section_config("Qemu").getboolean("enable_hardware_acceleration", True) hardware_accel = qemu_manager.config.settings.Qemu.enable_hardware_acceleration
if sys.platform.startswith("linux"):
# the enable_kvm option was used before version 2.0 and has priority
enable_kvm = qemu_manager.config.get_section_config("Qemu").getboolean("enable_kvm")
if enable_kvm is not None:
hardware_accel = enable_kvm
if hardware_accel and "-no-kvm" not in node.options and "-no-hax" not in node.options: if hardware_accel and "-no-kvm" not in node.options and "-no-hax" not in node.options:
pm = ProjectManager.instance() pm = ProjectManager.instance()
if pm.check_hardware_virtualization(node) is False: if pm.check_hardware_virtualization(node) is False:

View File

@ -43,8 +43,7 @@ async def shutdown():
Shutdown the local server Shutdown the local server
""" """
config = Config.instance() if Config.instance().settings.Server.local is False:
if config.get_section_config("Server").getboolean("local", False) is False:
raise ControllerForbiddenError("You can only stop a local server") raise ControllerForbiddenError("You can only stop a local server")
log.info("Start shutting down the server") log.info("Start shutting down the server")
@ -76,8 +75,7 @@ def get_version():
Return the server version number. Return the server version number.
""" """
config = Config.instance() local_server = Config.instance().settings.Server.local
local_server = config.get_section_config("Server").getboolean("local", False)
return {"version": __version__, "local": local_server} return {"version": __version__, "local": local_server}

View File

@ -183,9 +183,8 @@ async def load_project(path: str = Body(..., embed=True)):
""" """
controller = Controller.instance() controller = Controller.instance()
config = Config.instance()
dot_gns3_file = path dot_gns3_file = path
if config.get_section_config("Server").getboolean("local", False) is False: if Config.instance().settings.Server.local is False:
log.error("Cannot load '{}' because the server has not been started with the '--local' parameter".format(dot_gns3_file)) log.error("Cannot load '{}' because the server has not been started with the '--local' parameter".format(dot_gns3_file))
raise ControllerForbiddenError("Cannot load project when server is not local") raise ControllerForbiddenError("Cannot load project when server is not local")
project = await controller.load_project(dot_gns3_file,) project = await controller.load_project(dot_gns3_file,)
@ -313,8 +312,7 @@ async def import_project(project_id: UUID, request: Request, path: Optional[Path
""" """
controller = Controller.instance() controller = Controller.instance()
config = Config.instance() if Config.instance().settings.Server.local is False:
if not config.get_section_config("Server").getboolean("local", False):
raise ControllerForbiddenError("The server is not local") raise ControllerForbiddenError("The server is not local")
# We write the content to a temporary location and after we extract it all. # We write the content to a temporary location and after we extract it all.
@ -353,8 +351,7 @@ async def duplicate_project(project_data: schemas.ProjectDuplicate, project: Pro
""" """
if project_data.path: if project_data.path:
config = Config.instance() if Config.instance().settings.Server.local is False:
if config.get_section_config("Server").getboolean("local", False) is False:
raise ControllerForbiddenError("The server is not a local server") raise ControllerForbiddenError("The server is not a local server")
location = project_data.path location = project_data.path
else: else:

View File

@ -418,7 +418,6 @@ class BaseManager:
return "" return ""
orig_path = path orig_path = path
server_config = self.config.get_section_config("Server")
img_directory = self.get_images_directory() img_directory = self.get_images_directory()
valid_directory_prefices = images_directories(self._NODE_TYPE) valid_directory_prefices = images_directories(self._NODE_TYPE)
if extra_dir: if extra_dir:
@ -445,7 +444,7 @@ class BaseManager:
raise ImageMissingError(orig_path) raise ImageMissingError(orig_path)
# For local server we allow using absolute path outside image directory # For local server we allow using absolute path outside image directory
if server_config.getboolean("local", False) is True: if Config.instance().settings.Server.local is True:
log.debug("Searching for '{}'".format(orig_path)) log.debug("Searching for '{}'".format(orig_path))
path = force_unix_path(path) path = force_unix_path(path)
if os.path.exists(path): if os.path.exists(path):

View File

@ -373,9 +373,8 @@ class BaseNode:
Returns the VNC console port range. Returns the VNC console port range.
""" """
server_config = self._manager.config.get_section_config("Server") vnc_console_start_port_range = self._manager.config.settings.Server.vnc_console_start_port_range
vnc_console_start_port_range = server_config.getint("vnc_console_start_port_range", 5900) vnc_console_end_port_range = self._manager.config.settings.Server.vnc_console_end_port_range
vnc_console_end_port_range = server_config.getint("vnc_console_end_port_range", 10000)
if not 5900 <= vnc_console_start_port_range <= 65535: if not 5900 <= vnc_console_start_port_range <= 65535:
raise NodeError("The VNC console start port range must be between 5900 and 65535") raise NodeError("The VNC console start port range must be between 5900 and 65535")
@ -685,8 +684,7 @@ class BaseNode:
:returns: path to uBridge :returns: path to uBridge
""" """
path = self._manager.config.get_section_config("Server").get("ubridge_path", "ubridge") path = shutil.which(self._manager.config.settings.Server.ubridge_path)
path = shutil.which(path)
return path return path
async def _ubridge_send(self, command): async def _ubridge_send(self, command):
@ -721,8 +719,7 @@ class BaseNode:
if require_privileged_access and not self._manager.has_privileged_access(self.ubridge_path): if require_privileged_access and not self._manager.has_privileged_access(self.ubridge_path):
raise NodeError("uBridge requires root access or the capability to interact with network adapters") raise NodeError("uBridge requires root access or the capability to interact with network adapters")
server_config = self._manager.config.get_section_config("Server") server_host = self._manager.config.settings.Server.host
server_host = server_config.get("host")
if not self.ubridge: if not self.ubridge:
self._ubridge_hypervisor = Hypervisor(self._project, self.ubridge_path, self.working_dir, server_host) self._ubridge_hypervisor = Hypervisor(self._project, self.ubridge_path, self.working_dir, server_host)
log.info("Starting new uBridge hypervisor {}:{}".format(self._ubridge_hypervisor.host, self._ubridge_hypervisor.port)) log.info("Starting new uBridge hypervisor {}:{}".format(self._ubridge_hypervisor.host, self._ubridge_hypervisor.port))

View File

@ -36,12 +36,16 @@ class Nat(Cloud):
def __init__(self, name, node_id, project, manager, ports=None): def __init__(self, name, node_id, project, manager, ports=None):
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
nat_interface = Config.instance().get_section_config("Server").get("default_nat_interface", "virbr0") nat_interface = Config.instance().settings.Server.default_nat_interface
if not nat_interface:
nat_interface = "virbr0"
if nat_interface not in [interface["name"] for interface in gns3server.utils.interfaces.interfaces()]: if nat_interface not in [interface["name"] for interface in gns3server.utils.interfaces.interfaces()]:
raise NodeError("NAT interface {} is missing, please install libvirt".format(nat_interface)) raise NodeError("NAT interface {} is missing, please install libvirt".format(nat_interface))
interface = nat_interface interface = nat_interface
else: else:
nat_interface = Config.instance().get_section_config("Server").get("default_nat_interface", "vmnet8") nat_interface = Config.instance().settings.Server.default_nat_interface
if not nat_interface:
nat_interface = "vmnet8"
interfaces = list(filter(lambda x: nat_interface in x.lower(), interfaces = list(filter(lambda x: nat_interface in x.lower(),
[interface["name"] for interface in gns3server.utils.interfaces.interfaces()])) [interface["name"] for interface in gns3server.utils.interfaces.interfaces()]))
if not len(interfaces): if not len(interfaces):

View File

@ -248,7 +248,7 @@ class Dynamips(BaseManager):
def find_dynamips(self): def find_dynamips(self):
# look for Dynamips # look for Dynamips
dynamips_path = self.config.get_section_config("Dynamips").get("dynamips_path", "dynamips") dynamips_path = self.config.settings.Dynamips.dynamips_path
if not os.path.isabs(dynamips_path): if not os.path.isabs(dynamips_path):
dynamips_path = shutil.which(dynamips_path) dynamips_path = shutil.which(dynamips_path)
@ -279,8 +279,7 @@ class Dynamips(BaseManager):
# FIXME: hypervisor should always listen to 127.0.0.1 # FIXME: hypervisor should always listen to 127.0.0.1
# See https://github.com/GNS3/dynamips/issues/62 # See https://github.com/GNS3/dynamips/issues/62
server_config = self.config.get_section_config("Server") server_host = self.config.settings.Server.host
server_host = server_config.get("host")
try: try:
info = socket.getaddrinfo(server_host, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) info = socket.getaddrinfo(server_host, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
@ -310,7 +309,7 @@ class Dynamips(BaseManager):
async def ghost_ios_support(self, vm): async def ghost_ios_support(self, vm):
ghost_ios_support = self.config.get_section_config("Dynamips").getboolean("ghost_ios_support", True) ghost_ios_support = self.config.settings.Dynamips.ghost_ios_support
if ghost_ios_support: if ghost_ios_support:
async with Dynamips._ghost_ios_lock: async with Dynamips._ghost_ios_lock:
try: try:
@ -483,11 +482,11 @@ class Dynamips(BaseManager):
except IndexError: except IndexError:
raise DynamipsError("WIC slot {} doesn't exist on this router".format(wic_slot_id)) raise DynamipsError("WIC slot {} doesn't exist on this router".format(wic_slot_id))
mmap_support = self.config.get_section_config("Dynamips").getboolean("mmap_support", True) mmap_support = self.config.settings.Dynamips.mmap_support
if mmap_support is False: if mmap_support is False:
await vm.set_mmap(False) await vm.set_mmap(False)
sparse_memory_support = self.config.get_section_config("Dynamips").getboolean("sparse_memory_support", True) sparse_memory_support = self.config.settings.Dynamips.sparse_memory_support
if sparse_memory_support is False: if sparse_memory_support is False:
await vm.set_sparsemem(False) await vm.set_sparsemem(False)

View File

@ -95,9 +95,6 @@ class IOUVM(BaseNode):
self._application_id = application_id self._application_id = application_id
self._l1_keepalives = False # used to overcome the always-up Ethernet interfaces (not supported by all IOSes). self._l1_keepalives = False # used to overcome the always-up Ethernet interfaces (not supported by all IOSes).
def _config(self):
return self._manager.config.get_section_config("IOU")
def _nvram_changed(self, path): def _nvram_changed(self, path):
""" """
Called when the NVRAM file has changed Called when the NVRAM file has changed
@ -248,7 +245,7 @@ class IOUVM(BaseNode):
:returns: path to IOURC :returns: path to IOURC
""" """
iourc_path = self._config().get("iourc_path") iourc_path = self._manager.config.settings.IOU.iourc_path
if not iourc_path: if not iourc_path:
# look for the iourc file in the temporary dir. # look for the iourc file in the temporary dir.
path = os.path.join(self.temporary_directory, "iourc") path = os.path.join(self.temporary_directory, "iourc")
@ -401,7 +398,7 @@ class IOUVM(BaseNode):
try: try:
# we allow license check to be disabled server wide # we allow license check to be disabled server wide
server_wide_license_check = self._config().getboolean("license_check", True) server_wide_license_check = self._manager.config.settings.IOU.license_check
except ValueError: except ValueError:
raise IOUError("Invalid licence check setting") raise IOUError("Invalid licence check setting")

View File

@ -43,15 +43,13 @@ class PortManager:
self._used_tcp_ports = set() self._used_tcp_ports = set()
self._used_udp_ports = set() self._used_udp_ports = set()
server_config = Config.instance().get_section_config("Server") console_start_port_range = Config.instance().settings.Server.console_start_port_range
console_end_port_range = Config.instance().settings.Server.console_end_port_range
console_start_port_range = server_config.getint("console_start_port_range", 5000)
console_end_port_range = server_config.getint("console_end_port_range", 10000)
self._console_port_range = (console_start_port_range, console_end_port_range) self._console_port_range = (console_start_port_range, console_end_port_range)
log.debug(f"Console port range is {console_start_port_range}-{console_end_port_range}") log.debug(f"Console port range is {console_start_port_range}-{console_end_port_range}")
udp_start_port_range = server_config.getint("udp_start_port_range", 20000) udp_start_port_range = Config.instance().settings.Server.udp_start_port_range
udp_end_port_range = server_config.getint("udp_end_port_range", 30000) udp_end_port_range = Config.instance().settings.Server.udp_end_port_range
self._udp_port_range = (udp_start_port_range, udp_end_port_range) self._udp_port_range = (udp_start_port_range, udp_end_port_range)
log.debug(f"UDP port range is {udp_start_port_range}-{udp_end_port_range}") log.debug(f"UDP port range is {udp_start_port_range}-{udp_end_port_range}")
@ -86,8 +84,7 @@ class PortManager:
Bind console host to 0.0.0.0 if remote connections are allowed. Bind console host to 0.0.0.0 if remote connections are allowed.
""" """
server_config = Config.instance().get_section_config("Server") remote_console_connections = Config.instance().settings.Server.allow_remote_console
remote_console_connections = server_config.getboolean("allow_remote_console")
if remote_console_connections: if remote_console_connections:
log.warning("Remote console connections are allowed") log.warning("Remote console connections are allowed")
self._console_host = "0.0.0.0" self._console_host = "0.0.0.0"

View File

@ -85,13 +85,9 @@ class Project:
"variables": self._variables "variables": self._variables
} }
def _config(self):
return Config.instance().get_section_config("Server")
def is_local(self): def is_local(self):
return self._config().getboolean("local", False) return Config.instance().settings.Server.local
@property @property
def id(self): def id(self):

View File

@ -73,9 +73,9 @@ class QemuVM(BaseNode):
def __init__(self, name, node_id, project, manager, linked_clone=True, qemu_path=None, console=None, console_type="telnet", aux=None, aux_type="none", platform=None): def __init__(self, name, node_id, project, manager, linked_clone=True, qemu_path=None, console=None, console_type="telnet", aux=None, aux_type="none", platform=None):
super().__init__(name, node_id, project, manager, console=console, console_type=console_type, linked_clone=linked_clone, aux=aux, aux_type=aux_type, wrap_console=True, wrap_aux=True) super().__init__(name, node_id, project, manager, console=console, console_type=console_type, linked_clone=linked_clone, aux=aux, aux_type=aux_type, wrap_console=True, wrap_aux=True)
server_config = manager.config.get_section_config("Server")
self._host = server_config.get("host", "127.0.0.1") self._host = manager.config.settings.Server.host
self._monitor_host = server_config.get("monitor_host", "127.0.0.1") self._monitor_host = manager.config.settings.Qemu.monitor_host
self._process = None self._process = None
self._cpulimit_process = None self._cpulimit_process = None
self._monitor = None self._monitor = None
@ -1055,7 +1055,7 @@ class QemuVM(BaseNode):
await self.resume() await self.resume()
return return
if self._manager.config.get_section_config("Qemu").getboolean("monitor", True): if self._manager.config.settings.Qemu.enable_monitor:
try: try:
info = socket.getaddrinfo(self._monitor_host, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) info = socket.getaddrinfo(self._monitor_host, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
if not info: if not info:
@ -2112,17 +2112,8 @@ class QemuVM(BaseNode):
:returns: Boolean True if we need to enable hardware acceleration :returns: Boolean True if we need to enable hardware acceleration
""" """
enable_hardware_accel = self.manager.config.get_section_config("Qemu").getboolean("enable_hardware_acceleration", True) enable_hardware_accel = self.manager.config.settings.Qemu.enable_hardware_acceleration
require_hardware_accel = self.manager.config.get_section_config("Qemu").getboolean("require_hardware_acceleration", True) require_hardware_accel = self.manager.config.settings.Qemu.require_hardware_acceleration
if sys.platform.startswith("linux"):
# compatibility: these options were used before version 2.0 and have priority
enable_kvm = self.manager.config.get_section_config("Qemu").getboolean("enable_kvm")
if enable_kvm is not None:
enable_hardware_accel = enable_kvm
require_kvm = self.manager.config.get_section_config("Qemu").getboolean("require_kvm")
if require_kvm is not None:
require_hardware_accel = require_kvm
if enable_hardware_accel and "-no-kvm" not in options and "-no-hax" not in options: if enable_hardware_accel and "-no-kvm" not in options and "-no-hax" not in options:
# Turn OFF hardware acceleration for non x86 architectures # Turn OFF hardware acceleration for non x86 architectures
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@ -2137,7 +2128,7 @@ class QemuVM(BaseNode):
if sys.platform.startswith("linux") and not os.path.exists("/dev/kvm"): if sys.platform.startswith("linux") and not os.path.exists("/dev/kvm"):
if require_hardware_accel: if require_hardware_accel:
raise QemuError("KVM acceleration cannot be used (/dev/kvm doesn't exist). It is possible to turn off KVM support in the gns3_server.conf by adding enable_kvm = false to the [Qemu] section.") raise QemuError("KVM acceleration cannot be used (/dev/kvm doesn't exist). It is possible to turn off KVM support in the gns3_server.conf by adding enable_hardware_acceleration = false to the [Qemu] section.")
else: else:
return False return False
elif sys.platform.startswith("win"): elif sys.platform.startswith("win"):

View File

@ -57,7 +57,7 @@ class VirtualBox(BaseManager):
def find_vboxmanage(self): def find_vboxmanage(self):
# look for VBoxManage # look for VBoxManage
vboxmanage_path = self.config.get_section_config("VirtualBox").get("vboxmanage_path") vboxmanage_path = self.config.settings.VirtualBox.vboxmanage_path
if vboxmanage_path: if vboxmanage_path:
if not os.path.isabs(vboxmanage_path): if not os.path.isabs(vboxmanage_path):
vboxmanage_path = shutil.which(vboxmanage_path) vboxmanage_path = shutil.which(vboxmanage_path)

View File

@ -91,7 +91,7 @@ class VMware(BaseManager):
""" """
# look for vmrun # look for vmrun
vmrun_path = self.config.get_section_config("VMware").get("vmrun_path") vmrun_path = self.config.settings.VMware.vmrun_path
if not vmrun_path: if not vmrun_path:
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
vmrun_path = shutil.which("vmrun") vmrun_path = shutil.which("vmrun")
@ -309,8 +309,8 @@ class VMware(BaseManager):
def is_managed_vmnet(self, vmnet): def is_managed_vmnet(self, vmnet):
self._vmnet_start_range = self.config.get_section_config("VMware").getint("vmnet_start_range", self._vmnet_start_range) self._vmnet_start_range = self.config.settings.VMware.vmnet_start_range
self._vmnet_end_range = self.config.get_section_config("VMware").getint("vmnet_end_range", self._vmnet_end_range) self._vmnet_end_range = self.config.settings.VMware.vmnet_end_range
match = re.search(r"vmnet([0-9]+)$", vmnet, re.IGNORECASE) match = re.search(r"vmnet([0-9]+)$", vmnet, re.IGNORECASE)
if match: if match:
vmnet_number = match.group(1) vmnet_number = match.group(1)

View File

@ -336,7 +336,7 @@ class VMwareVM(BaseNode):
# special case on OSX, we cannot bind VMnet interfaces using the libpcap # special case on OSX, we cannot bind VMnet interfaces using the libpcap
await self._ubridge_send('bridge add_nio_fusion_vmnet {name} "{interface}"'.format(name=vnet, interface=vmnet_interface)) await self._ubridge_send('bridge add_nio_fusion_vmnet {name} "{interface}"'.format(name=vnet, interface=vmnet_interface))
else: else:
block_host_traffic = self.manager.config.get_section_config("VMware").getboolean("block_host_traffic", False) block_host_traffic = self.manager.config.VMware.block_host_traffic
await self._add_ubridge_ethernet_connection(vnet, vmnet_interface, block_host_traffic) await self._add_ubridge_ethernet_connection(vnet, vmnet_interface, block_host_traffic)
if isinstance(nio, NIOUDP): if isinstance(nio, NIOUDP):

View File

@ -138,7 +138,7 @@ class VPCSVM(BaseNode):
:returns: path to VPCS :returns: path to VPCS
""" """
vpcs_path = self._manager.config.get_section_config("VPCS").get("vpcs_path", "vpcs") vpcs_path = self._manager.config.settings.VPCS.vpcs_path
if not os.path.isabs(vpcs_path): if not os.path.isabs(vpcs_path):
vpcs_path = shutil.which(vpcs_path) vpcs_path = shutil.which(vpcs_path)
return vpcs_path return vpcs_path

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (C) 2015 GNS3 Technologies Inc. # Copyright (C) 2021 GNS3 Technologies Inc.
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by # it under the terms of the GNU General Public License as published by
@ -16,14 +16,17 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
Reads the configuration file and store the settings for the controller & compute. Reads the configuration file and store the settings for the server.
""" """
import sys import sys
import os import os
import shutil import shutil
import secrets
import configparser import configparser
from pydantic import ValidationError
from .schemas import ServerConfig
from .version import __version_info__ from .version import __version_info__
from .utils.file_watcher import FileWatcher from .utils.file_watcher import FileWatcher
@ -32,18 +35,19 @@ log = logging.getLogger(__name__)
class Config: class Config:
""" """
Configuration file management using configparser. Configuration file management using configparser.
:param files: Array of configuration files (optional) :param files: Array of configuration files (optional)
:param profile: Profile settings (default use standard settings file) :param profile: Profile settings (default use standard config file)
""" """
def __init__(self, files=None, profile=None): def __init__(self, files=None, profile=None):
self._settings = None
self._files = files self._files = files
self._profile = profile self._profile = profile
if files and len(files): if files and len(files):
directory_name = os.path.dirname(files[0]) directory_name = os.path.dirname(files[0])
if not directory_name or directory_name == "": if not directory_name or directory_name == "":
@ -79,15 +83,6 @@ class Config:
versioned_user_dir = os.path.join(appdata, appname, version) versioned_user_dir = os.path.join(appdata, appname, version)
server_filename = "gns3_server.ini" server_filename = "gns3_server.ini"
controller_filename = "gns3_controller.ini"
# move gns3_controller.conf to gns3_controller.ini (file was renamed in 2.2.0 on Windows)
old_controller_filename = os.path.join(legacy_user_dir, "gns3_controller.conf")
if os.path.exists(old_controller_filename):
try:
shutil.copyfile(old_controller_filename, os.path.join(legacy_user_dir, controller_filename))
except OSError as e:
log.error("Cannot move old controller configuration file: {}".format(e))
if self._files is None and not hasattr(sys, "_called_from_test"): if self._files is None and not hasattr(sys, "_called_from_test"):
self._files = [os.path.join(os.getcwd(), server_filename), self._files = [os.path.join(os.getcwd(), server_filename),
@ -106,7 +101,6 @@ class Config:
home = os.path.expanduser("~") home = os.path.expanduser("~")
server_filename = "gns3_server.conf" server_filename = "gns3_server.conf"
controller_filename = "gns3_controller.conf"
if self._profile: if self._profile:
legacy_user_dir = os.path.join(home, ".config", appname, "profiles", self._profile) legacy_user_dir = os.path.join(home, ".config", appname, "profiles", self._profile)
@ -128,7 +122,7 @@ class Config:
if self._main_config_file is None: if self._main_config_file is None:
# TODO: migrate versioned config file from a previous version of GNS3 (for instance 2.2 -> 2.3) + support profiles # TODO: migrate versioned config file from a previous version of GNS3 (for instance 2.2 -> 3.0) + support profiles
# migrate post version 2.2.0 config files if they exist # migrate post version 2.2.0 config files if they exist
os.makedirs(versioned_user_dir, exist_ok=True) os.makedirs(versioned_user_dir, exist_ok=True)
try: try:
@ -137,12 +131,6 @@ class Config:
new_server_config = os.path.join(versioned_user_dir, server_filename) new_server_config = os.path.join(versioned_user_dir, server_filename)
if not os.path.exists(new_server_config) and os.path.exists(old_server_config): if not os.path.exists(new_server_config) and os.path.exists(old_server_config):
shutil.copyfile(old_server_config, new_server_config) shutil.copyfile(old_server_config, new_server_config)
# migrate the controller config file
old_controller_config = os.path.join(legacy_user_dir, controller_filename)
new_controller_config = os.path.join(versioned_user_dir, controller_filename)
if not os.path.exists(new_controller_config) and os.path.exists(old_controller_config):
shutil.copyfile(old_controller_config, os.path.join(versioned_user_dir, new_controller_config))
except OSError as e: except OSError as e:
log.error("Cannot migrate old config files: {}".format(e)) log.error("Cannot migrate old config files: {}".format(e))
@ -155,6 +143,16 @@ class Config:
self.clear() self.clear()
self._watch_config_file() self._watch_config_file()
@property
def settings(self) -> ServerConfig:
"""
Return the settings.
"""
if self._settings is None:
return ServerConfig()
return self._settings
def listen_for_config_changes(self, callback): def listen_for_config_changes(self, callback):
""" """
Call the callback when the configuration file change Call the callback when the configuration file change
@ -170,20 +168,17 @@ class Config:
@property @property
def config_dir(self): def config_dir(self):
"""
Return the directory where the configuration file is located.
"""
return os.path.dirname(self._main_config_file) return os.path.dirname(self._main_config_file)
@property
def controller_config(self):
if sys.platform.startswith("win"):
controller_config_filename = "gns3_controller.ini"
else:
controller_config_filename = "gns3_controller.conf"
return os.path.join(self.config_dir, controller_config_filename)
@property @property
def server_config(self): def server_config(self):
"""
Return the server configuration file path.
"""
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
server_config_filename = "gns3_server.ini" server_config_filename = "gns3_server.ini"
@ -196,21 +191,24 @@ class Config:
Restart with a clean config Restart with a clean config
""" """
self._config = configparser.ConfigParser(interpolation=None)
# Override config from command line even if we modify the config file and live reload it.
self._override_config = {}
self.read_config() self.read_config()
def _watch_config_file(self): def _watch_config_file(self):
"""
Add config files to be monitored for changes.
"""
for file in self._files: for file in self._files:
if os.path.exists(file): if os.path.exists(file):
self._watched_files[file] = FileWatcher(file, self._config_file_change) self._watched_files[file] = FileWatcher(file, self._config_file_change)
def _config_file_change(self, path): def _config_file_change(self, file_path):
"""
Callback when a config file has been updated.
"""
log.info(f"'{file_path}' has been updated, reloading the config...")
self.read_config() self.read_config()
for section in self._override_config:
self.set_section_config(section, self._override_config[section])
for callback in self._watch_callback: for callback in self._watch_callback:
callback() callback()
@ -220,93 +218,70 @@ class Config:
""" """
self.read_config() self.read_config()
for section in self._override_config:
self.set_section_config(section, self._override_config[section])
def get_config_files(self): def get_config_files(self):
"""
Return the config files in use.
"""
return self._watched_files return self._watched_files
def _load_jwt_secret_key(self):
"""
Load the JWT secret key.
"""
jwt_secret_key_path = os.path.join(self._settings.Server.secrets_dir, "gns3_jwt_secret_key")
if not os.path.exists(jwt_secret_key_path):
log.info(f"No JWT secret key configured, generating one in '{jwt_secret_key_path}'...")
try:
with open(jwt_secret_key_path, "w+", encoding="utf-8") as fd:
fd.write(secrets.token_hex(32))
except OSError as e:
log.error(f"Could not create JWT secret key file '{jwt_secret_key_path}': {e}")
try:
with open(jwt_secret_key_path, encoding="utf-8") as fd:
jwt_secret_key_content = fd.read()
self._settings.Controller.jwt_secret_key = jwt_secret_key_content
except OSError as e:
log.error(f"Could not read JWT secret key file '{jwt_secret_key_path}': {e}")
def _load_secret_files(self):
"""
Load the secret files.
"""
if not self._settings.Server.secrets_dir:
self._settings.Server.secrets_dir = os.path.dirname(self.server_config)
self._load_jwt_secret_key()
def read_config(self): def read_config(self):
""" """
Read the configuration files. Read the configuration files and validate the settings.
""" """
config = configparser.ConfigParser(interpolation=None)
try: try:
parsed_files = self._config.read(self._files, encoding="utf-8") parsed_files = config.read(self._files, encoding="utf-8")
except configparser.Error as e: except configparser.Error as e:
log.error("Can't parse configuration file: %s", str(e)) log.error("Can't parse configuration file: %s", str(e))
return return
if not parsed_files: if not parsed_files:
log.warning("No configuration file could be found or read") log.warning("No configuration file could be found or read")
else: return
for file in parsed_files:
log.info("Load configuration file {}".format(file))
self._watched_files[file] = os.stat(file).st_mtime
def write_config(self): for file in parsed_files:
""" log.info(f"Load configuration file '{file}'")
Write the server configuration file. self._watched_files[file] = os.stat(file).st_mtime
"""
try: try:
os.makedirs(os.path.dirname(self.server_config), exist_ok=True) self._settings = ServerConfig(**config._sections)
with open(self.server_config, 'w+') as fd: except ValidationError as e:
self._config.write(fd) log.error(f"Could not validate config: {e}")
except OSError as e: return
log.error("Cannot write server configuration file '{}': {}".format(self.server_config, e))
def get_default_section(self): self._load_secret_files()
"""
Get the default configuration section.
:returns: configparser section
"""
return self._config["DEFAULT"]
def get_section_config(self, section):
"""
Get a specific configuration section.
Returns the default section if none can be found.
:returns: configparser section
"""
if section not in self._config:
return self._config["DEFAULT"]
return self._config[section]
def set_section_config(self, section, content):
"""
Set a specific configuration section. It's not
dumped on the disk.
:param section: Section name
:param content: A dictionary with section content
"""
if not self._config.has_section(section):
self._config.add_section(section)
for key in content:
if isinstance(content[key], bool):
content[key] = str(content[key]).lower()
self._config.set(section, key, content[key])
self._override_config[section] = content
def set(self, section, key, value):
"""
Set a config value.
It's not dumped on the disk.
If the section doesn't exists the section is created
"""
conf = self.get_section_config(section)
if isinstance(value, bool):
conf[key] = str(value)
else:
conf[key] = value
self.set_section_config(section, conf)
@staticmethod @staticmethod
def instance(*args, **kwargs): def instance(*args, **kwargs):

View File

@ -59,17 +59,15 @@ class Controller:
self._iou_license_settings = {"iourc_content": "", self._iou_license_settings = {"iourc_content": "",
"license_check": True} "license_check": True}
self._config_loaded = False self._config_loaded = False
self._config_file = Config.instance().controller_config
log.info("Load controller configuration file {}".format(self._config_file))
async def start(self, computes=None): async def start(self, computes=None):
log.info("Controller is starting") log.info("Controller is starting")
self.load_base_files() self.load_base_files()
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
Config.instance().listen_for_config_changes(self._update_config) Config.instance().listen_for_config_changes(self._update_config)
host = server_config.get("host", "localhost") host = server_config.host
port = server_config.getint("port", 3080) port = server_config.port
# clients will use the IP they use to connect to # clients will use the IP they use to connect to
# the controller if console_host is 0.0.0.0 # the controller if console_host is 0.0.0.0
@ -83,13 +81,13 @@ class Controller:
self._load_controller_settings() self._load_controller_settings()
if server_config.getboolean("ssl"): if server_config.ssl:
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
log.critical("SSL mode is not supported on Windows") log.critical("SSL mode is not supported on Windows")
raise SystemExit raise SystemExit
self._ssl_context = self._create_ssl_context(server_config) self._ssl_context = self._create_ssl_context(server_config)
protocol = server_config.get("protocol", "http") protocol = server_config.protocol
if self._ssl_context and protocol != "https": if self._ssl_context and protocol != "https":
log.warning("Protocol changed to 'https' for local compute because SSL is enabled".format(port)) log.warning("Protocol changed to 'https' for local compute because SSL is enabled".format(port))
protocol = "https" protocol = "https"
@ -100,8 +98,8 @@ class Controller:
host=host, host=host,
console_host=console_host, console_host=console_host,
port=port, port=port,
user=server_config.get("user", ""), user=server_config.user,
password=server_config.get("password", ""), password=server_config.password,
force=True, force=True,
connect=True, connect=True,
ssl_context=self._ssl_context) ssl_context=self._ssl_context)
@ -128,8 +126,8 @@ class Controller:
import ssl import ssl
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
certfile = server_config["certfile"] certfile = server_config.certfile
certkey = server_config["certkey"] certkey = server_config.certkey
try: try:
ssl_context.load_cert_chain(certfile, certkey) ssl_context.load_cert_chain(certfile, certkey)
except FileNotFoundError: except FileNotFoundError:
@ -153,9 +151,8 @@ class Controller:
""" """
if self._local_server: if self._local_server:
server_config = Config.instance().get_section_config("Server") self._local_server.user = Config.instance().settings.Server.user
self._local_server.user = server_config.get("user") self._local_server.password = Config.instance().settings.Server.password
self._local_server.password = server_config.get("password")
async def stop(self): async def stop(self):
@ -169,7 +166,7 @@ class Controller:
except (ComputeError, ControllerError, OSError): except (ComputeError, ControllerError, OSError):
pass pass
await self.gns3vm.exit_vm() await self.gns3vm.exit_vm()
#self.save() self.save()
self._computes = {} self._computes = {}
self._projects = {} self._projects = {}
@ -187,20 +184,6 @@ class Controller:
await self.load_projects() await self.load_projects()
def check_can_write_config(self):
"""
Check if the controller configuration can be written on disk
:returns: boolean
"""
try:
os.makedirs(os.path.dirname(self._config_file), exist_ok=True)
if not os.access(self._config_file, os.W_OK):
raise ControllerNotFoundError("Change rejected, cannot write to controller configuration file '{}'".format(self._config_file))
except OSError as e:
raise ControllerError("Change rejected: {}".format(e))
def save(self): def save(self):
""" """
Save the controller configuration on disk Save the controller configuration on disk
@ -209,68 +192,83 @@ class Controller:
if self._config_loaded is False: if self._config_loaded is False:
return return
controller_settings = {"gns3vm": self.gns3vm.__json__(), if self._iou_license_settings["iourc_content"]:
"iou_license": self._iou_license_settings,
"appliances_etag": self._appliance_manager.appliances_etag,
"version": __version__}
# for compute in self._computes.values(): iou_config = Config.instance().settings.IOU
# if compute.id != "local" and compute.id != "vm": server_config = Config.instance().settings.Server
# controller_settings["computes"].append({"host": compute.host,
# "name": compute.name,
# "port": compute.port,
# "protocol": compute.protocol,
# "user": compute.user,
# "password": compute.password,
# "compute_id": compute.id})
try: if iou_config.iourc_path:
os.makedirs(os.path.dirname(self._config_file), exist_ok=True) iourc_path = iou_config.iourc_path
with open(self._config_file, 'w+') as f: else:
json.dump(controller_settings, f, indent=4) os.makedirs(os.path.dirname(server_config.secrets_dir), exist_ok=True)
except OSError as e: iourc_path = os.path.join(server_config.secrets_dir, "gns3_iourc_license")
log.error("Cannot write controller configuration file '{}': {}".format(self._config_file, e))
try:
with open(iourc_path, 'w+') as f:
f.write(self._iou_license_settings["iourc_content"])
log.info(f"iourc file '{iourc_path}' saved")
except OSError as e:
log.error(f"Cannot write IOU license file '{iourc_path}': {e}")
# if self._appliance_manager.appliances_etag:
# config._config.set("Controller", "appliances_etag", self._appliance_manager.appliances_etag)
# config.write_config()
def _load_controller_settings(self): def _load_controller_settings(self):
""" """
Reload the controller configuration from disk Reload the controller configuration from disk
""" """
try: # try:
if not os.path.exists(self._config_file): # if not os.path.exists(self._config_file):
self._config_loaded = True # self._config_loaded = True
self.save() # self.save()
with open(self._config_file) as f: # with open(self._config_file) as f:
controller_settings = json.load(f) # controller_settings = json.load(f)
except (OSError, ValueError) as e: # except (OSError, ValueError) as e:
log.critical("Cannot load configuration file '{}': {}".format(self._config_file, e)) # log.critical("Cannot load configuration file '{}': {}".format(self._config_file, e))
return [] # return []
# load GNS3 VM settings # load GNS3 VM settings
if "gns3vm" in controller_settings: # if "gns3vm" in controller_settings:
gns3_vm_settings = controller_settings["gns3vm"] # gns3_vm_settings = controller_settings["gns3vm"]
if "port" not in gns3_vm_settings: # if "port" not in gns3_vm_settings:
# port setting was added in version 2.2.8 # # port setting was added in version 2.2.8
# the default port was 3080 before this # # the default port was 3080 before this
gns3_vm_settings["port"] = 3080 # gns3_vm_settings["port"] = 3080
self.gns3vm.settings = gns3_vm_settings # self.gns3vm.settings = gns3_vm_settings
# load the IOU license settings # load the IOU license settings
if "iou_license" in controller_settings: iou_config = Config.instance().settings.IOU
self._iou_license_settings = controller_settings["iou_license"] server_config = Config.instance().settings.Server
self._appliance_manager.appliances_etag = controller_settings.get("appliances_etag") #controller_config.getboolean("iou_license_check", True)
self._appliance_manager.load_appliances()
if iou_config.iourc_path:
iourc_path = iou_config.iourc_path
else:
iourc_path = os.path.join(server_config.secrets_dir, "gns3_iourc_license")
if os.path.exists(iourc_path):
try:
with open(iourc_path, 'r') as f:
self._iou_license_settings["iourc_content"] = f.read()
log.info(f"iourc file '{iourc_path}' loaded")
except OSError as e:
log.error(f"Cannot read IOU license file '{iourc_path}': {e}")
self._iou_license_settings["license_check"] = iou_config.license_check
#self._appliance_manager.appliances_etag = controller_config.get("appliances_etag", None)
#self._appliance_manager.load_appliances()
self._config_loaded = True self._config_loaded = True
return controller_settings.get("computes", [])
async def load_projects(self): async def load_projects(self):
""" """
Preload the list of projects from disk Preload the list of projects from disk
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
projects_path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) projects_path = os.path.expanduser(server_config.projects_path)
os.makedirs(projects_path, exist_ok=True) os.makedirs(projects_path, exist_ok=True)
try: try:
for project_path in os.listdir(projects_path): for project_path in os.listdir(projects_path):
@ -305,8 +303,8 @@ class Controller:
Get the image storage directory Get the image storage directory
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
images_path = os.path.expanduser(server_config.get("images_path", "~/GNS3/images")) images_path = os.path.expanduser(server_config.images_path)
os.makedirs(images_path, exist_ok=True) os.makedirs(images_path, exist_ok=True)
return images_path return images_path
@ -315,8 +313,8 @@ class Controller:
Get the configs storage directory Get the configs storage directory
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
configs_path = os.path.expanduser(server_config.get("configs_path", "~/GNS3/configs")) configs_path = os.path.expanduser(server_config.configs_path)
os.makedirs(configs_path, exist_ok=True) os.makedirs(configs_path, exist_ok=True)
return configs_path return configs_path
@ -348,7 +346,7 @@ class Controller:
compute = Compute(compute_id=compute_id, controller=self, name=name, **kwargs) compute = Compute(compute_id=compute_id, controller=self, name=name, **kwargs)
self._computes[compute.id] = compute self._computes[compute.id] = compute
self.save() #self.save()
if connect: if connect:
asyncio.ensure_future(compute.connect()) asyncio.ensure_future(compute.connect())
self.notification.controller_emit("compute.created", compute.__json__()) self.notification.controller_emit("compute.created", compute.__json__())
@ -394,7 +392,7 @@ class Controller:
await self.close_compute_projects(compute) await self.close_compute_projects(compute)
await compute.close() await compute.close()
del self._computes[compute_id] del self._computes[compute_id]
self.save() #self.save()
self.notification.controller_emit("compute.deleted", compute.__json__()) self.notification.controller_emit("compute.deleted", compute.__json__())
@property @property
@ -557,8 +555,8 @@ class Controller:
def projects_directory(self): def projects_directory(self):
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
return os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) return os.path.expanduser(server_config.projects_path)
@staticmethod @staticmethod
def instance(): def instance():

View File

@ -71,8 +71,8 @@ class ApplianceManager:
Get the image storage directory Get the image storage directory
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
appliances_path = os.path.expanduser(server_config.get("appliances_path", "~/GNS3/appliances")) appliances_path = os.path.expanduser(server_config.appliances_path)
os.makedirs(appliances_path, exist_ok=True) os.makedirs(appliances_path, exist_ok=True)
return appliances_path return appliances_path

View File

@ -121,9 +121,9 @@ class Compute:
else: else:
self._user = user.strip() self._user = user.strip()
if password: if password:
self._password = password.strip() self._password = password
try: try:
self._auth = aiohttp.BasicAuth(self._user, self._password, "utf-8") self._auth = aiohttp.BasicAuth(self._user, self._password.get_secret_value(), "utf-8")
except ValueError as e: except ValueError as e:
log.error(str(e)) log.error(str(e))
else: else:

View File

@ -413,9 +413,6 @@ class Project:
self._path = path self._path = path
def _config(self):
return Config.instance().get_section_config("Server")
@property @property
def captures_directory(self): def captures_directory(self):
""" """
@ -870,8 +867,8 @@ class Project:
depending of the operating system depending of the operating system
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) path = os.path.expanduser(server_config.projects_path)
path = os.path.normpath(path) path = os.path.normpath(path)
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

View File

@ -112,7 +112,9 @@ class Symbols:
return symbols return symbols
def symbols_path(self): def symbols_path(self):
directory = os.path.expanduser(Config.instance().get_section_config("Server").get("symbols_path", "~/GNS3/symbols"))
server_config = Config.instance().settings.Server
directory = os.path.expanduser(server_config.symbols_path)
if directory: if directory:
try: try:
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)

View File

@ -142,8 +142,7 @@ class CrashReport:
log.warning(".git directory detected, crash reporting is turned off for developers.") log.warning(".git directory detected, crash reporting is turned off for developers.")
return return
server_config = Config.instance().get_section_config("Server") if Config.instance().settings.Server.report_errors:
if server_config.getboolean("report_errors"):
if not SENTRY_SDK_AVAILABLE: if not SENTRY_SDK_AVAILABLE:
log.warning("Cannot capture exception: Sentry SDK is not available") log.warning("Cannot capture exception: Sentry SDK is not available")

View File

@ -54,6 +54,10 @@ class ComputesRepository(BaseRepository):
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
password = compute_create.password
if password:
password = password.get_secret_value()
db_compute = models.Compute( db_compute = models.Compute(
compute_id=compute_create.compute_id, compute_id=compute_create.compute_id,
name=compute_create.name, name=compute_create.name,
@ -61,7 +65,7 @@ class ComputesRepository(BaseRepository):
host=compute_create.host, host=compute_create.host,
port=compute_create.port, port=compute_create.port,
user=compute_create.user, user=compute_create.user,
password=compute_create.password password=password
) )
self._db_session.add(db_compute) self._db_session.add(db_compute)
await self._db_session.commit() await self._db_session.commit()
@ -72,6 +76,10 @@ class ComputesRepository(BaseRepository):
update_values = compute_update.dict(exclude_unset=True) update_values = compute_update.dict(exclude_unset=True)
password = compute_update.password
if password:
update_values["password"] = password.get_secret_value()
query = update(models.Compute) \ query = update(models.Compute) \
.where(models.Compute.compute_id == compute_id) \ .where(models.Compute.compute_id == compute_id) \
.values(update_values) .values(update_values)

View File

@ -100,12 +100,10 @@ def parse_arguments(argv):
parser.add_argument("--config", help="Configuration file") parser.add_argument("--config", help="Configuration file")
parser.add_argument("--certfile", help="SSL cert file") parser.add_argument("--certfile", help="SSL cert file")
parser.add_argument("--certkey", help="SSL key file") parser.add_argument("--certkey", help="SSL key file")
parser.add_argument("--record", help="save curl requests into a file (for developers)")
parser.add_argument("-L", "--local", action="store_true", help="local mode (allows some insecure operations)") parser.add_argument("-L", "--local", action="store_true", help="local mode (allows some insecure operations)")
parser.add_argument("-A", "--allow", action="store_true", help="allow remote connections to local console ports") parser.add_argument("-A", "--allow", action="store_true", help="allow remote connections to local console ports")
parser.add_argument("-q", "--quiet", action="store_true", help="do not show logs on stdout") parser.add_argument("-q", "--quiet", action="store_true", help="do not show logs on stdout")
parser.add_argument("-d", "--debug", action="store_true", help="show debug logs") parser.add_argument("-d", "--debug", action="store_true", help="show debug logs")
parser.add_argument("--shell", action="store_true", help="start a shell inside the server (debugging purpose only you need to install ptpython before)")
parser.add_argument("--log", help="send output to logfile instead of console") parser.add_argument("--log", help="send output to logfile instead of console")
parser.add_argument("--logmaxsize", help="maximum logfile size in bytes (default is 10MB)") parser.add_argument("--logmaxsize", help="maximum logfile size in bytes (default is 10MB)")
parser.add_argument("--logbackupcount", help="number of historical log files to keep (default is 10)") parser.add_argument("--logbackupcount", help="number of historical log files to keep (default is 10)")
@ -120,50 +118,37 @@ def parse_arguments(argv):
else: else:
Config.instance(profile=args.profile) Config.instance(profile=args.profile)
config = Config.instance().get_section_config("Server") config = Config.instance().settings
defaults = { defaults = {
"host": config.get("host", "0.0.0.0"), "host": config.Server.host,
"port": config.getint("port", 3080), "port": config.Server.port,
"ssl": config.getboolean("ssl", False), "ssl": config.Server.ssl,
"certfile": config.get("certfile", ""), "certfile": config.Server.certfile,
"certkey": config.get("certkey", ""), "certkey": config.Server.certkey,
"record": config.get("record", ""), "local": config.Server.local,
"local": config.getboolean("local", False), "allow": config.Server.allow_remote_console,
"allow": config.getboolean("allow_remote_console", False), "quiet": config.Server.quiet,
"quiet": config.getboolean("quiet", False), "debug": config.Server.debug,
"debug": config.getboolean("debug", False), "logfile": config.Server.logfile,
"logfile": config.getboolean("logfile", ""), "logmaxsize": config.Server.logmaxsize,
"logmaxsize": config.getint("logmaxsize", 10000000), # default is 10MB "logbackupcount": config.Server.logbackupcount,
"logbackupcount": config.getint("logbackupcount", 10), "logcompression": config.Server.logcompression
"logcompression": config.getboolean("logcompression", False)
} }
parser.set_defaults(**defaults) parser.set_defaults(**defaults)
return parser.parse_args(argv) return parser.parse_args(argv)
def set_config(args): def set_config(args):
config = Config.instance() config = Config.instance().settings
server_config = config.get_section_config("Server") config.Server.local = args.local
jwt_secret_key = server_config.get("jwt_secret_key", None) config.Server.allow_remote_console = args.allow
if not jwt_secret_key: config.Server.host = args.host
log.info("No JWT secret key configured, generating one...") config.Server.port = args.port
if not config._config.has_section("Server"): config.Server.ssl = args.ssl
config._config.add_section("Server") config.Server.certfile = args.certfile
config._config.set("Server", "jwt_secret_key", secrets.token_hex(32)) config.Server.certkey = args.certkey
config.write_config() config.Server.debug = args.debug
server_config["local"] = str(args.local)
server_config["allow_remote_console"] = str(args.allow)
server_config["host"] = args.host
server_config["port"] = str(args.port)
server_config["ssl"] = str(args.ssl)
server_config["certfile"] = args.certfile
server_config["certkey"] = args.certkey
server_config["record"] = args.record
server_config["debug"] = str(args.debug)
server_config["shell"] = str(args.shell)
config.set_section_config("Server", server_config)
def pid_lock(path): def pid_lock(path):
@ -280,17 +265,13 @@ def run():
log.info("Config file {} loaded".format(config_file)) log.info("Config file {} loaded".format(config_file))
set_config(args) set_config(args)
server_config = Config.instance().get_section_config("Server") config = Config.instance().settings
if server_config.getboolean("local"): if config.Server.local:
log.warning("Local mode is enabled. Beware, clients will have full control on your filesystem") log.warning("Local mode is enabled. Beware, clients will have full control on your filesystem")
if server_config.getboolean("auth"): if config.Server.auth:
user = server_config.get("user", "").strip() log.info("HTTP authentication is enabled with username '{}'".format(config.Server.user))
if not user:
log.critical("HTTP authentication is enabled but no username is configured")
return
log.info("HTTP authentication is enabled with username '{}'".format(user))
# we only support Python 3 version >= 3.6 # we only support Python 3 version >= 3.6
if sys.version_info < (3, 6, 0): if sys.version_info < (3, 6, 0):
@ -311,8 +292,8 @@ def run():
return return
CrashReport.instance() CrashReport.instance()
host = server_config["host"] host = config.Server.host
port = int(server_config["port"]) port = config.Server.port
PortManager.instance().console_host = host PortManager.instance().console_host = host
signal_handling() signal_handling()
@ -325,22 +306,19 @@ def run():
if log.getEffectiveLevel() == logging.DEBUG: if log.getEffectiveLevel() == logging.DEBUG:
access_log = True access_log = True
certfile = None if config.Server.ssl:
certkey = None
if server_config.getboolean("ssl"):
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
log.critical("SSL mode is not supported on Windows") log.critical("SSL mode is not supported on Windows")
raise SystemExit raise SystemExit
certfile = server_config["certfile"]
certkey = server_config["certkey"]
log.info("SSL is enabled") log.info("SSL is enabled")
config = uvicorn.Config(app, config = uvicorn.Config(app,
host=host, host=host,
port=port, port=port,
access_log=access_log, access_log=access_log,
ssl_certfile=certfile, ssl_certfile=config.Server.certfile,
ssl_keyfile=certkey) ssl_keyfile=config.Server.certkey,
lifespan="on")
# overwrite uvicorn loggers with our own logger # overwrite uvicorn loggers with our own logger
for uvicorn_logger_name in ("uvicorn", "uvicorn.error"): for uvicorn_logger_name in ("uvicorn", "uvicorn.error"):

View File

@ -15,7 +15,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from .config import ServerConfig
from .iou_license import IOULicense from .iou_license import IOULicense
from .links import Link from .links import Link
from .common import ErrorMessage from .common import ErrorMessage

View File

@ -15,7 +15,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, SecretStr, validator
from typing import List, Optional, Union from typing import List, Optional, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from enum import Enum from enum import Enum
@ -51,7 +51,7 @@ class ComputeCreate(ComputeBase):
""" """
compute_id: Union[str, UUID] = Field(default_factory=uuid4) compute_id: Union[str, UUID] = Field(default_factory=uuid4)
password: Optional[str] = None password: Optional[SecretStr] = None
class Config: class Config:
schema_extra = { schema_extra = {
@ -91,7 +91,7 @@ class ComputeUpdate(ComputeBase):
protocol: Optional[Protocol] = None protocol: Optional[Protocol] = None
host: Optional[str] = None host: Optional[str] = None
port: Optional[int] = Field(None, gt=0, le=65535) port: Optional[int] = Field(None, gt=0, le=65535)
password: Optional[str] = None password: Optional[SecretStr] = None
class Config: class Config:
schema_extra = { schema_extra = {

View File

@ -0,0 +1,196 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from enum import Enum
from pydantic import BaseModel, Field, SecretStr, validator
from typing import List
class ControllerSettings(BaseModel):
jwt_secret_key: str = None
jwt_algorithm: str = "HS256"
jwt_access_token_expire_minutes: int = 1440
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class VPCSSettings(BaseModel):
vpcs_path: str = "vpcs"
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class DynamipsSettings(BaseModel):
allocate_aux_console_ports: bool = False
mmap_support: bool = True
dynamips_path: str = "dynamips"
sparse_memory_support: bool = True
ghost_ios_support: bool = True
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class IOUSettings(BaseModel):
iourc_path: str = None
license_check: bool = True
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class QemuSettings(BaseModel):
enable_monitor: bool = True
monitor_host: str = "127.0.0.1"
enable_hardware_acceleration: bool = True
require_hardware_acceleration: bool = False
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class VirtualBoxSettings(BaseModel):
vboxmanage_path: str = None
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class VMwareSettings(BaseModel):
vmrun_path: str = None
vmnet_start_range: int = Field(2, ge=1, le=255)
vmnet_end_range: int = Field(255, ge=1, le=255) # should be limited to 19 on Windows
block_host_traffic: bool = False
@validator("vmnet_end_range")
def vmnet_port_range(cls, v, values):
if "vmnet_start_range" in values and v <= values["vmnet_start_range"]:
raise ValueError("vmnet_end_range must be > vmnet_start_range")
return v
class Config:
validate_assignment = True
anystr_strip_whitespace = True
class ServerProtocol(str, Enum):
http = "http"
https = "https"
class ServerSettings(BaseModel):
protocol: ServerProtocol = ServerProtocol.http
host: str = "0.0.0.0"
port: int = Field(3080, gt=0, le=65535)
secrets_dir: str = None
ssl: bool = False
certfile: str = None
certkey: str = None
images_path: str = "~/GNS3/images"
projects_path: str = "~/GNS3/projects"
appliances_path: str = "~/GNS3/appliances"
symbols_path: str = "~/GNS3/symbols"
configs_path: str = "~/GNS3/configs"
report_errors: bool = True
additional_images_paths: List[str] = Field(default_factory=list)
console_start_port_range: int = Field(5000, gt=0, le=65535)
console_end_port_range: int = Field(10000, gt=0, le=65535)
vnc_console_start_port_range: int = Field(5900, ge=5900, le=65535)
vnc_console_end_port_range: int = Field(10000, ge=5900, le=65535)
udp_start_port_range: int = Field(10000, gt=0, le=65535)
udp_end_port_range: int = Field(30000, gt=0, le=65535)
ubridge_path: str = "ubridge"
user: str = None
password: SecretStr = None
auth: bool = False
allowed_interfaces: List[str] = Field(default_factory=list)
default_nat_interface: str = None
logfile: str = None
logmaxsize: int = 10000000 # default is 10MB
logbackupcount: int = 10
logcompression: bool = False
local: bool = False
allow_remote_console: bool = False
quiet: bool = False
debug: bool = False
@validator("additional_images_paths", pre=True)
def split_additional_images_paths(cls, v):
if v:
return v.split(';')
return list()
@validator("allowed_interfaces", pre=True)
def split_allowed_interfaces(cls, v):
if v:
return v.split(',')
return list()
@validator("console_end_port_range")
def console_port_range(cls, v, values):
if "console_start_port_range" in values and v <= values["console_start_port_range"]:
raise ValueError("console_end_port_range must be > console_start_port_range")
return v
@validator("vnc_console_end_port_range")
def vnc_console_port_range(cls, v, values):
if "vnc_console_start_port_range" in values and v <= values["vnc_console_start_port_range"]:
raise ValueError("vnc_console_end_port_range must be > vnc_console_start_port_range")
return v
@validator("auth")
def validate_enable_auth(cls, v, values):
if v is True:
if "user" not in values or not values["user"]:
raise ValueError("HTTP authentication is enabled but no username is configured")
return v
class Config:
validate_assignment = True
anystr_strip_whitespace = True
use_enum_values = True
class ServerConfig(BaseModel):
Server: ServerSettings= ServerSettings()
Controller: ControllerSettings = ControllerSettings()
VPCS: VPCSSettings = VPCSSettings()
Dynamips: DynamipsSettings = DynamipsSettings()
IOU: IOUSettings = IOUSettings()
Qemu: QemuSettings = QemuSettings()
VirtualBox: VirtualBoxSettings = VirtualBoxSettings()
VMware: VMwareSettings = VMwareSettings()

View File

@ -38,7 +38,7 @@ class AuthService:
def __init__(self): def __init__(self):
self._server_config = Config.instance().get_section_config("Server") self._controller_config = Config.instance().settings.Controller
def hash_password(self, password: str) -> str: def hash_password(self, password: str) -> str:
@ -48,20 +48,6 @@ class AuthService:
return pwd_context.verify(password, hashed_password) return pwd_context.verify(password, hashed_password)
def get_secret_key(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_secret_key", None)
def get_algorithm(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_algorithm", None)
def create_access_token( def create_access_token(
self, self,
username, username,
@ -70,15 +56,15 @@ class AuthService:
) -> str: ) -> str:
if not expires_in: if not expires_in:
expires_in = self._server_config.getint("jwt_access_token_expire_minutes", 1440) expires_in = self._controller_config.jwt_access_token_expire_minutes
expire = datetime.utcnow() + timedelta(minutes=expires_in) expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire} to_encode = {"sub": username, "exp": expire}
if secret_key is None: if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None) secret_key = self._controller_config.jwt_secret_key
if secret_key is None: if secret_key is None:
secret_key = DEFAULT_JWT_SECRET_KEY secret_key = DEFAULT_JWT_SECRET_KEY
log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!")
algorithm = self._server_config.get("jwt_algorithm", "HS256") algorithm = self._controller_config.jwt_algorithm
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt return encoded_jwt
@ -91,11 +77,11 @@ class AuthService:
) )
try: try:
if secret_key is None: if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None) secret_key = self._controller_config.jwt_secret_key
if secret_key is None: if secret_key is None:
secret_key = DEFAULT_JWT_SECRET_KEY secret_key = DEFAULT_JWT_SECRET_KEY
log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!") log.error("A JWT secret key must be configured to secure the server, using an unsecured default key!")
algorithm = self._server_config.get("jwt_algorithm", "HS256") algorithm = self._controller_config.jwt_algorithm
payload = jwt.decode(token, secret_key, algorithms=[algorithm]) payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:

View File

@ -35,8 +35,8 @@ def list_images(type):
files = set() files = set()
images = [] images = []
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
general_images_directory = os.path.expanduser(server_config.get("images_path", "~/GNS3/images")) general_images_directory = os.path.expanduser(server_config.images_path)
# Subfolder of the general_images_directory specific to this VM type # Subfolder of the general_images_directory specific to this VM type
default_directory = default_images_directory(type) default_directory = default_images_directory(type)
@ -106,8 +106,8 @@ def default_images_directory(type):
""" """
:returns: Return the default directory for a node type :returns: Return the default directory for a node type
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
img_dir = os.path.expanduser(server_config.get("images_path", "~/GNS3/images")) img_dir = os.path.expanduser(server_config.images_path)
if type == "qemu": if type == "qemu":
return os.path.join(img_dir, "QEMU") return os.path.join(img_dir, "QEMU")
elif type == "iou": elif type == "iou":
@ -125,17 +125,17 @@ def images_directories(type):
:param type: Type of emulator :param type: Type of emulator
""" """
server_config = Config.instance().get_section_config("Server")
server_config = Config.instance().settings.Server
paths = [] paths = []
img_dir = os.path.expanduser(server_config.get("images_path", "~/GNS3/images")) img_dir = os.path.expanduser(server_config.images_path)
type_img_directory = default_images_directory(type) type_img_directory = default_images_directory(type)
try: try:
os.makedirs(type_img_directory, exist_ok=True) os.makedirs(type_img_directory, exist_ok=True)
paths.append(type_img_directory) paths.append(type_img_directory)
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
for directory in server_config.get("additional_images_path", "").split(";"): for directory in server_config.additional_images_paths:
paths.append(directory) paths.append(directory)
# Compatibility with old topologies we look in parent directory # Compatibility with old topologies we look in parent directory
paths.append(img_dir) paths.append(img_dir)

View File

@ -184,9 +184,7 @@ def interfaces():
results = [] results = []
if not sys.platform.startswith("win"): if not sys.platform.startswith("win"):
allowed_interfaces = Config.instance().get_section_config("Server").get("allowed_interfaces", None) allowed_interfaces = Config.instance().settings.Server.allowed_interfaces
if allowed_interfaces:
allowed_interfaces = allowed_interfaces.split(',')
net_if_addrs = psutil.net_if_addrs() net_if_addrs = psutil.net_if_addrs()
for interface in sorted(net_if_addrs.keys()): for interface in sorted(net_if_addrs.keys()):
if allowed_interfaces and interface not in allowed_interfaces and not interface.startswith("gns3tap"): if allowed_interfaces and interface not in allowed_interfaces and not interface.startswith("gns3tap"):

View File

@ -27,8 +27,8 @@ def get_default_project_directory():
depending of the operating system depending of the operating system
""" """
server_config = Config.instance().get_section_config("Server") server_config = Config.instance().settings.Server
path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) path = os.path.expanduser(server_config.projects_path)
path = os.path.normpath(path) path = os.path.normpath(path)
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@ -45,11 +45,9 @@ def check_path_allowed(path):
Raise a 403 in case of error Raise a 403 in case of error
""" """
config = Config.instance().get_section_config("Server")
project_directory = get_default_project_directory() project_directory = get_default_project_directory()
if len(os.path.commonprefix([project_directory, path])) == len(project_directory): if len(os.path.commonprefix([project_directory, path])) == len(project_directory):
return return
if "local" in config and config.getboolean("local") is False: if Config.instance().settings.Server.local is False:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The path is not allowed") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The path is not allowed")

View File

@ -41,9 +41,8 @@ async def test_interfaces(app: FastAPI, client: AsyncClient) -> None:
assert isinstance(response.json(), list) assert isinstance(response.json(), list)
async def test_version_output(app: FastAPI, client: AsyncClient, config) -> None: async def test_version_output(app: FastAPI, client: AsyncClient) -> None:
config.set("Server", "local", "true")
response = await client.get(app.url_path_for("compute_version")) response = await client.get(app.url_path_for("compute_version"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {'local': True, 'version': __version__} assert response.json() == {'local': True, 'version': __version__}

View File

@ -158,10 +158,10 @@ async def test_close_project_invalid_uuid(app: FastAPI, client: AsyncClient) ->
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
async def test_get_file(app: FastAPI, client: AsyncClient, tmpdir) -> None: async def test_get_file(app: FastAPI, client: AsyncClient, config, tmpdir) -> None:
with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): config.settings.Server.projects_path = str(tmpdir)
project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b") project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b")
with open(os.path.join(project.path, "hello"), "w+") as f: with open(os.path.join(project.path, "hello"), "w+") as f:
f.write("world") f.write("world")
@ -179,10 +179,10 @@ async def test_get_file(app: FastAPI, client: AsyncClient, tmpdir) -> None:
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
async def test_write_file(app: FastAPI, client: AsyncClient, tmpdir) -> None: async def test_write_file(app: FastAPI, client: AsyncClient, config, tmpdir) -> None:
with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): config.settings.Server.projects_path = str(tmpdir)
project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b") project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b")
response = await client.post(app.url_path_for("write_compute_project_file", response = await client.post(app.url_path_for("write_compute_project_file",
project_id=project.id, project_id=project.id,

View File

@ -425,9 +425,9 @@ async def test_create_img_relative(app: FastAPI, client: AsyncClient):
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_create_img_absolute_non_local(app: FastAPI, client: AsyncClient, config: dict) -> None: async def test_create_img_absolute_non_local(app: FastAPI, client: AsyncClient, config) -> None:
config.set("Server", "local", "false") config.settings.Server.local = False
params = { params = {
"qemu_img": "/tmp/qemu-img", "qemu_img": "/tmp/qemu-img",
"path": "/tmp/hda.qcow2", "path": "/tmp/hda.qcow2",
@ -443,9 +443,9 @@ async def test_create_img_absolute_non_local(app: FastAPI, client: AsyncClient,
assert response.status_code == 403 assert response.status_code == 403
async def test_create_img_absolute_local(app: FastAPI, client: AsyncClient, config: dict) -> None: async def test_create_img_absolute_local(app: FastAPI, client: AsyncClient, config) -> None:
config.set("Server", "local", "true") config.settings.Server.local = True
params = { params = {
"qemu_img": "/tmp/qemu-img", "qemu_img": "/tmp/qemu-img",
"path": "/tmp/hda.qcow2", "path": "/tmp/hda.qcow2",

View File

@ -30,7 +30,7 @@ pytestmark = pytest.mark.asyncio
async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config) -> None: async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config) -> None:
os.kill = MagicMock() os.kill = MagicMock()
config.set("Server", "local", True) config.settings.Server.local = True
response = await client.post(app.url_path_for("shutdown")) response = await client.post(app.url_path_for("shutdown"))
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert os.kill.called assert os.kill.called
@ -38,7 +38,7 @@ async def test_shutdown_local(app: FastAPI, client: AsyncClient, config: Config)
async def test_shutdown_non_local(app: FastAPI, client: AsyncClient, config: Config) -> None: async def test_shutdown_non_local(app: FastAPI, client: AsyncClient, config: Config) -> None:
config.set("Server", "local", False) config.settings.Server.local = False
response = await client.post(app.url_path_for("shutdown")) response = await client.post(app.url_path_for("shutdown"))
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@ -178,7 +178,7 @@ async def test_open_project(app: FastAPI, client: AsyncClient, project: Project)
async def test_load_project(app: FastAPI, client: AsyncClient, project: Project, config) -> None: async def test_load_project(app: FastAPI, client: AsyncClient, project: Project, config) -> None:
config.set("Server", "local", "true") config.settings.Server.local = True
with asyncio_patch("gns3server.controller.Controller.load_project", return_value=project) as mock: with asyncio_patch("gns3server.controller.Controller.load_project", return_value=project) as mock:
response = await client.post(app.url_path_for("load_project"), json={"path": "/tmp/test.gns3"}) response = await client.post(app.url_path_for("load_project"), json={"path": "/tmp/test.gns3"})
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED

View File

@ -131,7 +131,7 @@ class TestAuthTokens:
config: Config config: Config
) -> None: ) -> None:
jwt_secret = config.get_section_config("Server").get("jwt_secret_key", DEFAULT_JWT_SECRET_KEY) jwt_secret = config.settings.Controller.jwt_secret_key
token = auth_service.create_access_token(test_user.username) token = auth_service.create_access_token(test_user.username)
payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) payload = jwt.decode(token, jwt_secret, algorithms=["HS256"])
username = payload.get("sub") username = payload.get("sub")
@ -139,7 +139,7 @@ class TestAuthTokens:
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None: async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None:
jwt_secret = config.get_section_config("Server").get("jwt_secret_key", DEFAULT_JWT_SECRET_KEY) jwt_secret = config.settings.Controller.jwt_secret_key
token = auth_service.create_access_token(None) token = auth_service.create_access_token(None)
with pytest.raises(jwt.JWTError): with pytest.raises(jwt.JWTError):
jwt.decode(token, jwt_secret, algorithms=["HS256"]) jwt.decode(token, jwt_secret, algorithms=["HS256"])
@ -171,11 +171,12 @@ class TestAuthTokens:
test_user: User, test_user: User,
wrong_secret: str, wrong_secret: str,
wrong_token: Optional[str], wrong_token: Optional[str],
config,
) -> None: ) -> None:
token = auth_service.create_access_token(test_user.username) token = auth_service.create_access_token(test_user.username)
if wrong_secret == "use correct secret": if wrong_secret == "use correct secret":
wrong_secret = auth_service._server_config.get("jwt_secret_key", DEFAULT_JWT_SECRET_KEY) wrong_secret = config.settings.Controller.jwt_secret_key
if wrong_token == "use correct token": if wrong_token == "use correct token":
wrong_token = token wrong_token = token
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
@ -192,7 +193,7 @@ class TestUserLogin:
config: Config config: Config
) -> None: ) -> None:
jwt_secret = config.get_section_config("Server").get("jwt_secret_key", DEFAULT_JWT_SECRET_KEY) jwt_secret = config.settings.Controller.jwt_secret_key
client.headers["content-type"] = "application/x-www-form-urlencoded" client.headers["content-type"] = "application/x-www-form-urlencoded"
login_data = { login_data = {
"username": test_user.username, "username": test_user.username,

View File

@ -25,9 +25,8 @@ from gns3server.version import __version__
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
async def test_version_output(app: FastAPI, client: AsyncClient, config) -> None: async def test_version_output(app: FastAPI, client: AsyncClient) -> None:
config.set("Server", "local", "true")
response = await client.get(app.url_path_for("get_version")) response = await client.get(app.url_path_for("get_version"))
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {'local': True, 'version': __version__} assert response.json() == {'local': True, 'version': __version__}

View File

@ -37,19 +37,20 @@ async def manager(port_manager):
return m return m
def test_vm_invalid_dynamips_path(manager): def test_vm_invalid_dynamips_path(manager, config):
with patch("gns3server.config.Config.get_section_config", return_value={"dynamips_path": "/bin/test_fake"}): config.settings.Dynamips.dynamips_path = "/bin/test_fake"
with pytest.raises(DynamipsError): with pytest.raises(DynamipsError):
manager.find_dynamips() manager.find_dynamips()
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not supported by Windows") @pytest.mark.skipif(sys.platform.startswith("win"), reason="Not supported by Windows")
def test_vm_non_executable_dynamips_path(manager): def test_vm_non_executable_dynamips_path(manager, config):
tmpfile = tempfile.NamedTemporaryFile() tmpfile = tempfile.NamedTemporaryFile()
with patch("gns3server.config.Config.get_section_config", return_value={"dynamips_path": tmpfile.name}): config.settings.Dynamips.dynamips_path = tmpfile.name
with pytest.raises(DynamipsError): with pytest.raises(DynamipsError):
manager.find_dynamips() manager.find_dynamips()
def test_get_dynamips_id(manager): def test_get_dynamips_id(manager):

View File

@ -66,11 +66,11 @@ def test_convert_project_before_2_0_0_b3(compute_project, manager):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_invalid_dynamips_path(compute_project, manager): async def test_router_invalid_dynamips_path(compute_project, config, manager):
config = Config.instance() config = Config.instance()
config.set("Dynamips", "dynamips_path", "/bin/test_fake") config.settings.Dynamips.dynamips_path = "/bin/test_fake"
config.set("Dynamips", "allocate_aux_console_ports", False) config.settings.Dynamips.allocate_aux_console_ports = False
with pytest.raises(DynamipsError): with pytest.raises(DynamipsError):
router = Router("test", "00010203-0405-0607-0809-0a0b0c0d0e0e", compute_project, manager) router = Router("test", "00010203-0405-0607-0809-0a0b0c0d0e0e", compute_project, manager)

View File

@ -47,12 +47,10 @@ async def manager(port_manager):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@pytest.mark.asyncio @pytest.mark.asyncio
async def vm(compute_project, manager, tmpdir, fake_iou_bin, iourc_file): async def vm(compute_project, manager, config, tmpdir, fake_iou_bin, iourc_file):
vm = IOUVM("test", str(uuid.uuid4()), compute_project, manager, application_id=1) vm = IOUVM("test", str(uuid.uuid4()), compute_project, manager, application_id=1)
config = manager.config.get_section_config("IOU") config.settings.IOU.iourc_path = iourc_file
config["iourc_path"] = iourc_file
manager.config.set_section_config("IOU", config)
vm.path = "iou.bin" vm.path = "iou.bin"
return vm return vm
@ -118,7 +116,7 @@ async def test_start(vm):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_with_iourc(vm, tmpdir): async def test_start_with_iourc(vm, tmpdir, config):
fake_file = str(tmpdir / "iourc") fake_file = str(tmpdir / "iourc")
with open(fake_file, "w+") as f: with open(fake_file, "w+") as f:
@ -131,13 +129,13 @@ async def test_start_with_iourc(vm, tmpdir):
vm._start_ubridge = AsyncioMagicMock(return_value=True) vm._start_ubridge = AsyncioMagicMock(return_value=True)
vm._ubridge_send = AsyncioMagicMock() vm._ubridge_send = AsyncioMagicMock()
with patch("gns3server.config.Config.get_section_config", return_value={"iourc_path": fake_file}): config.settings.IOU.iourc_path = fake_file
with asyncio_patch("asyncio.create_subprocess_exec", return_value=mock_process) as exec_mock: with asyncio_patch("asyncio.create_subprocess_exec", return_value=mock_process) as exec_mock:
mock_process.returncode = None mock_process.returncode = None
await vm.start() await vm.start()
assert vm.is_running() assert vm.is_running()
arsgs, kwargs = exec_mock.call_args arsgs, kwargs = exec_mock.call_args
assert kwargs["env"]["IOURC"] == fake_file assert kwargs["env"]["IOURC"] == fake_file
@pytest.mark.asyncio @pytest.mark.asyncio
@ -224,7 +222,7 @@ async def test_close(vm, port_manager):
def test_path(vm, fake_iou_bin, config): def test_path(vm, fake_iou_bin, config):
config.set_section_config("Server", {"local": True}) config.settings.Server.local = True
vm.path = fake_iou_bin vm.path = fake_iou_bin
assert vm.path == fake_iou_bin assert vm.path == fake_iou_bin
@ -237,7 +235,7 @@ def test_path_relative(vm, fake_iou_bin):
def test_path_invalid_bin(vm, tmpdir, config): def test_path_invalid_bin(vm, tmpdir, config):
config.set_section_config("Server", {"local": True}) config.settings.Server.local = True
path = str(tmpdir / "test.bin") path = str(tmpdir / "test.bin")
with open(path, "w+") as f: with open(path, "w+") as f:

View File

@ -77,7 +77,7 @@ async def vm(compute_project, manager, fake_qemu_binary, fake_qemu_img_binary):
vm._start_ubridge = AsyncioMagicMock() vm._start_ubridge = AsyncioMagicMock()
vm._ubridge_hypervisor = MagicMock() vm._ubridge_hypervisor = MagicMock()
vm._ubridge_hypervisor.is_running.return_value = True vm._ubridge_hypervisor.is_running.return_value = True
vm.manager.config.set("Qemu", "enable_hardware_acceleration", False) vm.manager.config.settings.Qemu.enable_hardware_acceleration = False
return vm return vm
@ -894,14 +894,14 @@ def test_get_qemu_img(vm, tmpdir):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_with_hardware_acceleration_darwin(darwin_platform, vm): async def test_run_with_hardware_acceleration_darwin(darwin_platform, vm):
vm.manager.config.set("Qemu", "enable_hardware_acceleration", False) vm.manager.config.settings.Qemu.enable_hardware_acceleration = False
assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is False assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_with_hardware_acceleration_windows(windows_platform, vm): async def test_run_with_hardware_acceleration_windows(windows_platform, vm):
vm.manager.config.set("Qemu", "enable_hardware_acceleration", False) vm.manager.config.settings.Qemu.enable_hardware_acceleration = False
assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is False assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is False
@ -909,7 +909,7 @@ async def test_run_with_hardware_acceleration_windows(windows_platform, vm):
async def test_run_with_kvm_linux(linux_platform, vm): async def test_run_with_kvm_linux(linux_platform, vm):
with patch("os.path.exists", return_value=True) as os_path: with patch("os.path.exists", return_value=True) as os_path:
vm.manager.config.set("Qemu", "enable_kvm", True) vm.manager.config.settings.Qemu.enable_hardware_acceleration = True
assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is True assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") is True
os_path.assert_called_with("/dev/kvm") os_path.assert_called_with("/dev/kvm")
@ -918,7 +918,7 @@ async def test_run_with_kvm_linux(linux_platform, vm):
async def test_run_with_kvm_linux_options_no_kvm(linux_platform, vm): async def test_run_with_kvm_linux_options_no_kvm(linux_platform, vm):
with patch("os.path.exists", return_value=True) as os_path: with patch("os.path.exists", return_value=True) as os_path:
vm.manager.config.set("Qemu", "enable_kvm", True) vm.manager.config.settings.Qemu.enable_hardware_acceleration = True
assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "-no-kvm") is False assert await vm._run_with_hardware_acceleration("qemu-system-x86_64", "-no-kvm") is False
@ -926,7 +926,8 @@ async def test_run_with_kvm_linux_options_no_kvm(linux_platform, vm):
async def test_run_with_kvm_not_x86(linux_platform, vm): async def test_run_with_kvm_not_x86(linux_platform, vm):
with patch("os.path.exists", return_value=True): with patch("os.path.exists", return_value=True):
vm.manager.config.set("Qemu", "enable_kvm", True) vm.manager.config.settings.Qemu.enable_hardware_acceleration = True
vm.manager.config.settings.Qemu.require_hardware_acceleration = True
with pytest.raises(QemuError): with pytest.raises(QemuError):
await vm._run_with_hardware_acceleration("qemu-system-arm", "") await vm._run_with_hardware_acceleration("qemu-system-arm", "")
@ -935,6 +936,7 @@ async def test_run_with_kvm_not_x86(linux_platform, vm):
async def test_run_with_kvm_linux_dev_kvm_missing(linux_platform, vm): async def test_run_with_kvm_linux_dev_kvm_missing(linux_platform, vm):
with patch("os.path.exists", return_value=False): with patch("os.path.exists", return_value=False):
vm.manager.config.set("Qemu", "enable_kvm", True) vm.manager.config.settings.Qemu.enable_hardware_acceleration = True
vm.manager.config.settings.Qemu.require_hardware_acceleration = True
with pytest.raises(QemuError): with pytest.raises(QemuError):
await vm._run_with_hardware_acceleration("qemu-system-x86_64", "") await vm._run_with_hardware_acceleration("qemu-system-x86_64", "")

View File

@ -86,7 +86,7 @@ def test_get_abs_image_path(qemu, tmpdir, config):
path2 = force_unix_path(str(tmpdir / "QEMU" / "test2.bin")) path2 = force_unix_path(str(tmpdir / "QEMU" / "test2.bin"))
open(path2, 'w+').close() open(path2, 'w+').close()
config.set_section_config("Server", {"images_path": str(tmpdir)}) config.settings.Server.images_path = str(tmpdir)
assert qemu.get_abs_image_path(path1) == path1 assert qemu.get_abs_image_path(path1) == path1
assert qemu.get_abs_image_path("test1.bin") == path1 assert qemu.get_abs_image_path("test1.bin") == path1
assert qemu.get_abs_image_path(path2) == path2 assert qemu.get_abs_image_path(path2) == path2
@ -105,14 +105,16 @@ def test_get_abs_image_path_non_local(qemu, tmpdir, config):
path2 = force_unix_path(str(path2)) path2 = force_unix_path(str(path2))
# If non local we can't use path outside images directory # If non local we can't use path outside images directory
config.set_section_config("Server", {"images_path": str(tmpdir / "images"), "local": False}) config.settings.Server.images_path = str(tmpdir / "images")
config.settings.Server.local = False
assert qemu.get_abs_image_path(path1) == path1 assert qemu.get_abs_image_path(path1) == path1
with pytest.raises(NodeError): with pytest.raises(NodeError):
qemu.get_abs_image_path(path2) qemu.get_abs_image_path(path2)
with pytest.raises(NodeError): with pytest.raises(NodeError):
qemu.get_abs_image_path("C:\\test2.bin") qemu.get_abs_image_path("C:\\test2.bin")
config.set_section_config("Server", {"images_path": str(tmpdir / "images"), "local": True}) config.settings.Server.images_path = str(tmpdir / "images")
config.settings.Server.local = True
assert qemu.get_abs_image_path(path2) == path2 assert qemu.get_abs_image_path(path2) == path2
@ -126,10 +128,9 @@ def test_get_abs_image_additional_image_paths(qemu, tmpdir, config):
path2.write("1", ensure=True) path2.write("1", ensure=True)
path2 = force_unix_path(str(path2)) path2 = force_unix_path(str(path2))
config.set_section_config("Server", { config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.additional_images_paths = "/tmp/null24564;" + str(tmpdir / "images2")
"additional_images_path": "/tmp/null24564;{}".format(str(tmpdir / "images2")), config.settings.Server.local = False
"local": False})
assert qemu.get_abs_image_path("test1.bin") == path1 assert qemu.get_abs_image_path("test1.bin") == path1
assert qemu.get_abs_image_path("test2.bin") == path2 assert qemu.get_abs_image_path("test2.bin") == path2
@ -150,9 +151,9 @@ def test_get_abs_image_recursive(qemu, tmpdir, config):
path2.write("1", ensure=True) path2.write("1", ensure=True)
path2 = force_unix_path(str(path2)) path2 = force_unix_path(str(path2))
config.set_section_config("Server", { config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.local = False
"local": False})
assert qemu.get_abs_image_path("test1.bin") == path1 assert qemu.get_abs_image_path("test1.bin") == path1
assert qemu.get_abs_image_path("test2.bin") == path2 assert qemu.get_abs_image_path("test2.bin") == path2
# Absolute path # Absolute path
@ -169,9 +170,9 @@ def test_get_abs_image_recursive_ova(qemu, tmpdir, config):
path2.write("1", ensure=True) path2.write("1", ensure=True)
path2 = force_unix_path(str(path2)) path2 = force_unix_path(str(path2))
config.set_section_config("Server", { config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.local = False
"local": False})
assert qemu.get_abs_image_path("test.ova/test1.bin") == path1 assert qemu.get_abs_image_path("test.ova/test1.bin") == path1
assert qemu.get_abs_image_path("test.ova/test2.bin") == path2 assert qemu.get_abs_image_path("test.ova/test2.bin") == path2
# Absolute path # Absolute path
@ -199,11 +200,10 @@ def test_get_relative_image_path(qemu, tmpdir, config):
path5 = force_unix_path(str(tmpdir / "images1" / "VBOX" / "test5.bin")) path5 = force_unix_path(str(tmpdir / "images1" / "VBOX" / "test5.bin"))
open(path5, 'w+').close() open(path5, 'w+').close()
config.set_section_config("Server", { config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.additional_images_paths = str(tmpdir / "images2")
"additional_images_path": str(tmpdir / "images2"), config.settings.Server.local = True
"local": True
})
assert qemu.get_relative_image_path(path1) == "test1.bin" assert qemu.get_relative_image_path(path1) == "test1.bin"
assert qemu.get_relative_image_path("test1.bin") == "test1.bin" assert qemu.get_relative_image_path("test1.bin") == "test1.bin"
assert qemu.get_relative_image_path(path2) == "test2.bin" assert qemu.get_relative_image_path(path2) == "test2.bin"

View File

@ -135,10 +135,10 @@ def test_set_console_host(config):
""" """
p = PortManager() p = PortManager()
config.set_section_config("Server", {"allow_remote_console": False}) config.settings.Server.allow_remote_console = False
p.console_host = "10.42.1.42" p.console_host = "10.42.1.42"
assert p.console_host == "10.42.1.42" assert p.console_host == "10.42.1.42"
p = PortManager() p = PortManager()
config.set_section_config("Server", {"allow_remote_console": True}) config.settings.Server.allow_remote_console = True
p.console_host = "10.42.1.42" p.console_host = "10.42.1.42"
assert p.console_host == "0.0.0.0" assert p.console_host == "0.0.0.0"

View File

@ -194,30 +194,30 @@ async def test_project_close(node, compute_project):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_files(tmpdir): async def test_list_files(tmpdir, config):
with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): config.settings.Server.projects_path = str(tmpdir)
project = Project(project_id=str(uuid4())) project = Project(project_id=str(uuid4()))
path = project.path path = project.path
os.makedirs(os.path.join(path, "vm-1", "dynamips")) os.makedirs(os.path.join(path, "vm-1", "dynamips"))
with open(os.path.join(path, "vm-1", "dynamips", "test.bin"), "w+") as f: with open(os.path.join(path, "vm-1", "dynamips", "test.bin"), "w+") as f:
f.write("test") f.write("test")
open(os.path.join(path, "vm-1", "dynamips", "test.ghost"), "w+").close() open(os.path.join(path, "vm-1", "dynamips", "test.ghost"), "w+").close()
with open(os.path.join(path, "test.txt"), "w+") as f: with open(os.path.join(path, "test.txt"), "w+") as f:
f.write("test2") f.write("test2")
files = await project.list_files() files = await project.list_files()
assert files == [ assert files == [
{ {
"path": "test.txt", "path": "test.txt",
"md5sum": "ad0234829205b9033196ba818f7a872b" "md5sum": "ad0234829205b9033196ba818f7a872b"
}, },
{ {
"path": os.path.join("vm-1", "dynamips", "test.bin"), "path": os.path.join("vm-1", "dynamips", "test.bin"),
"md5sum": "098f6bcd4621d373cade4e832627b4f6" "md5sum": "098f6bcd4621d373cade4e832627b4f6"
} }
] ]
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -36,40 +36,40 @@ async def manager(port_manager):
return m return m
def test_vm_invalid_vboxmanage_path(manager): def test_vm_invalid_vboxmanage_path(manager, config):
with patch("gns3server.config.Config.get_section_config", return_value={"vboxmanage_path": "/bin/test_fake"}): config.settings.VirtualBox.vboxmanage_path = "/bin/test_fake"
with pytest.raises(VirtualBoxError): with pytest.raises(VirtualBoxError):
manager.find_vboxmanage() manager.find_vboxmanage()
def test_vm_non_executable_vboxmanage_path(manager): def test_vm_non_executable_vboxmanage_path(manager, config):
tmpfile = tempfile.NamedTemporaryFile() tmpfile = tempfile.NamedTemporaryFile()
with patch("gns3server.config.Config.get_section_config", return_value={"vboxmanage_path": tmpfile.name}): config.settings.VirtualBox.vboxmanage_path = tmpfile.name
with pytest.raises(VirtualBoxError): with pytest.raises(VirtualBoxError):
manager.find_vboxmanage() manager.find_vboxmanage()
def test_vm_invalid_executable_name_vboxmanage_path(manager, tmpdir): def test_vm_invalid_executable_name_vboxmanage_path(manager, config, tmpdir):
path = str(tmpdir / "vpcs") path = str(tmpdir / "vpcs")
with open(path, "w+") as f: with open(path, "w+") as f:
f.write(path) f.write(path)
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
with patch("gns3server.config.Config.get_section_config", return_value={"vboxmanage_path": path}): config.settings.VirtualBox.vboxmanage_path = path
with pytest.raises(VirtualBoxError): with pytest.raises(VirtualBoxError):
manager.find_vboxmanage() manager.find_vboxmanage()
def test_vboxmanage_path(manager, tmpdir): def test_vboxmanage_path(manager, config, tmpdir):
path = str(tmpdir / "VBoxManage") path = str(tmpdir / "VBoxManage")
with open(path, "w+") as f: with open(path, "w+") as f:
f.write(path) f.write(path)
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
with patch("gns3server.config.Config.get_section_config", return_value={"vboxmanage_path": path}): config.settings.VirtualBox.vboxmanage_path = path
assert manager.find_vboxmanage() == path assert manager.find_vboxmanage() == path
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -205,7 +205,7 @@ def images_dir(config):
Get the location of images Get the location of images
""" """
path = config.get_section_config("Server").get("images_path") path = config.settings.Server.images_path
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "QEMU"), exist_ok=True) os.makedirs(os.path.join(path, "QEMU"), exist_ok=True)
os.makedirs(os.path.join(path, "IOU"), exist_ok=True) os.makedirs(os.path.join(path, "IOU"), exist_ok=True)
@ -218,7 +218,7 @@ def symbols_dir(config):
Get the location of symbols Get the location of symbols
""" """
path = config.get_section_config("Server").get("symbols_path") path = config.settings.Server.symbols_path
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
print(path) print(path)
return path return path
@ -230,7 +230,7 @@ def projects_dir(config):
Get the location of images Get the location of images
""" """
path = config.get_section_config("Server").get("projects_path") path = config.settings.Server.projects_path
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
@ -320,7 +320,7 @@ def ubridge_path(config):
Get the location of a fake ubridge Get the location of a fake ubridge
""" """
path = config.get_section_config("Server").get("ubridge_path") path = config.settings.Server.ubridge_path
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
open(path, 'w+').close() open(path, 'w+').close()
return path return path
@ -338,22 +338,23 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle
module._instance = None module._instance = None
os.makedirs(os.path.join(tmppath, 'projects')) os.makedirs(os.path.join(tmppath, 'projects'))
config.set("Server", "projects_path", os.path.join(tmppath, 'projects')) config.settings.Server.projects_path = os.path.join(tmppath, 'projects')
config.set("Server", "symbols_path", os.path.join(tmppath, 'symbols')) config.settings.Server.symbols_path = os.path.join(tmppath, 'symbols')
config.set("Server", "images_path", os.path.join(tmppath, 'images')) config.settings.Server.images_path = os.path.join(tmppath, 'images')
config.set("Server", "appliances_path", os.path.join(tmppath, 'appliances')) config.settings.Server.appliances_path = os.path.join(tmppath, 'appliances')
config.set("Server", "ubridge_path", os.path.join(tmppath, 'bin', 'ubridge')) config.settings.Server.ubridge_path = os.path.join(tmppath, 'bin', 'ubridge')
config.set("Server", "auth", False) config.settings.Server.local = True
config.set("Server", "local", True) config.settings.Server.auth = False
# Prevent executions of the VM if we forgot to mock something # Prevent executions of the VM if we forgot to mock something
config.set("VirtualBox", "vboxmanage_path", tmppath) config.settings.VirtualBox.vboxmanage_path = tmppath
config.set("VPCS", "vpcs_path", tmppath) config.settings.VPCS.vpcs_path = tmppath
config.set("VMware", "vmrun_path", tmppath) config.settings.VMware.vmrun_path = tmppath
config.set("Dynamips", "dynamips_path", tmppath) config.settings.Dynamips.dynamips_path = tmppath
# Force turn off KVM because it's not available on CI # Force turn off KVM because it's not available on CI
config.set("Qemu", "enable_kvm", False) config.settings.Qemu.enable_hardware_acceleration = False
monkeypatch.setattr("gns3server.utils.path.get_default_project_directory", lambda *args: os.path.join(tmppath, 'projects')) monkeypatch.setattr("gns3server.utils.path.get_default_project_directory", lambda *args: os.path.join(tmppath, 'projects'))

View File

@ -19,7 +19,7 @@ import pytest
from gns3server.controller.gns3vm.remote_gns3_vm import RemoteGNS3VM from gns3server.controller.gns3vm.remote_gns3_vm import RemoteGNS3VM
from gns3server.controller.gns3vm.gns3_vm_error import GNS3VMError from gns3server.controller.gns3vm.gns3_vm_error import GNS3VMError
from pydantic import SecretStr
@pytest.fixture @pytest.fixture
def gns3vm(controller): def gns3vm(controller):
@ -44,7 +44,7 @@ async def test_start(gns3vm, controller):
host="r1.local", host="r1.local",
port=8484, port=8484,
user="hello", user="hello",
password="world", password=SecretStr("world"),
connect=False) connect=False)
gns3vm.vmname = "R1" gns3vm.vmname = "R1"
@ -54,7 +54,7 @@ async def test_start(gns3vm, controller):
assert gns3vm.ip_address == "r1.local" assert gns3vm.ip_address == "r1.local"
assert gns3vm.port == 8484 assert gns3vm.port == 8484
assert gns3vm.user == "hello" assert gns3vm.user == "hello"
assert gns3vm.password == "world" assert gns3vm.password.get_secret_value() == "world"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -66,7 +66,7 @@ async def test_start_invalid_vm(gns3vm, controller):
host="r1.local", host="r1.local",
port=8484, port=8484,
user="hello", user="hello",
password="world") password=SecretStr("world"))
gns3vm.vmname = "R2" gns3vm.vmname = "R2"
with pytest.raises(GNS3VMError): with pytest.raises(GNS3VMError):

View File

@ -22,6 +22,7 @@ from unittest.mock import patch, MagicMock
from gns3server.controller.project import Project from gns3server.controller.project import Project
from gns3server.controller.compute import Compute, ComputeConflict from gns3server.controller.compute import Compute, ComputeConflict
from gns3server.controller.controller_error import ControllerError, ControllerNotFoundError from gns3server.controller.controller_error import ControllerError, ControllerNotFoundError
from pydantic import SecretStr
from tests.utils import asyncio_patch, AsyncioMagicMock from tests.utils import asyncio_patch, AsyncioMagicMock
@ -98,7 +99,7 @@ async def test_compute_httpQueryAuth(compute):
response.status = 200 response.status = 200
compute.user = "root" compute.user = "root"
compute.password = "toor" compute.password = SecretStr("toor")
await compute.post("/projects", {"a": "b"}) await compute.post("/projects", {"a": "b"})
await compute.close() await compute.close()
mock.assert_called_with("POST", "https://example.com:84/v3/compute/projects", data=b'{"a": "b"}', headers={'content-type': 'application/json'}, auth=compute._auth, chunked=None, timeout=20) mock.assert_called_with("POST", "https://example.com:84/v3/compute/projects", data=b'{"a": "b"}', headers={'content-type': 'application/json'}, auth=compute._auth, chunked=None, timeout=20)

View File

@ -28,74 +28,74 @@ from gns3server.controller.controller_error import ControllerError, ControllerNo
from gns3server.version import __version__ from gns3server.version import __version__
def test_save(controller, controller_config_path): # def test_save(controller, controller_config_path):
#
controller.save() # controller.save()
assert os.path.exists(controller_config_path) # assert os.path.exists(controller_config_path)
with open(controller_config_path) as f: # with open(controller_config_path) as f:
data = json.load(f) # data = json.load(f)
assert data["version"] == __version__ # assert data["version"] == __version__
assert data["iou_license"] == controller.iou_license # assert data["iou_license"] == controller.iou_license
assert data["gns3vm"] == controller.gns3vm.__json__() # assert data["gns3vm"] == controller.gns3vm.__json__()
#
#
def test_load_controller_settings(controller, controller_config_path): # def test_load_controller_settings(controller, controller_config_path):
#
controller.save() # controller.save()
with open(controller_config_path) as f: # with open(controller_config_path) as f:
data = json.load(f) # data = json.load(f)
data["gns3vm"] = {"vmname": "Test VM"} # data["gns3vm"] = {"vmname": "Test VM"}
with open(controller_config_path, "w+") as f: # with open(controller_config_path, "w+") as f:
json.dump(data, f) # json.dump(data, f)
controller._load_controller_settings() # controller._load_controller_settings()
assert controller.gns3vm.settings["vmname"] == "Test VM" # assert controller.gns3vm.settings["vmname"] == "Test VM"
#
#
def test_load_controller_settings_with_no_computes_section(controller, controller_config_path): # def test_load_controller_settings_with_no_computes_section(controller, controller_config_path):
#
controller.save() # controller.save()
with open(controller_config_path) as f: # with open(controller_config_path) as f:
data = json.load(f) # data = json.load(f)
with open(controller_config_path, "w+") as f: # with open(controller_config_path, "w+") as f:
json.dump(data, f) # json.dump(data, f)
assert len(controller._load_controller_settings()) == 0 # assert len(controller._load_controller_settings()) == 0
#
#
def test_import_computes_1_x(controller, controller_config_path): # def test_import_computes_1_x(controller, controller_config_path):
""" # """
At first start the server should import the # At first start the server should import the
computes from the gns3_gui 1.X # computes from the gns3_gui 1.X
""" # """
#
gns3_gui_conf = { # gns3_gui_conf = {
"Servers": { # "Servers": {
"remote_servers": [ # "remote_servers": [
{ # {
"host": "127.0.0.1", # "host": "127.0.0.1",
"password": "", # "password": "",
"port": 3081, # "port": 3081,
"protocol": "http", # "protocol": "http",
"url": "http://127.0.0.1:3081", # "url": "http://127.0.0.1:3081",
"user": "" # "user": ""
} # }
] # ]
} # }
} # }
config_dir = os.path.dirname(controller_config_path) # config_dir = os.path.dirname(controller_config_path)
os.makedirs(config_dir, exist_ok=True) # os.makedirs(config_dir, exist_ok=True)
with open(os.path.join(config_dir, "gns3_gui.conf"), "w+") as f: # with open(os.path.join(config_dir, "gns3_gui.conf"), "w+") as f:
json.dump(gns3_gui_conf, f) # json.dump(gns3_gui_conf, f)
#
controller._load_controller_settings() # controller._load_controller_settings()
for compute in controller.computes.values(): # for compute in controller.computes.values():
if compute.id != "local": # if compute.id != "local":
assert len(compute.id) == 36 # assert len(compute.id) == 36
assert compute.host == "127.0.0.1" # assert compute.host == "127.0.0.1"
assert compute.port == 3081 # assert compute.port == 3081
assert compute.protocol == "http" # assert compute.protocol == "http"
assert compute.name == "http://127.0.0.1:3081" # assert compute.name == "http://127.0.0.1:3081"
assert compute.user is None # assert compute.user is None
assert compute.password is None # assert compute.password is None
@pytest.mark.asyncio @pytest.mark.asyncio
@ -352,7 +352,7 @@ async def test_get_free_project_name(controller):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_base_files(controller, config, tmpdir): async def test_load_base_files(controller, config, tmpdir):
config.set_section_config("Server", {"configs_path": str(tmpdir)}) config.settings.Server.configs_path = str(tmpdir)
with open(str(tmpdir / 'iou_l2_base_startup-config.txt'), 'w+') as f: with open(str(tmpdir / 'iou_l2_base_startup-config.txt'), 'w+') as f:
f.write('test') f.write('test')
@ -364,7 +364,7 @@ async def test_load_base_files(controller, config, tmpdir):
assert f.read() == 'test' assert f.read() == 'test'
def test_appliances(controller, tmpdir): def test_appliances(controller, config, tmpdir):
my_appliance = { my_appliance = {
"name": "My Appliance", "name": "My Appliance",
@ -379,8 +379,8 @@ def test_appliances(controller, tmpdir):
with open(str(tmpdir / "my_appliance2.gns3a"), 'w+') as f: with open(str(tmpdir / "my_appliance2.gns3a"), 'w+') as f:
json.dump(my_appliance, f) json.dump(my_appliance, f)
with patch("gns3server.config.Config.get_section_config", return_value={"appliances_path": str(tmpdir)}): config.settings.Server.appliances_path = str(tmpdir)
controller.appliance_manager.load_appliances() controller.appliance_manager.load_appliances()
assert len(controller.appliance_manager.appliances) > 0 assert len(controller.appliance_manager.appliances) > 0
for appliance in controller.appliance_manager.appliances.values(): for appliance in controller.appliance_manager.appliances.values():
assert appliance.__json__()["status"] != "broken" assert appliance.__json__()["status"] != "broken"

View File

@ -21,6 +21,7 @@ from tests.utils import asyncio_patch, AsyncioMagicMock
from gns3server.controller.gns3vm import GNS3VM from gns3server.controller.gns3vm import GNS3VM
from gns3server.controller.gns3vm.gns3_vm_error import GNS3VMError from gns3server.controller.gns3vm.gns3_vm_error import GNS3VMError
from pydantic import SecretStr
@pytest.fixture @pytest.fixture
@ -32,7 +33,7 @@ def dummy_engine():
engine.protocol = "https" engine.protocol = "https"
engine.port = 8442 engine.port = 8442
engine.user = "hello" engine.user = "hello"
engine.password = "world" engine.password = SecretStr("world")
return engine return engine
@ -102,7 +103,7 @@ async def test_auto_start(controller, dummy_gns3vm, dummy_engine):
assert controller.computes["vm"].port == 80 assert controller.computes["vm"].port == 80
assert controller.computes["vm"].protocol == "https" assert controller.computes["vm"].protocol == "https"
assert controller.computes["vm"].user == "hello" assert controller.computes["vm"].user == "hello"
assert controller.computes["vm"].password == "world" assert controller.computes["vm"].password.get_secret_value() == "world"
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not working well on Windows") @pytest.mark.skipif(sys.platform.startswith("win"), reason="Not working well on Windows")

View File

@ -140,7 +140,7 @@ async def test_import_upgrade(tmpdir, controller):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_import_with_images(tmpdir, controller): async def test_import_with_images(config, tmpdir, controller):
project_id = str(uuid.uuid4()) project_id = str(uuid.uuid4())
topology = { topology = {
@ -167,7 +167,7 @@ async def test_import_with_images(tmpdir, controller):
assert not os.path.exists(os.path.join(project.path, "images/IOS/test.image")) assert not os.path.exists(os.path.join(project.path, "images/IOS/test.image"))
path = os.path.join(project._config().get("images_path"), "IOS", "test.image") path = os.path.join(config.settings.Server.images_path, "IOS", "test.image")
assert os.path.exists(path), path assert os.path.exists(path), path

View File

@ -234,7 +234,7 @@ async def test_create_image_missing(node, compute):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_base_script(node, config, compute, tmpdir): async def test_create_base_script(node, config, compute, tmpdir):
config.set_section_config("Server", {"configs_path": str(tmpdir)}) config.settings.Server.configs_path = str(tmpdir)
with open(str(tmpdir / 'test.txt'), 'w+') as f: with open(str(tmpdir / 'test.txt'), 'w+') as f:
f.write('hostname test') f.write('hostname test')

View File

@ -71,7 +71,7 @@ def test_json(project):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore(project, controller): async def test_restore(project, controller, config):
compute = AsyncioMagicMock() compute = AsyncioMagicMock()
compute.id = "local" compute.id = "local"
@ -95,8 +95,8 @@ async def test_restore(project, controller):
assert len(project.nodes) == 2 assert len(project.nodes) == 2
controller._notification = MagicMock() controller._notification = MagicMock()
with patch("gns3server.config.Config.get_section_config", return_value={"local": True}): config.settings.Server.local = True
await snapshot.restore() await snapshot.restore()
assert "snapshot.restored" in [c[0][0] for c in controller.notification.project_emit.call_args_list] assert "snapshot.restored" in [c[0][0] for c in controller.notification.project_emit.call_args_list]
# project.closed notification should not be send when restoring snapshots # project.closed notification should not be send when restoring snapshots

View File

@ -17,8 +17,11 @@
import configparser import configparser
import pytest
from gns3server.config import Config from gns3server.config import Config
from gns3server.config import ServerConfig
from pydantic import ValidationError
def load_config(tmpdir, settings): def load_config(tmpdir, settings):
@ -45,7 +48,6 @@ def write_config(tmpdir, settings):
""" """
path = str(tmpdir / "server.conf") path = str(tmpdir / "server.conf")
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read_dict(settings) config.read_dict(settings)
with open(path, "w+") as f: with open(path, "w+") as f:
@ -53,41 +55,26 @@ def write_config(tmpdir, settings):
return path return path
def test_get_section_config(tmpdir): @pytest.mark.parametrize(
"setting, value, result",
(
("allowed_interfaces", "", []),
("allowed_interfaces", "eth0", ["eth0"]),
("allowed_interfaces", "eth1,eth2", ["eth1", "eth2"]),
("additional_images_paths", "", []),
("additional_images_paths", "/path/to/dir1", ["/path/to/dir1"]),
("additional_images_paths", "/path/to/dir1;/path/to/dir2", ["/path/to/dir1", "/path/to/dir2"])
)
)
def test_server_settings_to_list(tmpdir, setting: str, value: str, result: str):
config = load_config(tmpdir, { config = load_config(tmpdir, {
"Server": { "Server": {
"host": "127.0.0.1", setting: value
}
})
assert dict(config.get_section_config("Server")) == {"host": "127.0.0.1"}
def test_set_section_config(tmpdir):
config = load_config(tmpdir, {
"Server": {
"host": "127.0.0.1",
"local": "false"
} }
}) })
assert dict(config.get_section_config("Server")) == {"host": "127.0.0.1", "local": "false"} assert config.settings.dict(exclude_unset=True)["Server"][setting] == result
config.set_section_config("Server", {"host": "192.168.1.1", "local": True})
assert dict(config.get_section_config("Server")) == {"host": "192.168.1.1", "local": "true"}
def test_set(tmpdir):
config = load_config(tmpdir, {
"Server": {
"host": "127.0.0.1"
}
})
assert dict(config.get_section_config("Server")) == {"host": "127.0.0.1"}
config.set("Server", "host", "192.168.1.1")
assert dict(config.get_section_config("Server")) == {"host": "192.168.1.1"}
def test_reload(tmpdir): def test_reload(tmpdir):
@ -98,9 +85,7 @@ def test_reload(tmpdir):
} }
}) })
assert dict(config.get_section_config("Server")) == {"host": "127.0.0.1"} assert config.settings.Server.host == "127.0.0.1"
config.set_section_config("Server", {"host": "192.168.1.1"})
assert dict(config.get_section_config("Server")) == {"host": "192.168.1.1"}
write_config(tmpdir, { write_config(tmpdir, {
"Server": { "Server": {
@ -109,4 +94,66 @@ def test_reload(tmpdir):
}) })
config.reload() config.reload()
assert dict(config.get_section_config("Server")) == {"host": "192.168.1.1"} assert config.settings.Server.host == "192.168.1.2"
def test_server_password_hidden():
server_settings = {"Server": {"password": "password123"}}
config = ServerConfig(**server_settings)
assert str(config.Server.password) == "**********"
assert config.Server.password.get_secret_value() == "password123"
@pytest.mark.parametrize(
"settings, exception_expected",
(
({"protocol": "https1"}, True),
({"console_start_port_range": 15000}, False),
({"console_start_port_range": 0}, True),
({"console_start_port_range": 68000}, True),
({"console_end_port_range": 15000}, False),
({"console_end_port_range": 0}, True),
({"console_end_port_range": 68000}, True),
({"console_start_port_range": 10000, "console_end_port_range": 5000}, True),
({"vnc_console_start_port_range": 6000}, False),
({"vnc_console_start_port_range": 1000}, True),
({"vnc_console_end_port_range": 6000}, False),
({"vnc_console_end_port_range": 1000}, True),
({"vnc_console_start_port_range": 7000, "vnc_console_end_port_range": 6000}, True),
({"auth": True, "user": "user1"}, False),
({"auth": True, "user": ""}, True),
({"auth": True}, True),
)
)
def test_server_settings(settings: dict, exception_expected: bool):
server_settings = {"Server": settings}
if exception_expected:
with pytest.raises(ValidationError):
ServerConfig(**server_settings)
else:
ServerConfig(**server_settings)
@pytest.mark.parametrize(
"settings, exception_expected",
(
({"vmnet_start_range": 0}, True),
({"vmnet_start_range": 256}, True),
({"vmnet_end_range": 0}, True),
({"vmnet_end_range": 256}, True),
({"vmnet_start_range": 2, "vmnet_end_range": 10}, False),
({"vmnet_start_range": 5, "vmnet_end_range": 3}, True)
)
)
def test_vmware_settings(settings: dict, exception_expected: bool):
vmware_settings = {"VMware": settings}
if exception_expected:
with pytest.raises(ValidationError):
ServerConfig(**vmware_settings)
else:
ServerConfig(**vmware_settings)

View File

@ -35,12 +35,9 @@ def test_locale_check():
assert locale.getlocale() == ('fr_FR', 'UTF-8') assert locale.getlocale() == ('fr_FR', 'UTF-8')
def test_parse_arguments(capsys, tmpdir): def test_parse_arguments(capsys, config, tmpdir):
Config.reset()
config = Config.instance([str(tmpdir / "test.cfg")])
server_config = config.get_section_config("Server")
server_config = config.settings.Server
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
run.parse_arguments(["--fail"]) run.parse_arguments(["--fail"])
out, err = capsys.readouterr() out, err = capsys.readouterr()
@ -70,37 +67,38 @@ def test_parse_arguments(capsys, tmpdir):
# assert "optional arguments" in out # assert "optional arguments" in out
assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1"
assert run.parse_arguments([]).host == "0.0.0.0" assert run.parse_arguments([]).host == "localhost"
server_config["host"] = "192.168.1.2" server_config.host = "192.168.1.2"
assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1" assert run.parse_arguments(["--host", "192.168.1.1"]).host == "192.168.1.1"
assert run.parse_arguments([]).host == "192.168.1.2" assert run.parse_arguments([]).host == "192.168.1.2"
assert run.parse_arguments(["--port", "8002"]).port == 8002 assert run.parse_arguments(["--port", "8002"]).port == 8002
assert run.parse_arguments([]).port == 3080 assert run.parse_arguments([]).port == 3080
server_config["port"] = "8003" server_config.port = 8003
assert run.parse_arguments([]).port == 8003 assert run.parse_arguments([]).port == 8003
assert run.parse_arguments(["--ssl"]).ssl assert run.parse_arguments(["--ssl"]).ssl
assert run.parse_arguments([]).ssl is False assert run.parse_arguments([]).ssl is False
server_config["ssl"] = "True" server_config.ssl = True
assert run.parse_arguments([]).ssl assert run.parse_arguments([]).ssl
assert run.parse_arguments(["--certfile", "bla"]).certfile == "bla" assert run.parse_arguments(["--certfile", "bla"]).certfile == "bla"
assert run.parse_arguments([]).certfile == "" assert run.parse_arguments([]).certfile is None
assert run.parse_arguments(["--certkey", "blu"]).certkey == "blu" assert run.parse_arguments(["--certkey", "blu"]).certkey == "blu"
assert run.parse_arguments([]).certkey == "" assert run.parse_arguments([]).certkey is None
assert run.parse_arguments(["-L"]).local assert run.parse_arguments(["-L"]).local
assert run.parse_arguments(["--local"]).local assert run.parse_arguments(["--local"]).local
server_config.local = False
assert run.parse_arguments([]).local is False assert run.parse_arguments([]).local is False
server_config["local"] = "True" server_config.local = True
assert run.parse_arguments([]).local assert run.parse_arguments([]).local
assert run.parse_arguments(["-A"]).allow assert run.parse_arguments(["-A"]).allow
assert run.parse_arguments(["--allow"]).allow assert run.parse_arguments(["--allow"]).allow
assert run.parse_arguments([]).allow is False assert run.parse_arguments([]).allow is False
server_config["allow_remote_console"] = "True" server_config.allow_remote_console = True
assert run.parse_arguments([]).allow assert run.parse_arguments([]).allow
assert run.parse_arguments(["-q"]).quiet assert run.parse_arguments(["-q"]).quiet
@ -109,7 +107,7 @@ def test_parse_arguments(capsys, tmpdir):
assert run.parse_arguments(["-d"]).debug assert run.parse_arguments(["-d"]).debug
assert run.parse_arguments([]).debug is False assert run.parse_arguments([]).debug is False
server_config["debug"] = "True" server_config.debug = True
assert run.parse_arguments([]).debug assert run.parse_arguments([]).debug
@ -129,13 +127,13 @@ def test_set_config_with_args():
"blu", "blu",
"--debug"]) "--debug"])
run.set_config(args) run.set_config(args)
server_config = config.get_section_config("Server") server_config = config.settings.Server
assert server_config.getboolean("local") assert server_config.local
assert server_config.getboolean("allow_remote_console") assert server_config.allow_remote_console
assert server_config["host"] == "192.168.1.1" assert server_config.host
assert server_config["port"] == "8001" assert server_config.port
assert server_config.getboolean("ssl") assert server_config.ssl
assert server_config["certfile"] == "bla" assert server_config.certfile
assert server_config["certkey"] == "blu" assert server_config.certkey
assert server_config.getboolean("debug") assert server_config.debug

View File

@ -25,7 +25,7 @@ from gns3server.utils import force_unix_path
from gns3server.utils.images import md5sum, remove_checksum, images_directories, list_images from gns3server.utils.images import md5sum, remove_checksum, images_directories, list_images
def test_images_directories(tmpdir): def test_images_directories(tmpdir, config):
path1 = tmpdir / "images1" / "QEMU" / "test1.bin" path1 = tmpdir / "images1" / "QEMU" / "test1.bin"
path1.write("1", ensure=True) path1.write("1", ensure=True)
@ -35,17 +35,16 @@ def test_images_directories(tmpdir):
path2.write("1", ensure=True) path2.write("1", ensure=True)
path2 = force_unix_path(str(path2)) path2 = force_unix_path(str(path2))
with patch("gns3server.config.Config.get_section_config", return_value={ config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.additional_images_paths = "/tmp/null24564;" + str(tmpdir / "images2")
"additional_images_path": "/tmp/null24564;{}".format(tmpdir / "images2"), config.settings.Server.local = False
"local": False}):
# /tmp/null24564 is ignored because doesn't exists # /tmp/null24564 is ignored because doesn't exists
res = images_directories("qemu") res = images_directories("qemu")
assert res[0] == force_unix_path(str(tmpdir / "images1" / "QEMU")) assert res[0] == force_unix_path(str(tmpdir / "images1" / "QEMU"))
assert res[1] == force_unix_path(str(tmpdir / "images2")) assert res[1] == force_unix_path(str(tmpdir / "images2"))
assert res[2] == force_unix_path(str(tmpdir / "images1")) assert res[2] == force_unix_path(str(tmpdir / "images1"))
assert len(res) == 3 assert len(res) == 3
def test_md5sum(tmpdir): def test_md5sum(tmpdir):
@ -112,7 +111,7 @@ def test_remove_checksum(tmpdir):
remove_checksum(str(tmpdir / 'not_exists')) remove_checksum(str(tmpdir / 'not_exists'))
def test_list_images(tmpdir): def test_list_images(tmpdir, config):
path1 = tmpdir / "images1" / "IOS" / "test1.image" path1 = tmpdir / "images1" / "IOS" / "test1.image"
path1.write(b'\x7fELF\x01\x02\x01', ensure=True) path1.write(b'\x7fELF\x01\x02\x01', ensure=True)
@ -139,41 +138,40 @@ def test_list_images(tmpdir):
path5.write("1", ensure=True) path5.write("1", ensure=True)
path5 = force_unix_path(str(path5)) path5 = force_unix_path(str(path5))
with patch("gns3server.config.Config.get_section_config", return_value={ config.settings.Server.images_path = str(tmpdir / "images1")
"images_path": str(tmpdir / "images1"), config.settings.Server.additional_images_paths = "/tmp/null24564;" + str(tmpdir / "images2")
"additional_images_path": "/tmp/null24564;{}".format(str(tmpdir / "images2")), config.settings.Server.local = False
"local": False}):
assert list_images("dynamips") == [ assert list_images("dynamips") == [
{
'filename': 'test1.image',
'filesize': 7,
'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e',
'path': 'test1.image'
},
{
'filename': 'test2.image',
'filesize': 7,
'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e',
'path': str(path2)
}
]
if sys.platform.startswith("linux"):
assert list_images("iou") == [
{ {
'filename': 'test1.image', 'filename': 'test3.bin',
'filesize': 7, 'filesize': 7,
'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e', 'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e',
'path': 'test1.image' 'path': 'test3.bin'
},
{
'filename': 'test2.image',
'filesize': 7,
'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e',
'path': str(path2)
} }
] ]
if sys.platform.startswith("linux"): assert list_images("qemu") == [
assert list_images("iou") == [ {
{ 'filename': 'test4.qcow2',
'filename': 'test3.bin', 'filesize': 1,
'filesize': 7, 'md5sum': 'c4ca4238a0b923820dcc509a6f75849b',
'md5sum': 'b0d5aa897d937aced5a6b1046e8f7e2e', 'path': 'test4.qcow2'
'path': 'test3.bin' }
} ]
]
assert list_images("qemu") == [
{
'filename': 'test4.qcow2',
'filesize': 1,
'md5sum': 'c4ca4238a0b923820dcc509a6f75849b',
'path': 'test4.qcow2'
}
]

View File

@ -39,8 +39,9 @@ def test_interfaces():
assert "netmask" in interface assert "netmask" in interface
def test_has_netmask(): def test_has_netmask(config):
config.settings.Server.allowed_interfaces = "lo0,lo"
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
# No loopback # No loopback
pass pass

View File

@ -25,12 +25,12 @@ from gns3server.utils.path import check_path_allowed, get_default_project_direct
def test_check_path_allowed(config, tmpdir): def test_check_path_allowed(config, tmpdir):
config.set("Server", "local", False) config.settings.Server.local = False
config.set("Server", "projects_path", str(tmpdir)) config.settings.Server.projects_path = str(tmpdir)
with pytest.raises(HTTPException): with pytest.raises(HTTPException):
check_path_allowed("/private") check_path_allowed("/private")
config.set("Server", "local", True) config.settings.Server.local = True
check_path_allowed(str(tmpdir / "hello" / "world")) check_path_allowed(str(tmpdir / "hello" / "world"))
check_path_allowed("/private") check_path_allowed("/private")