diff --git a/etebase_fastapi/db_hack.py b/etebase_fastapi/db_hack.py new file mode 100644 index 0000000..24d5824 --- /dev/null +++ b/etebase_fastapi/db_hack.py @@ -0,0 +1,27 @@ +""" +FIXME: this whole function is a hack around the django db limitations due to how db connections are cached and cleaned. +Essentially django assumes there's the django request dispatcher to automatically clean up after the ORM. +""" +import typing as t +from functools import wraps + +from django.db import close_old_connections, reset_queries + + +def django_db_cleanup(): + reset_queries() + close_old_connections() + + +def django_db_cleanup_decorator(func: t.Callable[..., t.Any]): + from inspect import iscoroutinefunction + + if iscoroutinefunction(func): + return func + + @wraps(func) + def wrapper(*args, **kwargs): + django_db_cleanup() + return func(*args, **kwargs) + + return wrapper diff --git a/etebase_fastapi/msgpack.py b/etebase_fastapi/msgpack.py index 8de8806..67627e1 100644 --- a/etebase_fastapi/msgpack.py +++ b/etebase_fastapi/msgpack.py @@ -1,10 +1,13 @@ import typing as t + +from fastapi import params from fastapi.routing import APIRoute, get_request_handler from pydantic import BaseModel from starlette.requests import Request from starlette.responses import Response from .utils import msgpack_encode, msgpack_decode +from .db_hack import django_db_cleanup_decorator class MsgpackRequest(Request): @@ -35,6 +38,22 @@ class MsgpackRoute(APIRoute): # keep track of content-type -> response classes ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse} + def __init__( + self, + path: str, + endpoint: t.Callable[..., t.Any], + *args, + dependencies: t.Optional[t.Sequence[params.Depends]] = None, + **kwargs + ): + if dependencies is not None: + dependencies = [ + params.Depends(django_db_cleanup_decorator(dep.dependency), use_cache=dep.use_cache) + for dep in dependencies + ] + endpoint = django_db_cleanup_decorator(endpoint) + super().__init__(path, endpoint, *args, dependencies=dependencies, **kwargs) + def _get_media_type_route_handler(self, media_type): return get_request_handler( dependant=self.dependant,