From 8a964390f8b5c5c602db08daec3f3cc23282aa60 Mon Sep 17 00:00:00 2001 From: grossmj Date: Tue, 31 May 2022 18:08:34 +0700 Subject: [PATCH] Add zstandard compression support for project export --- gns3server/api/routes/controller/projects.py | 6 +- gns3server/schemas/__init__.py | 2 +- gns3server/schemas/controller/projects.py | 12 ++++ gns3server/utils/asyncio/aiozipstream.py | 70 ++++++++++++++++---- gns3server/utils/zipfile_zstd/__init__.py | 10 +++ gns3server/utils/zipfile_zstd/_patcher.py | 20 ++++++ gns3server/utils/zipfile_zstd/_zipfile.py | 64 ++++++++++++++++++ 7 files changed, 167 insertions(+), 17 deletions(-) create mode 100644 gns3server/utils/zipfile_zstd/__init__.py create mode 100644 gns3server/utils/zipfile_zstd/_patcher.py create mode 100644 gns3server/utils/zipfile_zstd/_zipfile.py diff --git a/gns3server/api/routes/controller/projects.py b/gns3server/api/routes/controller/projects.py index f5c10b52..21f92240 100644 --- a/gns3server/api/routes/controller/projects.py +++ b/gns3server/api/routes/controller/projects.py @@ -21,10 +21,10 @@ API routes for projects. import os import asyncio import tempfile -import zipfile import aiofiles import time import urllib.parse +import gns3server.utils.zipfile_zstd as zipfile import logging @@ -285,7 +285,7 @@ async def export_project( include_snapshots: bool = False, include_images: bool = False, reset_mac_addresses: bool = False, - compression: str = "zip", + compression: schemas.ProjectCompression = "zstd", ) -> StreamingResponse: """ Export a project as a portable archive. @@ -300,6 +300,8 @@ async def export_project( compression = zipfile.ZIP_BZIP2 elif compression_query == "lzma": compression = zipfile.ZIP_LZMA + elif compression_query == "zstd": + compression = zipfile.ZIP_ZSTANDARD try: begin = time.time() diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py index 03a67b97..77a5c9c3 100644 --- a/gns3server/schemas/__init__.py +++ b/gns3server/schemas/__init__.py @@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance from .controller.drawings import Drawing from .controller.gns3vm import GNS3VM from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node -from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile +from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission from .controller.tokens import Token diff --git a/gns3server/schemas/controller/projects.py b/gns3server/schemas/controller/projects.py index 2c7c846e..4d98e7c5 100644 --- a/gns3server/schemas/controller/projects.py +++ b/gns3server/schemas/controller/projects.py @@ -102,3 +102,15 @@ class ProjectFile(BaseModel): path: str = Field(..., description="File path") md5sum: str = Field(..., description="File checksum") + + +class ProjectCompression(str, Enum): + """ + Supported project compression. + """ + + none = "none" + zip = "zip" + bzip2 = "bzip2" + lzma = "lzma" + zstd = "zstd" diff --git a/gns3server/utils/asyncio/aiozipstream.py b/gns3server/utils/asyncio/aiozipstream.py index 451a3165..2510fe8e 100644 --- a/gns3server/utils/asyncio/aiozipstream.py +++ b/gns3server/utils/asyncio/aiozipstream.py @@ -43,26 +43,37 @@ from zipfile import ( stringEndArchive64Locator, ) + +ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd +ZSTANDARD_VERSION = 20 stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor -def _get_compressor(compress_type): +def _get_compressor(compress_type, compresslevel=None): """ Return the compressor. """ if compress_type == zipfile.ZIP_DEFLATED: from zipfile import zlib - + if compresslevel is not None: + return zlib.compressobj(compresslevel, zlib.DEFLATED, -15) return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15) elif compress_type == zipfile.ZIP_BZIP2: from zipfile import bz2 - + if compresslevel is not None: + return bz2.BZ2Compressor(compresslevel) return bz2.BZ2Compressor() + # compresslevel is ignored for ZIP_LZMA elif compress_type == zipfile.ZIP_LZMA: from zipfile import LZMACompressor - return LZMACompressor() + elif compress_type == ZIP_ZSTANDARD: + import zstandard as zstd + if compresslevel is not None: + params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31) + return zstd.ZstdCompressor(compression_params=params).compressobj() + return zstd.ZstdCompressor().compressobj() else: return None @@ -129,7 +140,15 @@ class ZipInfo(zipfile.ZipInfo): class ZipFile(zipfile.ZipFile): - def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768): + def __init__( + self, + fileobj=None, + mode="w", + compression=zipfile.ZIP_STORED, + allowZip64=True, + compresslevel=None, + chunksize=32768 + ): """Open the ZIP file with mode write "w".""" if mode not in ("w",): @@ -138,7 +157,13 @@ class ZipFile(zipfile.ZipFile): fileobj = PointerIO() self._comment = b"" - zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64) + zipfile.ZipFile.__init__( + self, fileobj, + mode=mode, + compression=compression, + compresslevel=compresslevel, + allowZip64=allowZip64 + ) self._chunksize = chunksize self.paths_to_write = [] @@ -195,23 +220,33 @@ class ZipFile(zipfile.ZipFile): for chunk in self._close(): yield chunk - def write(self, filename, arcname=None, compress_type=None): + def write(self, filename, arcname=None, compress_type=None, compresslevel=None): """ Write a file to the archive under the name `arcname`. """ - kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type} + kwargs = { + "filename": filename, + "arcname": arcname, + "compress_type": compress_type, + "compresslevel": compresslevel + } self.paths_to_write.append(kwargs) - def write_iter(self, arcname, iterable, compress_type=None): + def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None): """ Write the bytes iterable `iterable` to the archive under the name `arcname`. """ - kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type} + kwargs = { + "arcname": arcname, + "iterable": iterable, + "compress_type": compress_type, + "compresslevel": compresslevel + } self.paths_to_write.append(kwargs) - def writestr(self, arcname, data, compress_type=None): + def writestr(self, arcname, data, compress_type=None, compresslevel=None): """ Writes a str into ZipFile by wrapping data as a generator """ @@ -219,9 +254,9 @@ class ZipFile(zipfile.ZipFile): def _iterable(): yield data - return self.write_iter(arcname, _iterable(), compress_type=compress_type) + return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel) - async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None): + async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None): """ Put the bytes from filename into the archive under the name `arcname`. """ @@ -256,6 +291,11 @@ class ZipFile(zipfile.ZipFile): else: zinfo.compress_type = compress_type + if compresslevel is None: + zinfo._compresslevel = self.compresslevel + else: + zinfo._compresslevel = compresslevel + if st: zinfo.file_size = st[6] else: @@ -279,7 +319,7 @@ class ZipFile(zipfile.ZipFile): yield self.fp.write(zinfo.FileHeader(False)) return - cmpr = _get_compressor(zinfo.compress_type) + cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel) # Must overwrite CRC and sizes with correct data later zinfo.CRC = CRC = 0 @@ -369,6 +409,8 @@ class ZipFile(zipfile.ZipFile): min_version = max(zipfile.BZIP2_VERSION, min_version) elif zinfo.compress_type == zipfile.ZIP_LZMA: min_version = max(zipfile.LZMA_VERSION, min_version) + elif zinfo.compress_type == ZIP_ZSTANDARD: + min_version = max(ZSTANDARD_VERSION, min_version) extract_version = max(min_version, zinfo.extract_version) create_version = max(min_version, zinfo.create_version) diff --git a/gns3server/utils/zipfile_zstd/__init__.py b/gns3server/utils/zipfile_zstd/__init__.py new file mode 100644 index 00000000..61f10fd7 --- /dev/null +++ b/gns3server/utils/zipfile_zstd/__init__.py @@ -0,0 +1,10 @@ + +# NOTE: this patches the standard zipfile module +from . import _zipfile + +from zipfile import * +from zipfile import ( + ZIP_ZSTANDARD, + ZSTANDARD_VERSION, +) + diff --git a/gns3server/utils/zipfile_zstd/_patcher.py b/gns3server/utils/zipfile_zstd/_patcher.py new file mode 100644 index 00000000..83d7fc7c --- /dev/null +++ b/gns3server/utils/zipfile_zstd/_patcher.py @@ -0,0 +1,20 @@ +import functools + + +class patch: + + originals = {} + + def __init__(self, host, name): + self.host = host + self.name = name + + def __call__(self, func): + original = getattr(self.host, self.name) + self.originals[self.name] = original + + functools.update_wrapper(func, original) + setattr(self.host, self.name, func) + + return func + diff --git a/gns3server/utils/zipfile_zstd/_zipfile.py b/gns3server/utils/zipfile_zstd/_zipfile.py new file mode 100644 index 00000000..5748ad94 --- /dev/null +++ b/gns3server/utils/zipfile_zstd/_zipfile.py @@ -0,0 +1,64 @@ +import zipfile +import zstandard as zstd +import inspect + +from ._patcher import patch + + +zipfile.ZIP_ZSTANDARD = 93 +zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard' +zipfile.ZSTANDARD_VERSION = 20 + + +@patch(zipfile, '_check_compression') +def zstd_check_compression(compression): + if compression == zipfile.ZIP_ZSTANDARD: + pass + else: + patch.originals['_check_compression'](compression) + + +class ZstdDecompressObjWrapper: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + if attr == 'eof': + return False + return getattr(self.o, attr) + + +@patch(zipfile, '_get_decompressor') +def zstd_get_decompressor(compress_type): + if compress_type == zipfile.ZIP_ZSTANDARD: + return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj()) + else: + return patch.originals['_get_decompressor'](compress_type) + + +if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters: + @patch(zipfile, '_get_compressor') + def zstd_get_compressor(compress_type, compresslevel=None): + if compress_type == zipfile.ZIP_ZSTANDARD: + if compresslevel is None: + compresslevel = 3 + return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj() + else: + return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel) +else: + @patch(zipfile, '_get_compressor') + def zstd_get_compressor(compress_type, compresslevel=None): + if compress_type == zipfile.ZIP_ZSTANDARD: + if compresslevel is None: + compresslevel = 3 + return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj() + else: + return patch.originals['_get_compressor'](compress_type) + + +@patch(zipfile.ZipInfo, 'FileHeader') +def zstd_FileHeader(self, zip64=None): + if self.compress_type == zipfile.ZIP_ZSTANDARD: + self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION) + self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION) + return patch.originals['FileHeader'](self, zip64=zip64)