diff --git a/etebase_fastapi/collection.py b/etebase_fastapi/collection.py index 0ceb988..2b0b876 100644 --- a/etebase_fastapi/collection.py +++ b/etebase_fastapi/collection.py @@ -1,3 +1,4 @@ +import dataclasses import typing as t from asgiref.sync import sync_to_async @@ -26,6 +27,12 @@ Prefetch = t.Literal["auto", "medium"] PrefetchQuery = Query(default="auto") +@dataclasses.dataclass +class Context: + user: t.Optional[User] + prefetch: t.Optional[Prefetch] + + class ListMulti(BaseModel): collectionTypes: t.List[bytes] @@ -40,11 +47,11 @@ class CollectionItemRevisionOut(BaseModel): orm_mode = True @classmethod - def from_orm_user( - cls: t.Type["CollectionItemRevisionOut"], obj: models.CollectionItemRevision, prefetch: Prefetch + def from_orm_context( + cls: t.Type["CollectionItemRevisionOut"], obj: models.CollectionItemRevision, context: Context ) -> "CollectionItemRevisionOut": chunk_obj = obj.chunks_relation.get().chunk - if prefetch == "auto": + if context.prefetch == "auto": with open(chunk_obj.chunkFile.path, "rb") as f: chunks = chunk_obj.uid, f.read() else: @@ -63,15 +70,15 @@ class CollectionItemOut(BaseModel): orm_mode = True @classmethod - def from_orm_user( - cls: t.Type["CollectionItemOut"], obj: models.CollectionItem, prefetch: Prefetch + def from_orm_context( + cls: t.Type["CollectionItemOut"], obj: models.CollectionItem, context: Context ) -> "CollectionItemOut": return cls( uid=obj.uid, version=obj.version, encryptionKey=obj.encryptionKey, etag=obj.etag, - content=CollectionItemRevisionOut.from_orm_user(obj.content, prefetch), + content=CollectionItemRevisionOut.from_orm_context(obj.content, context), ) @@ -83,15 +90,15 @@ class CollectionOut(BaseModel): item: CollectionItemOut @classmethod - def from_orm_user(cls: t.Type["CollectionOut"], obj: Collection, user: User, prefetch: Prefetch) -> "CollectionOut": - member: CollectionMember = obj.members.get(user=user) + def from_orm_context(cls: t.Type["CollectionOut"], obj: Collection, context: Context) -> "CollectionOut": + member: CollectionMember = obj.members.get(user=context.user) collection_type = member.collectionType ret = cls( collectionType=collection_type and collection_type.uid, collectionKey=member.encryptionKey, accessLevel=member.accessLevel, stoken=obj.stoken, - item=CollectionItemOut.from_orm_user(obj.main_item, prefetch), + item=CollectionItemOut.from_orm_context(obj.main_item, context), ) return ret @@ -121,7 +128,8 @@ def list_common( ) -> MsgpackResponse: result, new_stoken_obj, done = filter_by_stoken_and_limit(stoken, limit, queryset, Collection.stoken_annotation) new_stoken = new_stoken_obj and new_stoken_obj.uid - data: t.List[CollectionOut] = [CollectionOut.from_orm_user(item, user, prefetch) for item in queryset] + context = Context(user, prefetch) + data: t.List[CollectionOut] = [CollectionOut.from_orm_context(item, context) for item in queryset] ret = ListResponse(data=data, stoken=new_stoken, done=done) return MsgpackResponse(content=ret) @@ -221,5 +229,5 @@ async def create(data: CollectionIn, user: User = Depends(get_authenticated_user @collection_router.get("/{uid}/") def get_collection(uid: str, user: User = Depends(get_authenticated_user), prefetch: Prefetch = PrefetchQuery): obj = get_collection_queryset(user, default_queryset).get(uid=uid) - ret = CollectionOut.from_orm_user(obj, user, prefetch) + ret = CollectionOut.from_orm_context(obj, Context(user, prefetch)) return MsgpackResponse(ret)