diff --git a/etebase_fastapi/collection.py b/etebase_fastapi/collection.py index f687127..5656fb6 100644 --- a/etebase_fastapi/collection.py +++ b/etebase_fastapi/collection.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from django_etebase import models from django_etebase.models import Collection, AccessLevels, CollectionMember from .authentication import get_authenticated_user -from .exceptions import ValidationError +from .exceptions import ValidationError, transform_validation_error from .msgpack import MsgpackRoute, MsgpackResponse from .stoken_handler import filter_by_stoken_and_limit @@ -79,7 +79,7 @@ class CollectionItemOut(CollectionItemCommon): version=obj.version, encryptionKey=obj.encryptionKey, etag=obj.etag, - content=CollectionItemRevisionOut.from_orm_context(obj.content, context), + content=CollectionItemRevision.from_orm_context(obj.content, context), ) @@ -125,11 +125,26 @@ class ItemDepIn(BaseModel): etag: str uid: str + def validate_db(self): + item = models.CollectionItem.objects.get(uid=self.uid) + etag = self.etag + if item.etag != etag: + raise ValidationError( + "wrong_etag", + "Wrong etag. Expected {} got {}".format(item.etag, etag), + status_code=status.HTTP_409_CONFLICT, + ) + class ItemBatchIn(BaseModel): items: t.List[CollectionItemIn] deps: t.Optional[ItemDepIn] + def validate_db(self): + if self.deps is not None: + for key, _value in self.deps: + getattr(self.deps, key).validate_db() + @sync_to_async def list_common( @@ -172,7 +187,7 @@ async def collection_list( pass -def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevisionOut): +def process_revisions_for_item(item: models.CollectionItem, revision_data: CollectionItemRevision): chunks_objs = [] revision = models.CollectionItemRevision(**revision_data.dict(exclude={"chunks"}), item=item) @@ -250,16 +265,60 @@ def get_collection(uid: str, user: User = Depends(get_authenticated_user), prefe return MsgpackResponse(ret) -def item_bulk_common(data: ItemBatchIn, user: User, stoken: str, uid: str, validate_etag: bool): +def item_create(item_model: CollectionItemIn, validate_etag: bool): + """Function that's called when this serializer creates an item""" + etag = item_model.etag + revision_data = item_model.content + uid = item_model.uid + + Model = models.CollectionItem + + with transaction.atomic(): + instance, created = Model.objects.get_or_create( + uid=uid, defaults=item_model.dict(exclude={"uid", "etag", "content"}) + ) + cur_etag = instance.etag if not created else None + + # If we are trying to update an up to date item, abort early and consider it a success + if cur_etag == revision_data.uid: + return instance + + if validate_etag and cur_etag != etag: + raise ValidationError( + "wrong_etag", + "Wrong etag. Expected {} got {}".format(cur_etag, etag), + status_code=status.HTTP_409_CONFLICT, + ) + + if not created: + # We don't have to use select_for_update here because the unique constraint on current guards against + # the race condition. But it's a good idea because it'll lock and wait rather than fail. + current_revision = instance.revisions.filter(current=True).select_for_update().first() + current_revision.current = None + current_revision.save() + + try: + process_revisions_for_item(instance, revision_data) + except django_exceptions.ValidationError as e: + transform_validation_error("content", e) + + return instance + + +def item_bulk_common(data: ItemBatchIn, user: User, stoken: t.Optional[str], uid: str, validate_etag: bool): queryset = get_collection_queryset(user, default_queryset) with transaction.atomic(): # We need this for locking the collection object collection_object = queryset.select_for_update().get(uid=uid) + if stoken is not None and stoken != collection_object.stoken: raise ValidationError("stale_stoken", "Stoken is too old", status_code=status.HTTP_409_CONFLICT) + # XXX-TOM: make sure we return compatible errors + data.validate_db() + for item in data.items: + item_create(item, validate_etag) -def item_create(): - pass # + return MsgpackResponse({}) @collection_router.post("/{uid}/item/transaction/")