1
0
mirror of https://github.com/GNS3/gns3-server synced 2025-01-24 06:51:19 +00:00

Add zstandard compression support for project export

This commit is contained in:
grossmj 2022-05-31 18:08:34 +07:00
parent 37c7bc4956
commit 8a964390f8
7 changed files with 167 additions and 17 deletions

View File

@ -21,10 +21,10 @@ API routes for projects.
import os import os
import asyncio import asyncio
import tempfile import tempfile
import zipfile
import aiofiles import aiofiles
import time import time
import urllib.parse import urllib.parse
import gns3server.utils.zipfile_zstd as zipfile
import logging import logging
@ -285,7 +285,7 @@ async def export_project(
include_snapshots: bool = False, include_snapshots: bool = False,
include_images: bool = False, include_images: bool = False,
reset_mac_addresses: bool = False, reset_mac_addresses: bool = False,
compression: str = "zip", compression: schemas.ProjectCompression = "zstd",
) -> StreamingResponse: ) -> StreamingResponse:
""" """
Export a project as a portable archive. Export a project as a portable archive.
@ -300,6 +300,8 @@ async def export_project(
compression = zipfile.ZIP_BZIP2 compression = zipfile.ZIP_BZIP2
elif compression_query == "lzma": elif compression_query == "lzma":
compression = zipfile.ZIP_LZMA compression = zipfile.ZIP_LZMA
elif compression_query == "zstd":
compression = zipfile.ZIP_ZSTANDARD
try: try:
begin = time.time() begin = time.time()

View File

@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance
from .controller.drawings import Drawing from .controller.drawings import Drawing
from .controller.gns3vm import GNS3VM from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
from .controller.tokens import Token from .controller.tokens import Token

View File

@ -102,3 +102,15 @@ class ProjectFile(BaseModel):
path: str = Field(..., description="File path") path: str = Field(..., description="File path")
md5sum: str = Field(..., description="File checksum") md5sum: str = Field(..., description="File checksum")
class ProjectCompression(str, Enum):
"""
Supported project compression.
"""
none = "none"
zip = "zip"
bzip2 = "bzip2"
lzma = "lzma"
zstd = "zstd"

View File

@ -43,26 +43,37 @@ from zipfile import (
stringEndArchive64Locator, stringEndArchive64Locator,
) )
ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd
ZSTANDARD_VERSION = 20
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
def _get_compressor(compress_type): def _get_compressor(compress_type, compresslevel=None):
""" """
Return the compressor. Return the compressor.
""" """
if compress_type == zipfile.ZIP_DEFLATED: if compress_type == zipfile.ZIP_DEFLATED:
from zipfile import zlib from zipfile import zlib
if compresslevel is not None:
return zlib.compressobj(compresslevel, zlib.DEFLATED, -15)
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15) return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
elif compress_type == zipfile.ZIP_BZIP2: elif compress_type == zipfile.ZIP_BZIP2:
from zipfile import bz2 from zipfile import bz2
if compresslevel is not None:
return bz2.BZ2Compressor(compresslevel)
return bz2.BZ2Compressor() return bz2.BZ2Compressor()
# compresslevel is ignored for ZIP_LZMA
elif compress_type == zipfile.ZIP_LZMA: elif compress_type == zipfile.ZIP_LZMA:
from zipfile import LZMACompressor from zipfile import LZMACompressor
return LZMACompressor() return LZMACompressor()
elif compress_type == ZIP_ZSTANDARD:
import zstandard as zstd
if compresslevel is not None:
params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31)
return zstd.ZstdCompressor(compression_params=params).compressobj()
return zstd.ZstdCompressor().compressobj()
else: else:
return None return None
@ -129,7 +140,15 @@ class ZipInfo(zipfile.ZipInfo):
class ZipFile(zipfile.ZipFile): class ZipFile(zipfile.ZipFile):
def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768): def __init__(
self,
fileobj=None,
mode="w",
compression=zipfile.ZIP_STORED,
allowZip64=True,
compresslevel=None,
chunksize=32768
):
"""Open the ZIP file with mode write "w".""" """Open the ZIP file with mode write "w"."""
if mode not in ("w",): if mode not in ("w",):
@ -138,7 +157,13 @@ class ZipFile(zipfile.ZipFile):
fileobj = PointerIO() fileobj = PointerIO()
self._comment = b"" self._comment = b""
zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64) zipfile.ZipFile.__init__(
self, fileobj,
mode=mode,
compression=compression,
compresslevel=compresslevel,
allowZip64=allowZip64
)
self._chunksize = chunksize self._chunksize = chunksize
self.paths_to_write = [] self.paths_to_write = []
@ -195,23 +220,33 @@ class ZipFile(zipfile.ZipFile):
for chunk in self._close(): for chunk in self._close():
yield chunk yield chunk
def write(self, filename, arcname=None, compress_type=None): def write(self, filename, arcname=None, compress_type=None, compresslevel=None):
""" """
Write a file to the archive under the name `arcname`. Write a file to the archive under the name `arcname`.
""" """
kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type} kwargs = {
"filename": filename,
"arcname": arcname,
"compress_type": compress_type,
"compresslevel": compresslevel
}
self.paths_to_write.append(kwargs) self.paths_to_write.append(kwargs)
def write_iter(self, arcname, iterable, compress_type=None): def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None):
""" """
Write the bytes iterable `iterable` to the archive under the name `arcname`. Write the bytes iterable `iterable` to the archive under the name `arcname`.
""" """
kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type} kwargs = {
"arcname": arcname,
"iterable": iterable,
"compress_type": compress_type,
"compresslevel": compresslevel
}
self.paths_to_write.append(kwargs) self.paths_to_write.append(kwargs)
def writestr(self, arcname, data, compress_type=None): def writestr(self, arcname, data, compress_type=None, compresslevel=None):
""" """
Writes a str into ZipFile by wrapping data as a generator Writes a str into ZipFile by wrapping data as a generator
""" """
@ -219,9 +254,9 @@ class ZipFile(zipfile.ZipFile):
def _iterable(): def _iterable():
yield data yield data
return self.write_iter(arcname, _iterable(), compress_type=compress_type) return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel)
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None): async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None):
""" """
Put the bytes from filename into the archive under the name `arcname`. Put the bytes from filename into the archive under the name `arcname`.
""" """
@ -256,6 +291,11 @@ class ZipFile(zipfile.ZipFile):
else: else:
zinfo.compress_type = compress_type zinfo.compress_type = compress_type
if compresslevel is None:
zinfo._compresslevel = self.compresslevel
else:
zinfo._compresslevel = compresslevel
if st: if st:
zinfo.file_size = st[6] zinfo.file_size = st[6]
else: else:
@ -279,7 +319,7 @@ class ZipFile(zipfile.ZipFile):
yield self.fp.write(zinfo.FileHeader(False)) yield self.fp.write(zinfo.FileHeader(False))
return return
cmpr = _get_compressor(zinfo.compress_type) cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel)
# Must overwrite CRC and sizes with correct data later # Must overwrite CRC and sizes with correct data later
zinfo.CRC = CRC = 0 zinfo.CRC = CRC = 0
@ -369,6 +409,8 @@ class ZipFile(zipfile.ZipFile):
min_version = max(zipfile.BZIP2_VERSION, min_version) min_version = max(zipfile.BZIP2_VERSION, min_version)
elif zinfo.compress_type == zipfile.ZIP_LZMA: elif zinfo.compress_type == zipfile.ZIP_LZMA:
min_version = max(zipfile.LZMA_VERSION, min_version) min_version = max(zipfile.LZMA_VERSION, min_version)
elif zinfo.compress_type == ZIP_ZSTANDARD:
min_version = max(ZSTANDARD_VERSION, min_version)
extract_version = max(min_version, zinfo.extract_version) extract_version = max(min_version, zinfo.extract_version)
create_version = max(min_version, zinfo.create_version) create_version = max(min_version, zinfo.create_version)

View File

@ -0,0 +1,10 @@
# NOTE: this patches the standard zipfile module
from . import _zipfile
from zipfile import *
from zipfile import (
ZIP_ZSTANDARD,
ZSTANDARD_VERSION,
)

View File

@ -0,0 +1,20 @@
import functools
class patch:
originals = {}
def __init__(self, host, name):
self.host = host
self.name = name
def __call__(self, func):
original = getattr(self.host, self.name)
self.originals[self.name] = original
functools.update_wrapper(func, original)
setattr(self.host, self.name, func)
return func

View File

@ -0,0 +1,64 @@
import zipfile
import zstandard as zstd
import inspect
from ._patcher import patch
zipfile.ZIP_ZSTANDARD = 93
zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard'
zipfile.ZSTANDARD_VERSION = 20
@patch(zipfile, '_check_compression')
def zstd_check_compression(compression):
if compression == zipfile.ZIP_ZSTANDARD:
pass
else:
patch.originals['_check_compression'](compression)
class ZstdDecompressObjWrapper:
def __init__(self, o):
self.o = o
def __getattr__(self, attr):
if attr == 'eof':
return False
return getattr(self.o, attr)
@patch(zipfile, '_get_decompressor')
def zstd_get_decompressor(compress_type):
if compress_type == zipfile.ZIP_ZSTANDARD:
return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj())
else:
return patch.originals['_get_decompressor'](compress_type)
if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters:
@patch(zipfile, '_get_compressor')
def zstd_get_compressor(compress_type, compresslevel=None):
if compress_type == zipfile.ZIP_ZSTANDARD:
if compresslevel is None:
compresslevel = 3
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
else:
return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel)
else:
@patch(zipfile, '_get_compressor')
def zstd_get_compressor(compress_type, compresslevel=None):
if compress_type == zipfile.ZIP_ZSTANDARD:
if compresslevel is None:
compresslevel = 3
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
else:
return patch.originals['_get_compressor'](compress_type)
@patch(zipfile.ZipInfo, 'FileHeader')
def zstd_FileHeader(self, zip64=None):
if self.compress_type == zipfile.ZIP_ZSTANDARD:
self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION)
self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION)
return patch.originals['FileHeader'](self, zip64=zip64)