From f52facad1c0adaa84b9418428b35c8e4fb067bf8 Mon Sep 17 00:00:00 2001 From: Tom Hacohen Date: Mon, 4 Jan 2021 10:02:47 +0200 Subject: [PATCH] Subscriptions: implement live subscriptions for collection items --- django_etebase/app_settings_inner.py | 6 ++ etebase_fastapi/exceptions.py | 10 +++ etebase_fastapi/main.py | 16 ++++ etebase_fastapi/redis.py | 27 ++++++ etebase_fastapi/routers/collection.py | 26 +++++- etebase_fastapi/routers/websocket.py | 114 ++++++++++++++++++++++++++ 6 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 etebase_fastapi/redis.py create mode 100644 etebase_fastapi/routers/websocket.py diff --git a/django_etebase/app_settings_inner.py b/django_etebase/app_settings_inner.py index 90225a6..41fd910 100644 --- a/django_etebase/app_settings_inner.py +++ b/django_etebase/app_settings_inner.py @@ -11,6 +11,8 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import typing as t + from django.utils.functional import cached_property @@ -31,6 +33,10 @@ class AppSettings: return getattr(settings, self.prefix + name, dflt) + @cached_property + def REDIS_URI(self) -> t.Optional[str]: # pylint: disable=invalid-name + return self._setting("REDIS_URI", None) + @cached_property def API_PERMISSIONS_READ(self): # pylint: disable=invalid-name perms = self._setting("API_PERMISSIONS_READ", tuple()) diff --git a/etebase_fastapi/exceptions.py b/etebase_fastapi/exceptions.py index d38d50a..1a98fcb 100644 --- a/etebase_fastapi/exceptions.py +++ b/etebase_fastapi/exceptions.py @@ -63,6 +63,16 @@ class PermissionDenied(CustomHttpException): super().__init__(code=code, detail=detail, status_code=status_code) +class NotSupported(CustomHttpException): + def __init__( + self, + code="not_implemented", + detail: str = "This server's configuration does not support this request.", + status_code: int = status.HTTP_501_NOT_IMPLEMENTED, + ): + super().__init__(code=code, detail=detail, status_code=status_code) + + class HttpError(CustomHttpException): def __init__( self, diff --git a/etebase_fastapi/main.py b/etebase_fastapi/main.py index 8e8469c..d63c01d 100644 --- a/etebase_fastapi/main.py +++ b/etebase_fastapi/main.py @@ -5,12 +5,15 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware +from django_etebase import app_settings + from .exceptions import CustomHttpException from .msgpack import MsgpackResponse from .routers.authentication import authentication_router from .routers.collection import collection_router, item_router from .routers.member import member_router from .routers.invitation import invitation_incoming_router, invitation_outgoing_router +from .routers.websocket import websocket_router def create_application(prefix="", middlewares=[]): @@ -36,6 +39,7 @@ def create_application(prefix="", middlewares=[]): app.include_router( invitation_outgoing_router, prefix=f"{BASE_PATH}/invitation/outgoing", tags=["outgoing invitation"] ) + app.include_router(websocket_router, prefix=f"{BASE_PATH}/ws", tags=["websocket"]) if settings.DEBUG: from etebase_fastapi.routers.test_reset_view import test_reset_view_router @@ -54,6 +58,18 @@ def create_application(prefix="", middlewares=[]): for middleware in middlewares: app.add_middleware(middleware) + @app.on_event("startup") + async def on_startup() -> None: + from .redis import redisw + + await redisw.setup() + + @app.on_event("shutdown") + async def on_shutdown(): + from .redis import redisw + + await redisw.close() + @app.exception_handler(CustomHttpException) async def custom_exception_handler(request: Request, exc: CustomHttpException): return MsgpackResponse(status_code=exc.status_code, content=exc.as_dict) diff --git a/etebase_fastapi/redis.py b/etebase_fastapi/redis.py new file mode 100644 index 0000000..3735e36 --- /dev/null +++ b/etebase_fastapi/redis.py @@ -0,0 +1,27 @@ +import typing as t +import aioredis + +from django_etebase import app_settings + + +class RedisWrapper: + redis: aioredis.Redis + + def __init__(self, redis_uri: t.Optional[str]): + self.redis_uri = redis_uri + + async def setup(self): + if self.redis_uri is not None: + self.redis = await aioredis.create_redis_pool(self.redis_uri) + + async def close(self): + if self.redis is not None: + self.redis.close() + await self.redis.wait_closed() + + @property + def is_active(self): + return self.redis_uri is not None + + +redisw = RedisWrapper(app_settings.REDIS_URI) diff --git a/etebase_fastapi/routers/collection.py b/etebase_fastapi/routers/collection.py index 4dcb3c6..df25541 100644 --- a/etebase_fastapi/routers/collection.py +++ b/etebase_fastapi/routers/collection.py @@ -1,6 +1,6 @@ import typing as t -from asgiref.sync import sync_to_async +from asgiref.sync import sync_to_async, async_to_sync from django.core import exceptions as django_exceptions from django.core.files.base import ContentFile from django.db import transaction, IntegrityError @@ -10,6 +10,7 @@ from fastapi import APIRouter, Depends, status, Request from django_etebase import models from myauth.models import UserType from .authentication import get_authenticated_user +from .websocket import get_ticket, TicketRequest, TicketOut from ..exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError from ..msgpack import MsgpackRoute from ..stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken @@ -19,6 +20,7 @@ from ..utils import ( Prefetch, PrefetchQuery, is_collection_admin, + msgpack_encode, BaseModel, permission_responses, PERMISSIONS_READ, @@ -26,6 +28,7 @@ from ..utils import ( ) from ..dependencies import get_collection_queryset, get_item_queryset, get_collection from ..sendfile import sendfile +from ..redis import redisw from ..db_hack import django_db_cleanup_decorator collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) @@ -188,6 +191,16 @@ class ItemBatchIn(BaseModel): ) +# FIXME: make it a background task +def report_items_changed(col_uid: str, stoken: str, items: t.List[CollectionItemIn]): + if not redisw.is_active: + return + + redis = redisw.redis + content = msgpack_encode(CollectionItemListResponse(data=items, stoken=stoken, done=True).dict()) + async_to_sync(redis.publish)(f"col.{col_uid}", content) + + def collection_list_common( queryset: CollectionQuerySet, user: UserType, @@ -440,6 +453,15 @@ def item_list( return response +@item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ) +async def item_list_subscription_ticket( + collection: models.Collection = Depends(get_collection), + user: UserType = Depends(get_authenticated_user), +): + """Get an authentication ticket that can be used with the websocket endpoint""" + return await get_ticket(TicketRequest(collection=collection.uid), user) + + def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool): queryset = get_collection_queryset(user) with transaction.atomic(): # We need this for locking the collection object @@ -465,6 +487,8 @@ def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], status_code=status.HTTP_409_CONFLICT, ) + report_items_changed(collection_object.uid, collection_object.stoken, data.items) + @item_router.get( "/item/{item_uid}/revision/", response_model=CollectionItemRevisionListResponse, dependencies=PERMISSIONS_READ diff --git a/etebase_fastapi/routers/websocket.py b/etebase_fastapi/routers/websocket.py new file mode 100644 index 0000000..2d599db --- /dev/null +++ b/etebase_fastapi/routers/websocket.py @@ -0,0 +1,114 @@ +import asyncio +import typing as t + +import aioredis +from django.db.models import QuerySet +from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status +import nacl.encoding +import nacl.utils + +from django_etebase import models +from django_etebase.utils import CallbackContext, get_user_queryset +from myauth.models import UserType, get_typed_user_model + +from ..exceptions import NotSupported +from ..msgpack import MsgpackRoute, msgpack_decode, msgpack_encode +from ..redis import redisw +from ..utils import BaseModel, permission_responses + + +User = get_typed_user_model() +websocket_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses) +CollectionQuerySet = QuerySet[models.Collection] + + +TICKET_VALIDITY_SECONDS = 10 + + +class TicketRequest(BaseModel): + collection: str + + +class TicketOut(BaseModel): + ticket: str + + +class TicketInner(BaseModel): + user: int + req: TicketRequest + + +async def get_ticket( + ticket_request: TicketRequest, + user: UserType, +): + """Get an authentication ticket that can be used with the websocket endpoint for authentication""" + if not redisw.is_active: + raise NotSupported(detail="This end-point requires Redis to be configured") + + uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32)) + ticket_model = TicketInner(user=user.id, req=ticket_request) + ticket_raw = msgpack_encode(ticket_model.dict()) + await redisw.redis.set(uid, ticket_raw, expire=TICKET_VALIDITY_SECONDS * 1000) + return TicketOut(ticket=uid) + + +async def load_websocket_ticket(websocket: WebSocket, ticket: str) -> t.Optional[TicketInner]: + content = await redisw.redis.get(ticket) + if content is None: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return None + await redisw.redis.delete(ticket) + return TicketInner(**msgpack_decode(content)) + + +def get_websocket_user(websocket: WebSocket, ticket_model: t.Optional[TicketInner] = Depends(load_websocket_ticket)): + if ticket_model is None: + return None + user_queryset = get_user_queryset(User.objects.all(), CallbackContext(websocket.path_params)) + return user_queryset.get(id=ticket_model.user) + + +@websocket_router.websocket("/{ticket}/") +async def websocket_endpoint( + websocket: WebSocket, + user: t.Optional[UserType] = Depends(get_websocket_user), + ticket_model: TicketInner = Depends(load_websocket_ticket), +): + if user is None: + return + await websocket.accept() + await redis_connector(websocket, ticket_model) + + +async def redis_connector(websocket: WebSocket, ticket_model: TicketInner): + async def producer_handler(r: aioredis.Redis, ws: WebSocket): + channel_name = f"col.{ticket_model.req.collection}" + (channel,) = await r.psubscribe(channel_name) + assert isinstance(channel, aioredis.Channel) + try: + while True: + # We wait on the websocket so we fail if web sockets fail or get data + receive = asyncio.create_task(websocket.receive()) + done, pending = await asyncio.wait( + {receive, channel.wait_message()}, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if receive in done: + # Web socket should never receieve any data + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return + + message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], await channel.get()) + if message_raw: + _, message = message_raw + await ws.send_bytes(message) + + except aioredis.errors.ConnectionClosedError: + await websocket.close(code=status.WS_1012_SERVICE_RESTART) + except WebSocketDisconnect: + pass + + redis = redisw.redis + await producer_handler(redis, websocket)