Files
budget-tracker/backend/app/auth/router.py
T

152 lines
5.0 KiB
Python

import uuid
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.dependencies import get_current_user
from app.auth.security import (
create_access_token,
create_refresh_token,
decode_token,
hash_password,
hash_token,
verify_password,
)
from app.config import settings
from app.database import get_session
from app.models.refresh_token import RefreshToken
from app.models.user import User
from app.schemas.auth import Token, TokenRefresh, TokenRefreshRequest, UserCreate, UserResponse
from app.services.category_service import create_default_categories
from app.utils import utcnow
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_data: UserCreate,
session: AsyncSession = Depends(get_session),
) -> UserResponse:
existing = await session.execute(select(User).where(User.email == user_data.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Email already registered"
)
user = User(
email=user_data.email,
hashed_password=hash_password(user_data.password),
full_name=user_data.full_name,
)
session.add(user)
await session.flush() # Populate user.id before creating categories
await create_default_categories(session, user.id)
await session.commit()
await session.refresh(user)
return UserResponse.model_validate(user)
@router.post("/login", response_model=Token)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
session: AsyncSession = Depends(get_session),
) -> Token:
result = await session.execute(select(User).where(User.email == form_data.username))
user = result.scalar_one_or_none()
if user is None or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account inactive")
access_token = create_access_token({"sub": str(user.id)})
refresh_token_str = create_refresh_token({"sub": str(user.id)})
token_entry = RefreshToken(
user_id=user.id,
token_hash=hash_token(refresh_token_str),
expires_at=utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
)
session.add(token_entry)
await session.commit()
return Token(
access_token=access_token,
refresh_token=refresh_token_str,
token_type="bearer",
)
@router.post("/refresh", response_model=TokenRefresh)
async def refresh_token(
data: TokenRefreshRequest,
session: AsyncSession = Depends(get_session),
) -> TokenRefresh:
try:
payload = decode_token(data.refresh_token)
user_id_str: str | None = payload.get("sub")
if not user_id_str:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
user_id = uuid.UUID(user_id_str)
except (JWTError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
token_hash = hash_token(data.refresh_token)
result = await session.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
)
token_entry = result.scalar_one_or_none()
if token_entry is None or token_entry.expires_at < utcnow():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked"
)
# Revoke consumed token (rotation)
token_entry.revoked_at = utcnow()
await session.commit()
new_access_token = create_access_token({"sub": str(user_id)})
return TokenRefresh(access_token=new_access_token, token_type="bearer")
@router.get("/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)) -> UserResponse:
return UserResponse.model_validate(current_user)
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
async def logout(
data: TokenRefreshRequest,
session: AsyncSession = Depends(get_session),
current_user: User = Depends(get_current_user),
) -> None:
token_hash = hash_token(data.refresh_token)
result = await session.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.user_id == current_user.id,
)
)
token_entry = result.scalar_one_or_none()
if token_entry is not None:
token_entry.revoked_at = utcnow()
await session.commit()