From 0be14a7b0e49ff7afaac52e5defe360ea0456779 Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Sat, 8 Jun 2024 20:17:02 -0400 Subject: [PATCH] Fixes for fastapi. --- etebase_server/fastapi/msgpack.py | 5 +- .../fastapi/routers/authentication.py | 12 ++--- etebase_server/fastapi/routers/collection.py | 51 ++++++++++++------- etebase_server/fastapi/routers/invitation.py | 14 ++--- etebase_server/fastapi/routers/member.py | 12 +++-- 5 files changed, 56 insertions(+), 38 deletions(-) diff --git a/etebase_server/fastapi/msgpack.py b/etebase_server/fastapi/msgpack.py index d857b16..9820852 100644 --- a/etebase_server/fastapi/msgpack.py +++ b/etebase_server/fastapi/msgpack.py @@ -12,9 +12,12 @@ from .utils import msgpack_decode, msgpack_encode class MsgpackRequest(Request): media_type = "application/msgpack" + async def raw_body(self) -> bytes: + return await super().body() + async def body(self) -> bytes: if not hasattr(self, "_json"): - body = await super().body() + body = await self.raw_body() self._json = msgpack_decode(body) return self._json diff --git a/etebase_server/fastapi/routers/authentication.py b/etebase_server/fastapi/routers/authentication.py index 5f1d8bb..533b0eb 100644 --- a/etebase_server/fastapi/routers/authentication.py +++ b/etebase_server/fastapi/routers/authentication.py @@ -23,7 +23,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model from ..dependencies import AuthData, get_auth_data, get_authenticated_user from ..exceptions import AuthenticationFailed, HttpError, transform_validation_error -from ..msgpack import MsgpackRoute +from ..msgpack import MsgpackResponse, MsgpackRoute from ..utils import BaseModel, get_user_username_email_kwargs, msgpack_decode, msgpack_encode, permission_responses User = get_typed_user_model() @@ -76,7 +76,7 @@ class LoginOut(BaseModel): class Authentication(BaseModel): class Config: - ignored_types= (cached_property,) + ignored_types = (cached_property,) response: bytes signature: bytes @@ -188,7 +188,7 @@ def login_challenge(user: UserType = Depends(get_login_user)): "userId": user.id, } challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder)) - return LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version) + return MsgpackResponse(LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version)) @authentication_router.post("/login/", response_model=LoginOut) @@ -198,7 +198,7 @@ def login(data: Login, request: Request): validate_login_request(data.response_data, data, user, "login", host) ret = LoginOut.from_orm(user) user_logged_in.send(sender=user.__class__, request=None, user=user) - return ret + return MsgpackResponse(ret) @authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses) @@ -223,7 +223,7 @@ def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_u ret = { "url": get_dashboard_url(CallbackContext(request.path_params, user=user)), } - return ret + return MsgpackResponse(ret) def signup_save(data: SignupIn, request: Request) -> UserType: @@ -261,4 +261,4 @@ def signup(data: SignupIn, request: Request): user = signup_save(data, request) ret = LoginOut.from_orm(user) user_signed_up.send(sender=user.__class__, request=None, user=user) - return ret + return MsgpackResponse(ret) diff --git a/etebase_server/fastapi/routers/collection.py b/etebase_server/fastapi/routers/collection.py index 04fb8cd..9fe74f1 100644 --- a/etebase_server/fastapi/routers/collection.py +++ b/etebase_server/fastapi/routers/collection.py @@ -13,7 +13,7 @@ from etebase_server.myauth.models import UserType from ..db_hack import django_db_cleanup_decorator from ..dependencies import get_collection, get_collection_queryset, get_item_queryset from ..exceptions import HttpError, PermissionDenied, ValidationError, transform_validation_error -from ..msgpack import MsgpackRoute +from ..msgpack import MsgpackRequest, MsgpackResponse, MsgpackRoute from ..redis import redisw from ..sendfile import sendfile from ..stoken_handler import filter_by_stoken, filter_by_stoken_and_limit, get_queryset_stoken, get_stoken_obj @@ -135,7 +135,7 @@ class CollectionListResponse(BaseModel): stoken: t.Optional[str] done: bool - removedMemberships: t.Optional[t.List[RemovedMembershipOut]] + removedMemberships: t.Optional[t.List[RemovedMembershipOut]] = None class CollectionItemListResponse(BaseModel): @@ -275,7 +275,7 @@ def list_multi( Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True) ) - return collection_list_common(queryset, user, stoken, limit, prefetch) + return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch)) @collection_router.get("/", response_model=CollectionListResponse, dependencies=PERMISSIONS_READ) @@ -286,7 +286,7 @@ def collection_list( user: UserType = Depends(get_authenticated_user), queryset: CollectionQuerySet = Depends(get_collection_queryset), ): - return collection_list_common(queryset, user, stoken, limit, prefetch) + return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch)) def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionInOut): @@ -365,7 +365,7 @@ def collection_get( user: UserType = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery, ): - return CollectionOut.from_orm_context(obj, Context(user, prefetch)) + return MsgpackResponse(CollectionOut.from_orm_context(obj, Context(user, prefetch))) def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool): @@ -418,7 +418,7 @@ def item_get( prefetch: Prefetch = PrefetchQuery, ): obj = queryset.get(uid=item_uid) - return CollectionItemOut.from_orm_context(obj, Context(user, prefetch)) + return MsgpackResponse(CollectionItemOut.from_orm_context(obj, Context(user, prefetch))) def item_list_common( @@ -450,7 +450,7 @@ def item_list( queryset = queryset.filter(parent__isnull=True) response = item_list_common(queryset, user, stoken, limit, prefetch) - return response + return MsgpackResponse(response) @item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ) @@ -459,7 +459,7 @@ async def item_list_subscription_ticket( user: UserType = Depends(get_authenticated_user), ): """Get an authentication ticket that can be used with the websocket endpoint""" - return await get_ticket(TicketRequest(collection=collection.uid), user) + return MsgpackResponse(await get_ticket(TicketRequest(collection=collection.uid), user)) def item_bulk_common( @@ -527,10 +527,12 @@ def item_revisions( ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result] iterator = ret_data[-1].uid if len(result) > 0 else None - return CollectionItemRevisionListResponse( - data=ret_data, - iterator=iterator, - done=done, + return MsgpackResponse( + CollectionItemRevisionListResponse( + data=ret_data, + iterator=iterator, + done=done, + ) ) @@ -560,10 +562,12 @@ def fetch_updates( new_stoken = new_stoken or stoken_rev_uid context = Context(user, prefetch) - return CollectionItemListResponse( - data=[CollectionItemOut.from_orm_context(item, context) for item in queryset], - stoken=new_stoken, - done=True, # we always return all the items, so it's always done + return MsgpackResponse( + CollectionItemListResponse( + data=[CollectionItemOut.from_orm_context(item, context) for item in queryset], + stoken=new_stoken, + done=True, # we always return all the items, so it's always done + ) ) @@ -575,7 +579,9 @@ def item_transaction( stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user), ): - return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks) + return MsgpackResponse( + item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks) + ) @item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) @@ -586,7 +592,9 @@ def item_batch( stoken: t.Optional[str] = None, user: UserType = Depends(get_authenticated_user), ): - return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks) + return MsgpackResponse( + item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks) + ) # Chunks @@ -611,7 +619,12 @@ async def chunk_update( collection: models.Collection = Depends(get_collection), ): # IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid) - content_file = ContentFile(await request.body()) + if isinstance(request, MsgpackRequest): + body = await request.raw_body() + else: + body = await request.body() + + content_file = ContentFile(body) try: await chunk_save(chunk_uid, collection, content_file) except IntegrityError: diff --git a/etebase_server/fastapi/routers/invitation.py b/etebase_server/fastapi/routers/invitation.py index b5d841b..43a182f 100644 --- a/etebase_server/fastapi/routers/invitation.py +++ b/etebase_server/fastapi/routers/invitation.py @@ -10,7 +10,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model from ..db_hack import django_db_cleanup_decorator from ..exceptions import HttpError, PermissionDenied -from ..msgpack import MsgpackRoute +from ..msgpack import MsgpackResponse, MsgpackRoute from ..utils import ( PERMISSIONS_READ, PERMISSIONS_READWRITE, @@ -34,7 +34,7 @@ class UserInfoOut(BaseModel): pubkey: bytes class Config: - from_attributes= True + from_attributes = True @classmethod def from_orm(cls: t.Type["UserInfoOut"], obj: models.UserInfo) -> "UserInfoOut": @@ -121,7 +121,7 @@ def list_common( iterator = ret_data[-1].uid if len(result) > 0 else None return InvitationListResponse( - data=ret_data, + data=[CollectionInvitationOut.from_orm(x) for x in ret_data], iterator=iterator, done=done, ) @@ -133,7 +133,7 @@ def incoming_list( limit: int = 50, queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): - return list_common(queryset, iterator, limit) + return MsgpackResponse(list_common(queryset, iterator, limit)) @invitation_incoming_router.get( @@ -144,7 +144,7 @@ def incoming_get( queryset: InvitationQuerySet = Depends(get_incoming_queryset), ): obj = get_object_or_404(queryset, uid=invitation_uid) - return CollectionInvitationOut.from_orm(obj) + return MsgpackResponse(CollectionInvitationOut.from_orm(obj)) @invitation_incoming_router.delete( @@ -219,7 +219,7 @@ def outgoing_list( limit: int = 50, queryset: InvitationQuerySet = Depends(get_outgoing_queryset), ): - return list_common(queryset, iterator, limit) + return MsgpackResponse(list_common(queryset, iterator, limit)) @invitation_outgoing_router.delete( @@ -242,4 +242,4 @@ def outgoing_fetch_user_profile( kwargs = get_user_username_email_kwargs(username) user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs) user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user) - return UserInfoOut.from_orm(user_info) + return MsgpackResponse(UserInfoOut.from_orm(user_info)) diff --git a/etebase_server/fastapi/routers/member.py b/etebase_server/fastapi/routers/member.py index e913fbf..71da4c7 100644 --- a/etebase_server/fastapi/routers/member.py +++ b/etebase_server/fastapi/routers/member.py @@ -8,7 +8,7 @@ from etebase_server.django import models from etebase_server.myauth.models import UserType, get_typed_user_model from ..db_hack import django_db_cleanup_decorator -from ..msgpack import MsgpackRoute +from ..msgpack import MsgpackResponse, MsgpackRoute from ..stoken_handler import filter_by_stoken_and_limit from ..utils import PERMISSIONS_READ, PERMISSIONS_READWRITE, BaseModel, get_object_or_404, permission_responses from .authentication import get_authenticated_user @@ -66,10 +66,12 @@ def member_list( ) new_stoken = new_stoken_obj and new_stoken_obj.uid - return MemberListResponse( - data=[CollectionMemberOut.from_orm(item) for item in result], - iterator=new_stoken, - done=done, + return MsgpackResponse( + MemberListResponse( + data=[CollectionMemberOut.from_orm(item) for item in result], + iterator=new_stoken, + done=done, + ) )