pull/72/head
Tal Leibman 3 years ago committed by Tom Hacohen
parent 7d86459480
commit c90e92b0f0

@ -2,6 +2,7 @@ import dataclasses
import typing as t
from datetime import datetime
from functools import cached_property
from django.core import exceptions as django_exceptions
import nacl
import nacl.encoding
@ -11,16 +12,19 @@ import nacl.signing
from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.auth import get_user_model, user_logged_out, user_logged_in
from django.db import transaction
from django.utils import timezone
from fastapi import APIRouter, Depends, status, Request, Response
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from django_etebase import app_settings
from django_etebase.exceptions import EtebaseValidationError
from django_etebase.models import UserInfo
from django_etebase.serializers import UserSerializer
from django_etebase.token_auth.models import AuthToken
from django_etebase.token_auth.models import get_default_expiry
from django_etebase.utils import create_user
from django_etebase.views import msgpack_encode, msgpack_decode
from .execptions import AuthenticationFailed
from .msgpack import MsgpackResponse, MsgpackRoute
@ -74,6 +78,19 @@ class ChangePassword(Authentication):
return ChangePasswordResponse(**msgpack_decode(self.response))
class UserSignup(BaseModel):
username: str
email: str
class SignupIn(BaseModel):
user: UserSignup
salt: bytes
loginPubkey: bytes
pubkey: bytes
encryptedContent: bytes
def __renew_token(auth_token: AuthToken):
current_expiry = auth_token.expiry
new_expiry = get_default_expiry()
@ -252,3 +269,38 @@ async def change_password(data: ChangePassword, request: Request, user: User = D
return bad_login_response
await save_changed_password(data, user)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@sync_to_async
def signup_save(data: SignupIn):
user_data = data.user
with transaction.atomic():
try:
# XXX-TOM
# view = self.context.get("view", None)
# user_queryset = get_user_queryset(User.objects.all(), view)
user_queryset = User.objects.all()
instance = user_queryset.get(**{User.USERNAME_FIELD: user_data.username.lower()})
except User.DoesNotExist:
# Create the user and save the casing the user chose as the first name
try:
# XXX-TOM
instance = create_user(**user_data.dict(), password=None, first_name=user_data.username, view=None)
instance.full_clean()
except EtebaseValidationError as e:
raise e
except django_exceptions.ValidationError as e:
self.transform_validation_error("user", e)
except Exception as e:
raise EtebaseValidationError("generic", str(e))
if hasattr(instance, "userinfo"):
raise EtebaseValidationError("user_exists", "User already exists", status_code=status.HTTP_409_CONFLICT)
models.UserInfo.objects.create(**validated_data, owner=instance)
return instance
@authentication_router.post("/signup/")
async def signup(data: SignupIn):
pass

@ -1,5 +1,7 @@
from fastapi import status
from django_etebase.exceptions import EtebaseValidationError
class CustomHttpException(Exception):
def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
@ -40,3 +42,47 @@ class PermissionDenied(CustomHttpException):
status_code: int = status.HTTP_403_FORBIDDEN,
):
super().__init__(code=code, detail=detail, status_code=status_code)
class ValidationError(CustomHttpException):
def __init__(self, code: str, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
super().__init__(code=code, detail=detail, status_code=status_code)
def flatten_errors(field_name, errors):
ret = []
if isinstance(errors, dict):
for error_key in errors:
error = errors[error_key]
ret.extend(flatten_errors("{}.{}".format(field_name, error_key), error))
else:
for error in errors:
if error.messages:
message = error.messages[0]
else:
message = str(error)
ret.append(
{
"field": field_name,
"code": error.code,
"detail": message,
}
)
return ret
def transform_validation_error(prefix, err):
if hasattr(err, "error_dict"):
errors = flatten_errors(prefix, err.error_dict)
elif not hasattr(err, "message"):
errors = flatten_errors(prefix, err.error_list)
else:
raise EtebaseValidationError(err.code, err.message)
raise ValidationError(code="field_errors", detail="Field validations failed.")
raise serializers.ValidationError(
{
"code": "field_errors",
"detail": "Field validations failed.",
"errors": errors,
}
)

Loading…
Cancel
Save