1
0
mirror of https://github.com/etesync/server synced 2024-11-19 15:28:08 +00:00

Fixes for fastapi.

This commit is contained in:
Tom Hacohen 2024-06-08 20:17:02 -04:00
parent 57e676baa1
commit 0be14a7b0e
5 changed files with 56 additions and 38 deletions

View File

@ -12,9 +12,12 @@ from .utils import msgpack_decode, msgpack_encode
class MsgpackRequest(Request): class MsgpackRequest(Request):
media_type = "application/msgpack" media_type = "application/msgpack"
async def raw_body(self) -> bytes:
return await super().body()
async def body(self) -> bytes: async def body(self) -> bytes:
if not hasattr(self, "_json"): if not hasattr(self, "_json"):
body = await super().body() body = await self.raw_body()
self._json = msgpack_decode(body) self._json = msgpack_decode(body)
return self._json return self._json

View File

@ -23,7 +23,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model
from ..dependencies import AuthData, get_auth_data, get_authenticated_user from ..dependencies import AuthData, get_auth_data, get_authenticated_user
from ..exceptions import AuthenticationFailed, HttpError, transform_validation_error from ..exceptions import AuthenticationFailed, HttpError, transform_validation_error
from ..msgpack import MsgpackRoute from ..msgpack import MsgpackResponse, MsgpackRoute
from ..utils import BaseModel, get_user_username_email_kwargs, msgpack_decode, msgpack_encode, permission_responses from ..utils import BaseModel, get_user_username_email_kwargs, msgpack_decode, msgpack_encode, permission_responses
User = get_typed_user_model() User = get_typed_user_model()
@ -76,7 +76,7 @@ class LoginOut(BaseModel):
class Authentication(BaseModel): class Authentication(BaseModel):
class Config: class Config:
ignored_types= (cached_property,) ignored_types = (cached_property,)
response: bytes response: bytes
signature: bytes signature: bytes
@ -188,7 +188,7 @@ def login_challenge(user: UserType = Depends(get_login_user)):
"userId": user.id, "userId": user.id,
} }
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder)) challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
return LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version) return MsgpackResponse(LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version))
@authentication_router.post("/login/", response_model=LoginOut) @authentication_router.post("/login/", response_model=LoginOut)
@ -198,7 +198,7 @@ def login(data: Login, request: Request):
validate_login_request(data.response_data, data, user, "login", host) validate_login_request(data.response_data, data, user, "login", host)
ret = LoginOut.from_orm(user) ret = LoginOut.from_orm(user)
user_logged_in.send(sender=user.__class__, request=None, user=user) user_logged_in.send(sender=user.__class__, request=None, user=user)
return ret return MsgpackResponse(ret)
@authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses) @authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
@ -223,7 +223,7 @@ def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_u
ret = { ret = {
"url": get_dashboard_url(CallbackContext(request.path_params, user=user)), "url": get_dashboard_url(CallbackContext(request.path_params, user=user)),
} }
return ret return MsgpackResponse(ret)
def signup_save(data: SignupIn, request: Request) -> UserType: def signup_save(data: SignupIn, request: Request) -> UserType:
@ -261,4 +261,4 @@ def signup(data: SignupIn, request: Request):
user = signup_save(data, request) user = signup_save(data, request)
ret = LoginOut.from_orm(user) ret = LoginOut.from_orm(user)
user_signed_up.send(sender=user.__class__, request=None, user=user) user_signed_up.send(sender=user.__class__, request=None, user=user)
return ret return MsgpackResponse(ret)

View File

@ -13,7 +13,7 @@ from etebase_server.myauth.models import UserType
from ..db_hack import django_db_cleanup_decorator from ..db_hack import django_db_cleanup_decorator
from ..dependencies import get_collection, get_collection_queryset, get_item_queryset from ..dependencies import get_collection, get_collection_queryset, get_item_queryset
from ..exceptions import HttpError, PermissionDenied, ValidationError, transform_validation_error from ..exceptions import HttpError, PermissionDenied, ValidationError, transform_validation_error
from ..msgpack import MsgpackRoute from ..msgpack import MsgpackRequest, MsgpackResponse, MsgpackRoute
from ..redis import redisw from ..redis import redisw
from ..sendfile import sendfile from ..sendfile import sendfile
from ..stoken_handler import filter_by_stoken, filter_by_stoken_and_limit, get_queryset_stoken, get_stoken_obj from ..stoken_handler import filter_by_stoken, filter_by_stoken_and_limit, get_queryset_stoken, get_stoken_obj
@ -135,7 +135,7 @@ class CollectionListResponse(BaseModel):
stoken: t.Optional[str] stoken: t.Optional[str]
done: bool done: bool
removedMemberships: t.Optional[t.List[RemovedMembershipOut]] removedMemberships: t.Optional[t.List[RemovedMembershipOut]] = None
class CollectionItemListResponse(BaseModel): class CollectionItemListResponse(BaseModel):
@ -275,7 +275,7 @@ def list_multi(
Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True) Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True)
) )
return collection_list_common(queryset, user, stoken, limit, prefetch) return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch))
@collection_router.get("/", response_model=CollectionListResponse, dependencies=PERMISSIONS_READ) @collection_router.get("/", response_model=CollectionListResponse, dependencies=PERMISSIONS_READ)
@ -286,7 +286,7 @@ def collection_list(
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
queryset: CollectionQuerySet = Depends(get_collection_queryset), queryset: CollectionQuerySet = Depends(get_collection_queryset),
): ):
return collection_list_common(queryset, user, stoken, limit, prefetch) return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch))
def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionInOut): def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionInOut):
@ -365,7 +365,7 @@ def collection_get(
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
return CollectionOut.from_orm_context(obj, Context(user, prefetch)) return MsgpackResponse(CollectionOut.from_orm_context(obj, Context(user, prefetch)))
def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool): def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool):
@ -418,7 +418,7 @@ def item_get(
prefetch: Prefetch = PrefetchQuery, prefetch: Prefetch = PrefetchQuery,
): ):
obj = queryset.get(uid=item_uid) obj = queryset.get(uid=item_uid)
return CollectionItemOut.from_orm_context(obj, Context(user, prefetch)) return MsgpackResponse(CollectionItemOut.from_orm_context(obj, Context(user, prefetch)))
def item_list_common( def item_list_common(
@ -450,7 +450,7 @@ def item_list(
queryset = queryset.filter(parent__isnull=True) queryset = queryset.filter(parent__isnull=True)
response = item_list_common(queryset, user, stoken, limit, prefetch) response = item_list_common(queryset, user, stoken, limit, prefetch)
return response return MsgpackResponse(response)
@item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ) @item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ)
@ -459,7 +459,7 @@ async def item_list_subscription_ticket(
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
"""Get an authentication ticket that can be used with the websocket endpoint""" """Get an authentication ticket that can be used with the websocket endpoint"""
return await get_ticket(TicketRequest(collection=collection.uid), user) return MsgpackResponse(await get_ticket(TicketRequest(collection=collection.uid), user))
def item_bulk_common( def item_bulk_common(
@ -527,10 +527,12 @@ def item_revisions(
ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result] ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result]
iterator = ret_data[-1].uid if len(result) > 0 else None iterator = ret_data[-1].uid if len(result) > 0 else None
return CollectionItemRevisionListResponse( return MsgpackResponse(
data=ret_data, CollectionItemRevisionListResponse(
iterator=iterator, data=ret_data,
done=done, iterator=iterator,
done=done,
)
) )
@ -560,10 +562,12 @@ def fetch_updates(
new_stoken = new_stoken or stoken_rev_uid new_stoken = new_stoken or stoken_rev_uid
context = Context(user, prefetch) context = Context(user, prefetch)
return CollectionItemListResponse( return MsgpackResponse(
data=[CollectionItemOut.from_orm_context(item, context) for item in queryset], CollectionItemListResponse(
stoken=new_stoken, data=[CollectionItemOut.from_orm_context(item, context) for item in queryset],
done=True, # we always return all the items, so it's always done stoken=new_stoken,
done=True, # we always return all the items, so it's always done
)
) )
@ -575,7 +579,9 @@ def item_transaction(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks) return MsgpackResponse(
item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks)
)
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE]) @item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
@ -586,7 +592,9 @@ def item_batch(
stoken: t.Optional[str] = None, stoken: t.Optional[str] = None,
user: UserType = Depends(get_authenticated_user), user: UserType = Depends(get_authenticated_user),
): ):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks) return MsgpackResponse(
item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks)
)
# Chunks # Chunks
@ -611,7 +619,12 @@ async def chunk_update(
collection: models.Collection = Depends(get_collection), collection: models.Collection = Depends(get_collection),
): ):
# IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid) # IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid)
content_file = ContentFile(await request.body()) if isinstance(request, MsgpackRequest):
body = await request.raw_body()
else:
body = await request.body()
content_file = ContentFile(body)
try: try:
await chunk_save(chunk_uid, collection, content_file) await chunk_save(chunk_uid, collection, content_file)
except IntegrityError: except IntegrityError:

View File

@ -10,7 +10,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model
from ..db_hack import django_db_cleanup_decorator from ..db_hack import django_db_cleanup_decorator
from ..exceptions import HttpError, PermissionDenied from ..exceptions import HttpError, PermissionDenied
from ..msgpack import MsgpackRoute from ..msgpack import MsgpackResponse, MsgpackRoute
from ..utils import ( from ..utils import (
PERMISSIONS_READ, PERMISSIONS_READ,
PERMISSIONS_READWRITE, PERMISSIONS_READWRITE,
@ -34,7 +34,7 @@ class UserInfoOut(BaseModel):
pubkey: bytes pubkey: bytes
class Config: class Config:
from_attributes= True from_attributes = True
@classmethod @classmethod
def from_orm(cls: t.Type["UserInfoOut"], obj: models.UserInfo) -> "UserInfoOut": def from_orm(cls: t.Type["UserInfoOut"], obj: models.UserInfo) -> "UserInfoOut":
@ -121,7 +121,7 @@ def list_common(
iterator = ret_data[-1].uid if len(result) > 0 else None iterator = ret_data[-1].uid if len(result) > 0 else None
return InvitationListResponse( return InvitationListResponse(
data=ret_data, data=[CollectionInvitationOut.from_orm(x) for x in ret_data],
iterator=iterator, iterator=iterator,
done=done, done=done,
) )
@ -133,7 +133,7 @@ def incoming_list(
limit: int = 50, limit: int = 50,
queryset: InvitationQuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
return list_common(queryset, iterator, limit) return MsgpackResponse(list_common(queryset, iterator, limit))
@invitation_incoming_router.get( @invitation_incoming_router.get(
@ -144,7 +144,7 @@ def incoming_get(
queryset: InvitationQuerySet = Depends(get_incoming_queryset), queryset: InvitationQuerySet = Depends(get_incoming_queryset),
): ):
obj = get_object_or_404(queryset, uid=invitation_uid) obj = get_object_or_404(queryset, uid=invitation_uid)
return CollectionInvitationOut.from_orm(obj) return MsgpackResponse(CollectionInvitationOut.from_orm(obj))
@invitation_incoming_router.delete( @invitation_incoming_router.delete(
@ -219,7 +219,7 @@ def outgoing_list(
limit: int = 50, limit: int = 50,
queryset: InvitationQuerySet = Depends(get_outgoing_queryset), queryset: InvitationQuerySet = Depends(get_outgoing_queryset),
): ):
return list_common(queryset, iterator, limit) return MsgpackResponse(list_common(queryset, iterator, limit))
@invitation_outgoing_router.delete( @invitation_outgoing_router.delete(
@ -242,4 +242,4 @@ def outgoing_fetch_user_profile(
kwargs = get_user_username_email_kwargs(username) kwargs = get_user_username_email_kwargs(username)
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs) user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user) user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user)
return UserInfoOut.from_orm(user_info) return MsgpackResponse(UserInfoOut.from_orm(user_info))

View File

@ -8,7 +8,7 @@ from etebase_server.django import models
from etebase_server.myauth.models import UserType, get_typed_user_model from etebase_server.myauth.models import UserType, get_typed_user_model
from ..db_hack import django_db_cleanup_decorator from ..db_hack import django_db_cleanup_decorator
from ..msgpack import MsgpackRoute from ..msgpack import MsgpackResponse, MsgpackRoute
from ..stoken_handler import filter_by_stoken_and_limit from ..stoken_handler import filter_by_stoken_and_limit
from ..utils import PERMISSIONS_READ, PERMISSIONS_READWRITE, BaseModel, get_object_or_404, permission_responses from ..utils import PERMISSIONS_READ, PERMISSIONS_READWRITE, BaseModel, get_object_or_404, permission_responses
from .authentication import get_authenticated_user from .authentication import get_authenticated_user
@ -66,10 +66,12 @@ def member_list(
) )
new_stoken = new_stoken_obj and new_stoken_obj.uid new_stoken = new_stoken_obj and new_stoken_obj.uid
return MemberListResponse( return MsgpackResponse(
data=[CollectionMemberOut.from_orm(item) for item in result], MemberListResponse(
iterator=new_stoken, data=[CollectionMemberOut.from_orm(item) for item in result],
done=done, iterator=new_stoken,
done=done,
)
) )