diff --git a/gns3server/api/routes/controller/projects.py b/gns3server/api/routes/controller/projects.py index f5c10b52..8dca5985 100644 --- a/gns3server/api/routes/controller/projects.py +++ b/gns3server/api/routes/controller/projects.py @@ -21,10 +21,10 @@ API routes for projects. import os import asyncio import tempfile -import zipfile import aiofiles import time import urllib.parse +import gns3server.utils.zipfile_zstd as zipfile import logging @@ -41,7 +41,7 @@ from pathlib import Path from gns3server import schemas from gns3server.controller import Controller from gns3server.controller.project import Project -from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError +from gns3server.controller.controller_error import ControllerError, ControllerBadRequestError from gns3server.controller.import_project import import_project as import_controller_project from gns3server.controller.export_project import export_project as export_controller_project from gns3server.utils.asyncio import aiozipstream @@ -285,7 +285,8 @@ async def export_project( include_snapshots: bool = False, include_images: bool = False, reset_mac_addresses: bool = False, - compression: str = "zip", + compression: schemas.ProjectCompression = "zstd", + compression_level: int = None, ) -> StreamingResponse: """ Export a project as a portable archive. @@ -294,12 +295,23 @@ async def export_project( compression_query = compression.lower() if compression_query == "zip": compression = zipfile.ZIP_DEFLATED + if compression_level is not None and (compression_level < 0 or compression_level > 9): + raise ControllerBadRequestError("Compression level must be between 0 and 9 for ZIP compression") elif compression_query == "none": compression = zipfile.ZIP_STORED elif compression_query == "bzip2": compression = zipfile.ZIP_BZIP2 + if compression_level is not None and (compression_level < 1 or compression_level > 9): + raise ControllerBadRequestError("Compression level must be between 1 and 9 for BZIP2 compression") elif compression_query == "lzma": compression = zipfile.ZIP_LZMA + elif compression_query == "zstd": + compression = zipfile.ZIP_ZSTANDARD + if compression_level is not None and (compression_level < 1 or compression_level > 22): + raise ControllerBadRequestError("Compression level must be between 1 and 22 for Zstandard compression") + + if compression_level is not None and compression_query in ("none", "lzma"): + raise ControllerBadRequestError(f"Compression level is not supported for '{compression_query}' compression method") try: begin = time.time() @@ -307,8 +319,10 @@ async def export_project( working_dir = os.path.abspath(os.path.join(project.path, os.pardir)) async def streamer(): + log.info(f"Exporting project '{project.name}' with '{compression_query}' compression " + f"(level {compression_level})") with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir: - with aiozipstream.ZipFile(compression=compression) as zstream: + with aiozipstream.ZipFile(compression=compression, compresslevel=compression_level) as zstream: await export_controller_project( zstream, project, diff --git a/gns3server/api/server.py b/gns3server/api/server.py index 4c1c1b86..c3ceb816 100644 --- a/gns3server/api/server.py +++ b/gns3server/api/server.py @@ -166,12 +166,14 @@ async def sqlalchemry_error_handler(request: Request, exc: SQLAlchemyError): content={"message": "Database error detected, please check logs to find details"}, ) +# FIXME: do not use this middleware since it creates issue when using StreamingResponse +# see https://starlette-context.readthedocs.io/en/latest/middleware.html#why-are-there-two-middlewares-that-do-the-same-thing -@app.middleware("http") -async def add_extra_headers(request: Request, call_next): - start_time = time.time() - response = await call_next(request) - process_time = time.time() - start_time - response.headers["X-Process-Time"] = str(process_time) - response.headers["X-GNS3-Server-Version"] = f"{__version__}" - return response +# @app.middleware("http") +# async def add_extra_headers(request: Request, call_next): +# start_time = time.time() +# response = await call_next(request) +# process_time = time.time() - start_time +# response.headers["X-Process-Time"] = str(process_time) +# response.headers["X-GNS3-Server-Version"] = f"{__version__}" +# return response diff --git a/gns3server/controller/export_project.py b/gns3server/controller/export_project.py index 3b308a3f..4ae976d2 100644 --- a/gns3server/controller/export_project.py +++ b/gns3server/controller/export_project.py @@ -16,7 +16,6 @@ # along with this program. If not, see . import os -import sys import json import asyncio import aiofiles diff --git a/gns3server/controller/import_project.py b/gns3server/controller/import_project.py index f653cece..545c4ac1 100644 --- a/gns3server/controller/import_project.py +++ b/gns3server/controller/import_project.py @@ -20,10 +20,10 @@ import sys import json import uuid import shutil -import zipfile import aiofiles import itertools import tempfile +import gns3server.utils.zipfile_zstd as zipfile_zstd from .controller_error import ControllerError from .topology import load_topology @@ -60,9 +60,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non raise ControllerError("The destination path should not contain .gns3") try: - with zipfile.ZipFile(stream) as zip_file: + with zipfile_zstd.ZipFile(stream) as zip_file: project_file = zip_file.read("project.gns3").decode() - except zipfile.BadZipFile: + except zipfile_zstd.BadZipFile: raise ControllerError("Cannot import project, not a GNS3 project (invalid zip)") except KeyError: raise ControllerError("Cannot import project, project.gns3 file could not be found") @@ -92,9 +92,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non raise ControllerError("The project name contain non supported or invalid characters") try: - with zipfile.ZipFile(stream) as zip_file: + with zipfile_zstd.ZipFile(stream) as zip_file: await wait_run_in_executor(zip_file.extractall, path) - except zipfile.BadZipFile: + except zipfile_zstd.BadZipFile: raise ControllerError("Cannot extract files from GNS3 project (invalid zip)") topology = load_topology(os.path.join(path, "project.gns3")) @@ -264,11 +264,11 @@ async def _import_snapshots(snapshots_path, project_name, project_id): # extract everything to a temporary directory try: with open(snapshot_path, "rb") as f: - with zipfile.ZipFile(f) as zip_file: + with zipfile_zstd.ZipFile(f) as zip_file: await wait_run_in_executor(zip_file.extractall, tmpdir) except OSError as e: raise ControllerError(f"Cannot open snapshot '{os.path.basename(snapshot)}': {e}") - except zipfile.BadZipFile: + except zipfile_zstd.BadZipFile: raise ControllerError( f"Cannot extract files from snapshot '{os.path.basename(snapshot)}': not a GNS3 project (invalid zip)" ) @@ -294,7 +294,7 @@ async def _import_snapshots(snapshots_path, project_name, project_id): # write everything back to the original snapshot file try: - with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream: + with aiozipstream.ZipFile(compression=zipfile_zstd.ZIP_STORED) as zstream: for root, dirs, files in os.walk(tmpdir, topdown=True, followlinks=False): for file in files: path = os.path.join(root, file) diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py index 03a67b97..77a5c9c3 100644 --- a/gns3server/schemas/__init__.py +++ b/gns3server/schemas/__init__.py @@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance from .controller.drawings import Drawing from .controller.gns3vm import GNS3VM from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node -from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile +from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission from .controller.tokens import Token diff --git a/gns3server/schemas/controller/projects.py b/gns3server/schemas/controller/projects.py index 2c7c846e..4d98e7c5 100644 --- a/gns3server/schemas/controller/projects.py +++ b/gns3server/schemas/controller/projects.py @@ -102,3 +102,15 @@ class ProjectFile(BaseModel): path: str = Field(..., description="File path") md5sum: str = Field(..., description="File checksum") + + +class ProjectCompression(str, Enum): + """ + Supported project compression. + """ + + none = "none" + zip = "zip" + bzip2 = "bzip2" + lzma = "lzma" + zstd = "zstd" diff --git a/gns3server/utils/asyncio/aiozipstream.py b/gns3server/utils/asyncio/aiozipstream.py index 451a3165..c1f74169 100644 --- a/gns3server/utils/asyncio/aiozipstream.py +++ b/gns3server/utils/asyncio/aiozipstream.py @@ -43,26 +43,38 @@ from zipfile import ( stringEndArchive64Locator, ) + +ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd +ZSTANDARD_VERSION = 20 stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor -def _get_compressor(compress_type): +def _get_compressor(compress_type, compresslevel=None): """ Return the compressor. """ if compress_type == zipfile.ZIP_DEFLATED: from zipfile import zlib - + if compresslevel is not None: + return zlib.compressobj(compresslevel, zlib.DEFLATED, -15) return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15) elif compress_type == zipfile.ZIP_BZIP2: from zipfile import bz2 - + if compresslevel is not None: + return bz2.BZ2Compressor(compresslevel) return bz2.BZ2Compressor() + # compresslevel is ignored for ZIP_LZMA elif compress_type == zipfile.ZIP_LZMA: from zipfile import LZMACompressor - return LZMACompressor() + elif compress_type == ZIP_ZSTANDARD: + import zstandard as zstd + if compresslevel is not None: + #params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31) + #return zstd.ZstdCompressor(compression_params=params).compressobj() + return zstd.ZstdCompressor(level=compresslevel).compressobj() + return zstd.ZstdCompressor().compressobj() else: return None @@ -129,7 +141,15 @@ class ZipInfo(zipfile.ZipInfo): class ZipFile(zipfile.ZipFile): - def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768): + def __init__( + self, + fileobj=None, + mode="w", + compression=zipfile.ZIP_STORED, + allowZip64=True, + compresslevel=None, + chunksize=32768 + ): """Open the ZIP file with mode write "w".""" if mode not in ("w",): @@ -138,7 +158,13 @@ class ZipFile(zipfile.ZipFile): fileobj = PointerIO() self._comment = b"" - zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64) + zipfile.ZipFile.__init__( + self, fileobj, + mode=mode, + compression=compression, + compresslevel=compresslevel, + allowZip64=allowZip64 + ) self._chunksize = chunksize self.paths_to_write = [] @@ -195,23 +221,33 @@ class ZipFile(zipfile.ZipFile): for chunk in self._close(): yield chunk - def write(self, filename, arcname=None, compress_type=None): + def write(self, filename, arcname=None, compress_type=None, compresslevel=None): """ Write a file to the archive under the name `arcname`. """ - kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type} + kwargs = { + "filename": filename, + "arcname": arcname, + "compress_type": compress_type, + "compresslevel": compresslevel + } self.paths_to_write.append(kwargs) - def write_iter(self, arcname, iterable, compress_type=None): + def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None): """ Write the bytes iterable `iterable` to the archive under the name `arcname`. """ - kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type} + kwargs = { + "arcname": arcname, + "iterable": iterable, + "compress_type": compress_type, + "compresslevel": compresslevel + } self.paths_to_write.append(kwargs) - def writestr(self, arcname, data, compress_type=None): + def writestr(self, arcname, data, compress_type=None, compresslevel=None): """ Writes a str into ZipFile by wrapping data as a generator """ @@ -219,9 +255,9 @@ class ZipFile(zipfile.ZipFile): def _iterable(): yield data - return self.write_iter(arcname, _iterable(), compress_type=compress_type) + return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel) - async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None): + async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None): """ Put the bytes from filename into the archive under the name `arcname`. """ @@ -256,6 +292,11 @@ class ZipFile(zipfile.ZipFile): else: zinfo.compress_type = compress_type + if compresslevel is None: + zinfo._compresslevel = self.compresslevel + else: + zinfo._compresslevel = compresslevel + if st: zinfo.file_size = st[6] else: @@ -279,7 +320,7 @@ class ZipFile(zipfile.ZipFile): yield self.fp.write(zinfo.FileHeader(False)) return - cmpr = _get_compressor(zinfo.compress_type) + cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel) # Must overwrite CRC and sizes with correct data later zinfo.CRC = CRC = 0 @@ -369,6 +410,8 @@ class ZipFile(zipfile.ZipFile): min_version = max(zipfile.BZIP2_VERSION, min_version) elif zinfo.compress_type == zipfile.ZIP_LZMA: min_version = max(zipfile.LZMA_VERSION, min_version) + elif zinfo.compress_type == ZIP_ZSTANDARD: + min_version = max(ZSTANDARD_VERSION, min_version) extract_version = max(min_version, zinfo.extract_version) create_version = max(min_version, zinfo.create_version) diff --git a/gns3server/utils/zipfile_zstd/__init__.py b/gns3server/utils/zipfile_zstd/__init__.py new file mode 100644 index 00000000..61f10fd7 --- /dev/null +++ b/gns3server/utils/zipfile_zstd/__init__.py @@ -0,0 +1,10 @@ + +# NOTE: this patches the standard zipfile module +from . import _zipfile + +from zipfile import * +from zipfile import ( + ZIP_ZSTANDARD, + ZSTANDARD_VERSION, +) + diff --git a/gns3server/utils/zipfile_zstd/_patcher.py b/gns3server/utils/zipfile_zstd/_patcher.py new file mode 100644 index 00000000..83d7fc7c --- /dev/null +++ b/gns3server/utils/zipfile_zstd/_patcher.py @@ -0,0 +1,20 @@ +import functools + + +class patch: + + originals = {} + + def __init__(self, host, name): + self.host = host + self.name = name + + def __call__(self, func): + original = getattr(self.host, self.name) + self.originals[self.name] = original + + functools.update_wrapper(func, original) + setattr(self.host, self.name, func) + + return func + diff --git a/gns3server/utils/zipfile_zstd/_zipfile.py b/gns3server/utils/zipfile_zstd/_zipfile.py new file mode 100644 index 00000000..5748ad94 --- /dev/null +++ b/gns3server/utils/zipfile_zstd/_zipfile.py @@ -0,0 +1,64 @@ +import zipfile +import zstandard as zstd +import inspect + +from ._patcher import patch + + +zipfile.ZIP_ZSTANDARD = 93 +zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard' +zipfile.ZSTANDARD_VERSION = 20 + + +@patch(zipfile, '_check_compression') +def zstd_check_compression(compression): + if compression == zipfile.ZIP_ZSTANDARD: + pass + else: + patch.originals['_check_compression'](compression) + + +class ZstdDecompressObjWrapper: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + if attr == 'eof': + return False + return getattr(self.o, attr) + + +@patch(zipfile, '_get_decompressor') +def zstd_get_decompressor(compress_type): + if compress_type == zipfile.ZIP_ZSTANDARD: + return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj()) + else: + return patch.originals['_get_decompressor'](compress_type) + + +if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters: + @patch(zipfile, '_get_compressor') + def zstd_get_compressor(compress_type, compresslevel=None): + if compress_type == zipfile.ZIP_ZSTANDARD: + if compresslevel is None: + compresslevel = 3 + return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj() + else: + return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel) +else: + @patch(zipfile, '_get_compressor') + def zstd_get_compressor(compress_type, compresslevel=None): + if compress_type == zipfile.ZIP_ZSTANDARD: + if compresslevel is None: + compresslevel = 3 + return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj() + else: + return patch.originals['_get_compressor'](compress_type) + + +@patch(zipfile.ZipInfo, 'FileHeader') +def zstd_FileHeader(self, zip64=None): + if self.compress_type == zipfile.ZIP_ZSTANDARD: + self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION) + self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION) + return patch.originals['FileHeader'](self, zip64=zip64) diff --git a/requirements.txt b/requirements.txt index 3aeef5e7..e59de919 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ passlib[bcrypt]==1.7.4 python-jose==3.3.0 email-validator==1.2.1 watchfiles==0.14.1 +zstandard==0.17.0 setuptools==60.6.0 # don't upgrade because of https://github.com/pypa/setuptools/issues/3084 diff --git a/tests/api/routes/controller/test_projects.py b/tests/api/routes/controller/test_projects.py index ce727916..0787a0aa 100644 --- a/tests/api/routes/controller/test_projects.py +++ b/tests/api/routes/controller/test_projects.py @@ -17,7 +17,6 @@ import uuid import os -import zipfile import json import pytest @@ -26,6 +25,7 @@ from httpx import AsyncClient from unittest.mock import patch, MagicMock from tests.utils import asyncio_patch +import gns3server.utils.zipfile_zstd as zipfile_zstd from gns3server.controller import Controller from gns3server.controller.project import Project @@ -261,7 +261,7 @@ async def test_export_with_images(app: FastAPI, client: AsyncClient, tmpdir, pro with open(str(tmpdir / 'project.zip'), 'wb+') as f: f.write(response.content) - with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip: + with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip: with myzip.open("a") as myfile: content = myfile.read() assert content == b"hello" @@ -304,7 +304,7 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir, with open(str(tmpdir / 'project.zip'), 'wb+') as f: f.write(response.content) - with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip: + with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip: with myzip.open("a") as myfile: content = myfile.read() assert content == b"hello" @@ -313,6 +313,67 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir, myzip.getinfo("images/IOS/test.image") +@pytest.mark.parametrize( + "compression, compression_level, status_code", + ( + ("none", None, status.HTTP_200_OK), + ("none", 4, status.HTTP_400_BAD_REQUEST), + ("zip", None, status.HTTP_200_OK), + ("zip", 1, status.HTTP_200_OK), + ("zip", 12, status.HTTP_400_BAD_REQUEST), + ("bzip2", None, status.HTTP_200_OK), + ("bzip2", 1, status.HTTP_200_OK), + ("bzip2", 13, status.HTTP_400_BAD_REQUEST), + ("lzma", None, status.HTTP_200_OK), + ("lzma", 1, status.HTTP_400_BAD_REQUEST), + ("zstd", None, status.HTTP_200_OK), + ("zstd", 12, status.HTTP_200_OK), + ("zstd", 23, status.HTTP_400_BAD_REQUEST), + ) +) +async def test_export_compression( + app: FastAPI, + client: AsyncClient, + tmpdir, + project: Project, + compression: str, + compression_level: int, + status_code: int +) -> None: + + project.dump = MagicMock() + os.makedirs(project.path, exist_ok=True) + + topology = { + "topology": { + "nodes": [ + { + "node_type": "qemu" + } + ] + } + } + with open(os.path.join(project.path, "test.gns3"), 'w+') as f: + json.dump(topology, f) + + params = {"compression": compression} + if compression_level: + params["compression_level"] = compression_level + response = await client.get(app.url_path_for("export_project", project_id=project.id), params=params) + assert response.status_code == status_code + + if response.status_code == status.HTTP_200_OK: + assert response.headers['CONTENT-TYPE'] == 'application/gns3project' + assert response.headers['CONTENT-DISPOSITION'] == 'attachment; filename="{}.gns3project"'.format(project.name) + + with open(str(tmpdir / 'project.zip'), 'wb+') as f: + f.write(response.content) + + with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip: + with myzip.open("project.gns3") as myfile: + myfile.read() + + async def test_get_file(app: FastAPI, client: AsyncClient, project: Project) -> None: os.makedirs(project.path, exist_ok=True)