mirror of
https://github.com/etesync/server
synced 2024-11-26 02:38:15 +00:00
Subscriptions: implement live subscriptions for collection items
This commit is contained in:
parent
cd4131e890
commit
f52facad1c
@ -11,6 +11,8 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the GNU General Public License
|
# You should have received a copy of the GNU General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +33,10 @@ class AppSettings:
|
|||||||
|
|
||||||
return getattr(settings, self.prefix + name, dflt)
|
return getattr(settings, self.prefix + name, dflt)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def REDIS_URI(self) -> t.Optional[str]: # pylint: disable=invalid-name
|
||||||
|
return self._setting("REDIS_URI", None)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def API_PERMISSIONS_READ(self): # pylint: disable=invalid-name
|
def API_PERMISSIONS_READ(self): # pylint: disable=invalid-name
|
||||||
perms = self._setting("API_PERMISSIONS_READ", tuple())
|
perms = self._setting("API_PERMISSIONS_READ", tuple())
|
||||||
|
@ -63,6 +63,16 @@ class PermissionDenied(CustomHttpException):
|
|||||||
super().__init__(code=code, detail=detail, status_code=status_code)
|
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupported(CustomHttpException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
code="not_implemented",
|
||||||
|
detail: str = "This server's configuration does not support this request.",
|
||||||
|
status_code: int = status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
):
|
||||||
|
super().__init__(code=code, detail=detail, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
class HttpError(CustomHttpException):
|
class HttpError(CustomHttpException):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -5,12 +5,15 @@ from fastapi import FastAPI, Request
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
|
||||||
|
from django_etebase import app_settings
|
||||||
|
|
||||||
from .exceptions import CustomHttpException
|
from .exceptions import CustomHttpException
|
||||||
from .msgpack import MsgpackResponse
|
from .msgpack import MsgpackResponse
|
||||||
from .routers.authentication import authentication_router
|
from .routers.authentication import authentication_router
|
||||||
from .routers.collection import collection_router, item_router
|
from .routers.collection import collection_router, item_router
|
||||||
from .routers.member import member_router
|
from .routers.member import member_router
|
||||||
from .routers.invitation import invitation_incoming_router, invitation_outgoing_router
|
from .routers.invitation import invitation_incoming_router, invitation_outgoing_router
|
||||||
|
from .routers.websocket import websocket_router
|
||||||
|
|
||||||
|
|
||||||
def create_application(prefix="", middlewares=[]):
|
def create_application(prefix="", middlewares=[]):
|
||||||
@ -36,6 +39,7 @@ def create_application(prefix="", middlewares=[]):
|
|||||||
app.include_router(
|
app.include_router(
|
||||||
invitation_outgoing_router, prefix=f"{BASE_PATH}/invitation/outgoing", tags=["outgoing invitation"]
|
invitation_outgoing_router, prefix=f"{BASE_PATH}/invitation/outgoing", tags=["outgoing invitation"]
|
||||||
)
|
)
|
||||||
|
app.include_router(websocket_router, prefix=f"{BASE_PATH}/ws", tags=["websocket"])
|
||||||
|
|
||||||
if settings.DEBUG:
|
if settings.DEBUG:
|
||||||
from etebase_fastapi.routers.test_reset_view import test_reset_view_router
|
from etebase_fastapi.routers.test_reset_view import test_reset_view_router
|
||||||
@ -54,6 +58,18 @@ def create_application(prefix="", middlewares=[]):
|
|||||||
for middleware in middlewares:
|
for middleware in middlewares:
|
||||||
app.add_middleware(middleware)
|
app.add_middleware(middleware)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def on_startup() -> None:
|
||||||
|
from .redis import redisw
|
||||||
|
|
||||||
|
await redisw.setup()
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def on_shutdown():
|
||||||
|
from .redis import redisw
|
||||||
|
|
||||||
|
await redisw.close()
|
||||||
|
|
||||||
@app.exception_handler(CustomHttpException)
|
@app.exception_handler(CustomHttpException)
|
||||||
async def custom_exception_handler(request: Request, exc: CustomHttpException):
|
async def custom_exception_handler(request: Request, exc: CustomHttpException):
|
||||||
return MsgpackResponse(status_code=exc.status_code, content=exc.as_dict)
|
return MsgpackResponse(status_code=exc.status_code, content=exc.as_dict)
|
||||||
|
27
etebase_fastapi/redis.py
Normal file
27
etebase_fastapi/redis.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import typing as t
|
||||||
|
import aioredis
|
||||||
|
|
||||||
|
from django_etebase import app_settings
|
||||||
|
|
||||||
|
|
||||||
|
class RedisWrapper:
|
||||||
|
redis: aioredis.Redis
|
||||||
|
|
||||||
|
def __init__(self, redis_uri: t.Optional[str]):
|
||||||
|
self.redis_uri = redis_uri
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
if self.redis_uri is not None:
|
||||||
|
self.redis = await aioredis.create_redis_pool(self.redis_uri)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self.redis is not None:
|
||||||
|
self.redis.close()
|
||||||
|
await self.redis.wait_closed()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self):
|
||||||
|
return self.redis_uri is not None
|
||||||
|
|
||||||
|
|
||||||
|
redisw = RedisWrapper(app_settings.REDIS_URI)
|
@ -1,6 +1,6 @@
|
|||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async, async_to_sync
|
||||||
from django.core import exceptions as django_exceptions
|
from django.core import exceptions as django_exceptions
|
||||||
from django.core.files.base import ContentFile
|
from django.core.files.base import ContentFile
|
||||||
from django.db import transaction, IntegrityError
|
from django.db import transaction, IntegrityError
|
||||||
@ -10,6 +10,7 @@ from fastapi import APIRouter, Depends, status, Request
|
|||||||
from django_etebase import models
|
from django_etebase import models
|
||||||
from myauth.models import UserType
|
from myauth.models import UserType
|
||||||
from .authentication import get_authenticated_user
|
from .authentication import get_authenticated_user
|
||||||
|
from .websocket import get_ticket, TicketRequest, TicketOut
|
||||||
from ..exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
|
from ..exceptions import HttpError, transform_validation_error, PermissionDenied, ValidationError
|
||||||
from ..msgpack import MsgpackRoute
|
from ..msgpack import MsgpackRoute
|
||||||
from ..stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken
|
from ..stoken_handler import filter_by_stoken_and_limit, filter_by_stoken, get_stoken_obj, get_queryset_stoken
|
||||||
@ -19,6 +20,7 @@ from ..utils import (
|
|||||||
Prefetch,
|
Prefetch,
|
||||||
PrefetchQuery,
|
PrefetchQuery,
|
||||||
is_collection_admin,
|
is_collection_admin,
|
||||||
|
msgpack_encode,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
permission_responses,
|
permission_responses,
|
||||||
PERMISSIONS_READ,
|
PERMISSIONS_READ,
|
||||||
@ -26,6 +28,7 @@ from ..utils import (
|
|||||||
)
|
)
|
||||||
from ..dependencies import get_collection_queryset, get_item_queryset, get_collection
|
from ..dependencies import get_collection_queryset, get_item_queryset, get_collection
|
||||||
from ..sendfile import sendfile
|
from ..sendfile import sendfile
|
||||||
|
from ..redis import redisw
|
||||||
from ..db_hack import django_db_cleanup_decorator
|
from ..db_hack import django_db_cleanup_decorator
|
||||||
|
|
||||||
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
collection_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
@ -188,6 +191,16 @@ class ItemBatchIn(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: make it a background task
|
||||||
|
def report_items_changed(col_uid: str, stoken: str, items: t.List[CollectionItemIn]):
|
||||||
|
if not redisw.is_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
redis = redisw.redis
|
||||||
|
content = msgpack_encode(CollectionItemListResponse(data=items, stoken=stoken, done=True).dict())
|
||||||
|
async_to_sync(redis.publish)(f"col.{col_uid}", content)
|
||||||
|
|
||||||
|
|
||||||
def collection_list_common(
|
def collection_list_common(
|
||||||
queryset: CollectionQuerySet,
|
queryset: CollectionQuerySet,
|
||||||
user: UserType,
|
user: UserType,
|
||||||
@ -440,6 +453,15 @@ def item_list(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@item_router.post("/item/subscription-ticket/", response_model=TicketOut, dependencies=PERMISSIONS_READ)
|
||||||
|
async def item_list_subscription_ticket(
|
||||||
|
collection: models.Collection = Depends(get_collection),
|
||||||
|
user: UserType = Depends(get_authenticated_user),
|
||||||
|
):
|
||||||
|
"""Get an authentication ticket that can be used with the websocket endpoint"""
|
||||||
|
return await get_ticket(TicketRequest(collection=collection.uid), user)
|
||||||
|
|
||||||
|
|
||||||
def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
|
def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str], uid: str, validate_etag: bool):
|
||||||
queryset = get_collection_queryset(user)
|
queryset = get_collection_queryset(user)
|
||||||
with transaction.atomic(): # We need this for locking the collection object
|
with transaction.atomic(): # We need this for locking the collection object
|
||||||
@ -465,6 +487,8 @@ def item_bulk_common(data: ItemBatchIn, user: UserType, stoken: t.Optional[str],
|
|||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
report_items_changed(collection_object.uid, collection_object.stoken, data.items)
|
||||||
|
|
||||||
|
|
||||||
@item_router.get(
|
@item_router.get(
|
||||||
"/item/{item_uid}/revision/", response_model=CollectionItemRevisionListResponse, dependencies=PERMISSIONS_READ
|
"/item/{item_uid}/revision/", response_model=CollectionItemRevisionListResponse, dependencies=PERMISSIONS_READ
|
||||||
|
114
etebase_fastapi/routers/websocket.py
Normal file
114
etebase_fastapi/routers/websocket.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import asyncio
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
import aioredis
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
||||||
|
import nacl.encoding
|
||||||
|
import nacl.utils
|
||||||
|
|
||||||
|
from django_etebase import models
|
||||||
|
from django_etebase.utils import CallbackContext, get_user_queryset
|
||||||
|
from myauth.models import UserType, get_typed_user_model
|
||||||
|
|
||||||
|
from ..exceptions import NotSupported
|
||||||
|
from ..msgpack import MsgpackRoute, msgpack_decode, msgpack_encode
|
||||||
|
from ..redis import redisw
|
||||||
|
from ..utils import BaseModel, permission_responses
|
||||||
|
|
||||||
|
|
||||||
|
User = get_typed_user_model()
|
||||||
|
websocket_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
|
||||||
|
CollectionQuerySet = QuerySet[models.Collection]
|
||||||
|
|
||||||
|
|
||||||
|
TICKET_VALIDITY_SECONDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class TicketRequest(BaseModel):
|
||||||
|
collection: str
|
||||||
|
|
||||||
|
|
||||||
|
class TicketOut(BaseModel):
|
||||||
|
ticket: str
|
||||||
|
|
||||||
|
|
||||||
|
class TicketInner(BaseModel):
|
||||||
|
user: int
|
||||||
|
req: TicketRequest
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ticket(
|
||||||
|
ticket_request: TicketRequest,
|
||||||
|
user: UserType,
|
||||||
|
):
|
||||||
|
"""Get an authentication ticket that can be used with the websocket endpoint for authentication"""
|
||||||
|
if not redisw.is_active:
|
||||||
|
raise NotSupported(detail="This end-point requires Redis to be configured")
|
||||||
|
|
||||||
|
uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32))
|
||||||
|
ticket_model = TicketInner(user=user.id, req=ticket_request)
|
||||||
|
ticket_raw = msgpack_encode(ticket_model.dict())
|
||||||
|
await redisw.redis.set(uid, ticket_raw, expire=TICKET_VALIDITY_SECONDS * 1000)
|
||||||
|
return TicketOut(ticket=uid)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_websocket_ticket(websocket: WebSocket, ticket: str) -> t.Optional[TicketInner]:
|
||||||
|
content = await redisw.redis.get(ticket)
|
||||||
|
if content is None:
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
return None
|
||||||
|
await redisw.redis.delete(ticket)
|
||||||
|
return TicketInner(**msgpack_decode(content))
|
||||||
|
|
||||||
|
|
||||||
|
def get_websocket_user(websocket: WebSocket, ticket_model: t.Optional[TicketInner] = Depends(load_websocket_ticket)):
|
||||||
|
if ticket_model is None:
|
||||||
|
return None
|
||||||
|
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(websocket.path_params))
|
||||||
|
return user_queryset.get(id=ticket_model.user)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_router.websocket("/{ticket}/")
|
||||||
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user: t.Optional[UserType] = Depends(get_websocket_user),
|
||||||
|
ticket_model: TicketInner = Depends(load_websocket_ticket),
|
||||||
|
):
|
||||||
|
if user is None:
|
||||||
|
return
|
||||||
|
await websocket.accept()
|
||||||
|
await redis_connector(websocket, ticket_model)
|
||||||
|
|
||||||
|
|
||||||
|
async def redis_connector(websocket: WebSocket, ticket_model: TicketInner):
|
||||||
|
async def producer_handler(r: aioredis.Redis, ws: WebSocket):
|
||||||
|
channel_name = f"col.{ticket_model.req.collection}"
|
||||||
|
(channel,) = await r.psubscribe(channel_name)
|
||||||
|
assert isinstance(channel, aioredis.Channel)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# We wait on the websocket so we fail if web sockets fail or get data
|
||||||
|
receive = asyncio.create_task(websocket.receive())
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
{receive, channel.wait_message()}, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
if receive in done:
|
||||||
|
# Web socket should never receieve any data
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
return
|
||||||
|
|
||||||
|
message_raw = t.cast(t.Optional[t.Tuple[str, bytes]], await channel.get())
|
||||||
|
if message_raw:
|
||||||
|
_, message = message_raw
|
||||||
|
await ws.send_bytes(message)
|
||||||
|
|
||||||
|
except aioredis.errors.ConnectionClosedError:
|
||||||
|
await websocket.close(code=status.WS_1012_SERVICE_RESTART)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
redis = redisw.redis
|
||||||
|
await producer_handler(redis, websocket)
|
Loading…
Reference in New Issue
Block a user