1
0
mirror of https://github.com/etesync/server synced 2024-11-24 09:48:09 +00:00
etesync-server/etebase_server/fastapi/msgpack.py

80 lines
3.0 KiB
Python
Raw Normal View History

2020-12-23 21:29:08 +00:00
import typing as t
2020-12-23 21:29:08 +00:00
from fastapi.routing import APIRoute, get_request_handler
from pydantic import BaseModel
2020-12-23 21:29:08 +00:00
from starlette.requests import Request
from starlette.responses import Response
from .db_hack import django_db_cleanup_decorator
2024-06-08 21:51:44 +00:00
from .utils import msgpack_decode, msgpack_encode
2020-12-23 21:29:08 +00:00
class MsgpackRequest(Request):
media_type = "application/msgpack"
2024-06-09 00:17:02 +00:00
async def raw_body(self) -> bytes:
return await super().body()
async def body(self) -> bytes:
2020-12-23 21:29:08 +00:00
if not hasattr(self, "_json"):
2024-06-09 00:17:02 +00:00
body = await self.raw_body()
self._json = msgpack_decode(body)
2020-12-23 21:29:08 +00:00
return self._json
class MsgpackResponse(Response):
media_type = "application/msgpack"
2020-12-29 11:22:36 +00:00
def render(self, content: t.Optional[t.Any]) -> bytes:
2020-12-27 19:01:14 +00:00
if content is None:
return b""
if isinstance(content, BaseModel):
2024-06-08 22:33:29 +00:00
content = content.model_dump()
return msgpack_encode(content)
2020-12-23 21:29:08 +00:00
class MsgpackRoute(APIRoute):
# keep track of content-type -> request classes
REQUESTS_CLASSES = {MsgpackRequest.media_type: MsgpackRequest}
# keep track of content-type -> response classes
ROUTES_HANDLERS_CLASSES = {MsgpackResponse.media_type: MsgpackResponse}
def __init__(self, path: str, endpoint: t.Callable[..., t.Any], *args, **kwargs):
endpoint = django_db_cleanup_decorator(endpoint)
super().__init__(path, endpoint, *args, **kwargs)
2020-12-23 21:29:08 +00:00
def _get_media_type_route_handler(self, media_type):
return get_request_handler(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
# use custom response class or fallback on default self.response_class
response_class=self.ROUTES_HANDLERS_CLASSES.get(media_type, self.response_class),
2024-06-08 23:52:10 +00:00
response_field=self.response_field,
2020-12-23 21:29:08 +00:00
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
)
def get_route_handler(self) -> t.Callable:
async def custom_route_handler(request: Request) -> Response:
content_type = request.headers.get("Content-Type")
2024-06-08 22:17:59 +00:00
if content_type is not None:
try:
request_cls = self.REQUESTS_CLASSES[content_type]
request = request_cls(request.scope, request.receive)
except KeyError:
# nothing registered to handle content_type, process given requests as-is
pass
2020-12-23 21:29:08 +00:00
accept = request.headers.get("Accept")
route_handler = self._get_media_type_route_handler(accept)
return await route_handler(request)
return custom_route_handler