diff --git a/etebase_fastapi/authentication.py b/etebase_fastapi/authentication.py index 3ae6c4b..287b46e 100644 --- a/etebase_fastapi/authentication.py +++ b/etebase_fastapi/authentication.py @@ -18,15 +18,16 @@ from fastapi import APIRouter, Depends, status, Request, Response from fastapi.security import APIKeyHeader from pydantic import BaseModel -from django_etebase import app_settings +from django_etebase import app_settings, models from django_etebase.exceptions import EtebaseValidationError from django_etebase.models import UserInfo from django_etebase.serializers import UserSerializer +from django_etebase.signals import user_signed_up from django_etebase.token_auth.models import AuthToken from django_etebase.token_auth.models import get_default_expiry from django_etebase.utils import create_user from django_etebase.views import msgpack_encode, msgpack_decode -from .execptions import AuthenticationFailed +from .execptions import AuthenticationFailed, transform_validation_error, ValidationError from .msgpack import MsgpackResponse, MsgpackRoute User = get_user_model() @@ -272,7 +273,7 @@ async def change_password(data: ChangePassword, request: Request, user: User = D @sync_to_async -def signup_save(data: SignupIn): +def signup_save(data: SignupIn) -> User: user_data = data.user with transaction.atomic(): try: @@ -290,17 +291,26 @@ def signup_save(data: SignupIn): except EtebaseValidationError as e: raise e except django_exceptions.ValidationError as e: - self.transform_validation_error("user", e) + transform_validation_error("user", e) except Exception as e: raise EtebaseValidationError("generic", str(e)) if hasattr(instance, "userinfo"): - raise EtebaseValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT) + raise ValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT) - models.UserInfo.objects.create(**validated_data, owner=instance) + models.UserInfo.objects.create(**data.dict(exclude={"user"}), owner=instance) return instance +@sync_to_async +def send_user_signed_up_async(user: User, request): + user_signed_up.send(sender=user.__class__, request=request, user=user) + + @authentication_router.post("/signup/") async def signup(data: SignupIn): - pass + user = await signup_save(data) + # XXX-TOM + data = await login_response_data(user) + await send_user_signed_up_async(user, None) + return MsgpackResponse(content=data, status_code=status.HTTP_201_CREATED) diff --git a/etebase_fastapi/execptions.py b/etebase_fastapi/execptions.py index 2b35634..fa76c45 100644 --- a/etebase_fastapi/execptions.py +++ b/etebase_fastapi/execptions.py @@ -1,8 +1,23 @@ from fastapi import status +import typing as t + +from pydantic import BaseModel from django_etebase.exceptions import EtebaseValidationError +class ValidationErrorField(BaseModel): + field: str + code: str + detail: str + + +class ValidationErrorOut(BaseModel): + code: str + detail: str + errors: t.Optional[t.List[ValidationErrorField]] + + class CustomHttpException(Exception): def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): self.status_code = status_code @@ -44,12 +59,27 @@ class PermissionDenied(CustomHttpException): super().__init__(code=code, detail=detail, status_code=status_code) +from django_etebase.exceptions import EtebaseValidationError + + class ValidationError(CustomHttpException): - def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): + def __init__( + self, + code: str, + detail: str, + status_code: int = status.HTTP_400_BAD_REQUEST, + field: t.Optional[str] = None, + errors: t.Optional[t.List["ValidationError"]] = None, + ): + self.errors = errors super().__init__(code=code, detail=detail, status_code=status_code) + @property + def as_dict(self) -> dict: + return ValidationErrorOut(code=self.code, errors=self.errors, detail=self.detail).dict() + -def flatten_errors(field_name, errors): +def flatten_errors(field_name, errors) -> t.List[ValidationError]: ret = [] if isinstance(errors, dict): for error_key in errors: @@ -61,13 +91,7 @@ def flatten_errors(field_name, errors): message = error.messages[0] else: message = str(error) - ret.append( - { - "field": field_name, - "code": error.code, - "detail": message, - } - ) + ret.append(dict(code=error.code, detail=message, field=field_name)) return ret @@ -78,11 +102,4 @@ def transform_validation_error(prefix, err): errors = flatten_errors(prefix, err.error_list) else: raise EtebaseValidationError(err.code, err.message) - raise ValidationError(code="field_errors", detail="Field validations failed.") - raise serializers.ValidationError( - { - "code": "field_errors", - "detail": "Field validations failed.", - "errors": errors, - } - ) + raise ValidationError(code="field_errors", detail="Field validations failed.", errors=errors)