Fixes for fastapi.

pull/184/head
Tom Hacohen 3 weeks ago
parent 57e676baa1
commit 0be14a7b0e

@ -12,9 +12,12 @@ from .utils import msgpack_decode, msgpack_encode
class MsgpackRequest(Request):
media_type = "application/msgpack"
async def raw_body(self) -> bytes:
return await super().body()
async def body(self) -> bytes:
if not hasattr(self, "_json"):
body = await super().body()
body = await self.raw_body()
self._json = msgpack_decode(body)
return self._json

@ -23,7 +23,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model
from ..dependencies import AuthData, get_auth_data, get_authenticated_user
from ..exceptions import AuthenticationFailed, HttpError, transform_validation_error
from ..msgpack import MsgpackRoute
from ..msgpack import MsgpackResponse, MsgpackRoute
from ..utils import BaseModel, get_user_username_email_kwargs, msgpack_decode, msgpack_encode, permission_responses
User = get_typed_user_model()
@ -76,7 +76,7 @@ class LoginOut(BaseModel):
class Authentication(BaseModel):
class Config:
ignored_types= (cached_property,)
ignored_types = (cached_property,)
response: bytes
signature: bytes
@ -188,7 +188,7 @@ def login_challenge(user: UserType = Depends(get_login_user)):
"userId": user.id,
}
challenge = bytes(box.encrypt(msgpack_encode(challenge_data), encoder=nacl.encoding.RawEncoder))
return LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version)
return MsgpackResponse(LoginChallengeOut(salt=salt, challenge=challenge, version=user.userinfo.version))
@authentication_router.post("/login/", response_model=LoginOut)
@ -198,7 +198,7 @@ def login(data: Login, request: Request):
validate_login_request(data.response_data, data, user, "login", host)
ret = LoginOut.from_orm(user)
user_logged_in.send(sender=user.__class__, request=None, user=user)
return ret
return MsgpackResponse(ret)
@authentication_router.post("/logout/", status_code=status.HTTP_204_NO_CONTENT, responses=permission_responses)
@ -223,7 +223,7 @@ def dashboard_url(request: Request, user: UserType = Depends(get_authenticated_u
ret = {
"url": get_dashboard_url(CallbackContext(request.path_params, user=user)),
}
return ret
return MsgpackResponse(ret)
def signup_save(data: SignupIn, request: Request) -> UserType:
@ -261,4 +261,4 @@ def signup(data: SignupIn, request: Request):
user = signup_save(data, request)
ret = LoginOut.from_orm(user)
user_signed_up.send(sender=user.__class__, request=None, user=user)
return ret
return MsgpackResponse(ret)

@ -13,7 +13,7 @@ from etebase_server.myauth.models import UserType
from ..db_hack import django_db_cleanup_decorator
from ..dependencies import get_collection, get_collection_queryset, get_item_queryset
from ..exceptions import HttpError, PermissionDenied, ValidationError, transform_validation_error
from ..msgpack import MsgpackRoute
from ..msgpack import MsgpackRequest, MsgpackResponse, MsgpackRoute
from ..redis import redisw
from ..sendfile import sendfile
from ..stoken_handler import filter_by_stoken, filter_by_stoken_and_limit, get_queryset_stoken, get_stoken_obj
@ -135,7 +135,7 @@ class CollectionListResponse(BaseModel):
stoken: t.Optional[str]
done: bool
removedMemberships: t.Optional[t.List[RemovedMembershipOut]]
removedMemberships: t.Optional[t.List[RemovedMembershipOut]] = None
class CollectionItemListResponse(BaseModel):
@ -275,7 +275,7 @@ def list_multi(
Q(members__collectionType__uid__in=data.collectionTypes) | Q(members__collectionType__isnull=True)
)
return collection_list_common(queryset, user, stoken, limit, prefetch)
return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch))
@collection_router.get("/", response_model=CollectionListResponse, dependencies=PERMISSIONS_READ)
@ -286,7 +286,7 @@ def collection_list(
user: UserType = Depends(get_authenticated_user),
queryset: CollectionQuerySet = Depends(get_collection_queryset),
):
return collection_list_common(queryset, user, stoken, limit, prefetch)
return MsgpackResponse(collection_list_common(queryset, user, stoken, limit, prefetch))
def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionInOut):
@ -365,7 +365,7 @@ def collection_get(
user: UserType = Depends(get_authenticated_user),
prefetch: Prefetch = PrefetchQuery,
):
return CollectionOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(CollectionOut.from_orm_context(obj, Context(user, prefetch)))
def item_create(item_model: CollectionItemIn, collection: models.Collection, validate_etag: bool):
@ -418,7 +418,7 @@ def item_get(
prefetch: Prefetch = PrefetchQuery,
):
obj = queryset.get(uid=item_uid)
return CollectionItemOut.from_orm_context(obj, Context(user, prefetch))
return MsgpackResponse(CollectionItemOut.from_orm_context(obj, Context(user, prefetch)))
def item_list_common(
@ -450,7 +450,7 @@ def item_list(
queryset = queryset.filter(parent__isnull=True)
response = item_list_common(queryset, user, stoken, limit, prefetch)
return response
return MsgpackResponse(response)
@item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ)
@ -459,7 +459,7 @@ async def item_list_subscription_ticket(
user: UserType = Depends(get_authenticated_user),
):
"""Get an authentication ticket that can be used with the websocket endpoint"""
return await get_ticket(TicketRequest(collection=collection.uid), user)
return MsgpackResponse(await get_ticket(TicketRequest(collection=collection.uid), user))
def item_bulk_common(
@ -527,10 +527,12 @@ def item_revisions(
ret_data = [CollectionItemRevisionInOut.from_orm_context(revision, context) for revision in result]
iterator = ret_data[-1].uid if len(result) > 0 else None
return CollectionItemRevisionListResponse(
data=ret_data,
iterator=iterator,
done=done,
return MsgpackResponse(
CollectionItemRevisionListResponse(
data=ret_data,
iterator=iterator,
done=done,
)
)
@ -560,10 +562,12 @@ def fetch_updates(
new_stoken = new_stoken or stoken_rev_uid
context = Context(user, prefetch)
return CollectionItemListResponse(
data=[CollectionItemOut.from_orm_context(item, context) for item in queryset],
stoken=new_stoken,
done=True, # we always return all the items, so it's always done
return MsgpackResponse(
CollectionItemListResponse(
data=[CollectionItemOut.from_orm_context(item, context) for item in queryset],
stoken=new_stoken,
done=True, # we always return all the items, so it's always done
)
)
@ -575,7 +579,9 @@ def item_transaction(
stoken: t.Optional[str] = None,
user: UserType = Depends(get_authenticated_user),
):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks)
return MsgpackResponse(
item_bulk_common(data, user, stoken, collection_uid, validate_etag=True, background_tasks=background_tasks)
)
@item_router.post("/item/batch/", dependencies=[Depends(has_write_access), *PERMISSIONS_READWRITE])
@ -586,7 +592,9 @@ def item_batch(
stoken: t.Optional[str] = None,
user: UserType = Depends(get_authenticated_user),
):
return item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks)
return MsgpackResponse(
item_bulk_common(data, user, stoken, collection_uid, validate_etag=False, background_tasks=background_tasks)
)
# Chunks
@ -611,7 +619,12 @@ async def chunk_update(
collection: models.Collection = Depends(get_collection),
):
# IGNORED FOR NOW: col_it = get_object_or_404(col.items, uid=collection_item_uid)
content_file = ContentFile(await request.body())
if isinstance(request, MsgpackRequest):
body = await request.raw_body()
else:
body = await request.body()
content_file = ContentFile(body)
try:
await chunk_save(chunk_uid, collection, content_file)
except IntegrityError:

@ -10,7 +10,7 @@ from etebase_server.myauth.models import UserType, get_typed_user_model
from ..db_hack import django_db_cleanup_decorator
from ..exceptions import HttpError, PermissionDenied
from ..msgpack import MsgpackRoute
from ..msgpack import MsgpackResponse, MsgpackRoute
from ..utils import (
PERMISSIONS_READ,
PERMISSIONS_READWRITE,
@ -34,7 +34,7 @@ class UserInfoOut(BaseModel):
pubkey: bytes
class Config:
from_attributes= True
from_attributes = True
@classmethod
def from_orm(cls: t.Type["UserInfoOut"], obj: models.UserInfo) -> "UserInfoOut":
@ -121,7 +121,7 @@ def list_common(
iterator = ret_data[-1].uid if len(result) > 0 else None
return InvitationListResponse(
data=ret_data,
data=[CollectionInvitationOut.from_orm(x) for x in ret_data],
iterator=iterator,
done=done,
)
@ -133,7 +133,7 @@ def incoming_list(
limit: int = 50,
queryset: InvitationQuerySet = Depends(get_incoming_queryset),
):
return list_common(queryset, iterator, limit)
return MsgpackResponse(list_common(queryset, iterator, limit))
@invitation_incoming_router.get(
@ -144,7 +144,7 @@ def incoming_get(
queryset: InvitationQuerySet = Depends(get_incoming_queryset),
):
obj = get_object_or_404(queryset, uid=invitation_uid)
return CollectionInvitationOut.from_orm(obj)
return MsgpackResponse(CollectionInvitationOut.from_orm(obj))
@invitation_incoming_router.delete(
@ -219,7 +219,7 @@ def outgoing_list(
limit: int = 50,
queryset: InvitationQuerySet = Depends(get_outgoing_queryset),
):
return list_common(queryset, iterator, limit)
return MsgpackResponse(list_common(queryset, iterator, limit))
@invitation_outgoing_router.delete(
@ -242,4 +242,4 @@ def outgoing_fetch_user_profile(
kwargs = get_user_username_email_kwargs(username)
user = get_object_or_404(get_user_queryset(User.objects.all(), CallbackContext(request.path_params)), **kwargs)
user_info = get_object_or_404(models.UserInfo.objects.all(), owner=user)
return UserInfoOut.from_orm(user_info)
return MsgpackResponse(UserInfoOut.from_orm(user_info))

@ -8,7 +8,7 @@ from etebase_server.django import models
from etebase_server.myauth.models import UserType, get_typed_user_model
from ..db_hack import django_db_cleanup_decorator
from ..msgpack import MsgpackRoute
from ..msgpack import MsgpackResponse, MsgpackRoute
from ..stoken_handler import filter_by_stoken_and_limit
from ..utils import PERMISSIONS_READ, PERMISSIONS_READWRITE, BaseModel, get_object_or_404, permission_responses
from .authentication import get_authenticated_user
@ -66,10 +66,12 @@ def member_list(
)
new_stoken = new_stoken_obj and new_stoken_obj.uid
return MemberListResponse(
data=[CollectionMemberOut.from_orm(item) for item in result],
iterator=new_stoken,
done=done,
return MsgpackResponse(
MemberListResponse(
data=[CollectionMemberOut.from_orm(item) for item in result],
iterator=new_stoken,
done=done,
)
)

Loading…
Cancel
Save