152 lines
5.0 KiB
Python
152 lines
5.0 KiB
Python
import uuid
|
|
from datetime import timedelta
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from jose import JWTError
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.auth.dependencies import get_current_user
|
|
from app.auth.security import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
decode_token,
|
|
hash_password,
|
|
hash_token,
|
|
verify_password,
|
|
)
|
|
from app.config import settings
|
|
from app.database import get_session
|
|
from app.models.refresh_token import RefreshToken
|
|
from app.models.user import User
|
|
from app.schemas.auth import Token, TokenRefresh, TokenRefreshRequest, UserCreate, UserResponse
|
|
from app.services.category_service import create_default_categories
|
|
from app.utils import utcnow
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
user_data: UserCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> UserResponse:
|
|
existing = await session.execute(select(User).where(User.email == user_data.email))
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT, detail="Email already registered"
|
|
)
|
|
|
|
user = User(
|
|
email=user_data.email,
|
|
hashed_password=hash_password(user_data.password),
|
|
full_name=user_data.full_name,
|
|
)
|
|
session.add(user)
|
|
await session.flush() # Populate user.id before creating categories
|
|
|
|
await create_default_categories(session, user.id)
|
|
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
return UserResponse.model_validate(user)
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
async def login(
|
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Token:
|
|
result = await session.execute(select(User).where(User.email == form_data.username))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None or not verify_password(form_data.password, user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect email or password",
|
|
)
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account inactive")
|
|
|
|
access_token = create_access_token({"sub": str(user.id)})
|
|
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
|
|
|
token_entry = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=hash_token(refresh_token_str),
|
|
expires_at=utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
|
)
|
|
session.add(token_entry)
|
|
await session.commit()
|
|
|
|
return Token(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token_str,
|
|
token_type="bearer",
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=TokenRefresh)
|
|
async def refresh_token(
|
|
data: TokenRefreshRequest,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenRefresh:
|
|
try:
|
|
payload = decode_token(data.refresh_token)
|
|
user_id_str: str | None = payload.get("sub")
|
|
if not user_id_str:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
|
)
|
|
user_id = uuid.UUID(user_id_str)
|
|
except (JWTError, ValueError):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
|
)
|
|
|
|
token_hash = hash_token(data.refresh_token)
|
|
result = await session.execute(
|
|
select(RefreshToken).where(
|
|
RefreshToken.token_hash == token_hash,
|
|
RefreshToken.revoked_at.is_(None),
|
|
)
|
|
)
|
|
token_entry = result.scalar_one_or_none()
|
|
|
|
if token_entry is None or token_entry.expires_at < utcnow():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked"
|
|
)
|
|
|
|
# Revoke consumed token (rotation)
|
|
token_entry.revoked_at = utcnow()
|
|
await session.commit()
|
|
|
|
new_access_token = create_access_token({"sub": str(user_id)})
|
|
return TokenRefresh(access_token=new_access_token, token_type="bearer")
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_me(current_user: User = Depends(get_current_user)) -> UserResponse:
|
|
return UserResponse.model_validate(current_user)
|
|
|
|
|
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def logout(
|
|
data: TokenRefreshRequest,
|
|
session: AsyncSession = Depends(get_session),
|
|
current_user: User = Depends(get_current_user),
|
|
) -> None:
|
|
token_hash = hash_token(data.refresh_token)
|
|
result = await session.execute(
|
|
select(RefreshToken).where(
|
|
RefreshToken.token_hash == token_hash,
|
|
RefreshToken.user_id == current_user.id,
|
|
)
|
|
)
|
|
token_entry = result.scalar_one_or_none()
|
|
if token_entry is not None:
|
|
token_entry.revoked_at = utcnow()
|
|
await session.commit()
|