feat: backend core — models, auth, CRUD, tests
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import decode_token
|
||||
from app.database import get_session
|
||||
from app.models.user import User
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if user_id_str is None:
|
||||
raise credentials_exception
|
||||
user_id = uuid.UUID(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
@@ -0,0 +1,146 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.auth.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
hash_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import get_session
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import Token, TokenRefresh, TokenRefreshRequest, UserCreate, UserResponse
|
||||
from app.services.category_service import create_default_categories
|
||||
from app.utils import utcnow
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserResponse:
|
||||
existing = await session.execute(select(User).where(User.email == user_data.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail="Email already registered"
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
hashed_password=hash_password(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush() # Populate user.id before creating categories
|
||||
|
||||
await create_default_categories(session, user.id)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Token:
|
||||
result = await session.execute(select(User).where(User.email == form_data.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account inactive")
|
||||
|
||||
access_token = create_access_token({"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
|
||||
token_entry = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(refresh_token_str),
|
||||
expires_at=utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
)
|
||||
session.add(token_entry)
|
||||
await session.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenRefresh)
|
||||
async def refresh_token(
|
||||
data: TokenRefreshRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenRefresh:
|
||||
try:
|
||||
payload = decode_token(data.refresh_token)
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if not user_id_str:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
user_id = uuid.UUID(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
token_hash = hash_token(data.refresh_token)
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
)
|
||||
token_entry = result.scalar_one_or_none()
|
||||
|
||||
if token_entry is None or token_entry.expires_at < utcnow():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked"
|
||||
)
|
||||
|
||||
# Revoke consumed token (rotation)
|
||||
token_entry.revoked_at = utcnow()
|
||||
await session.commit()
|
||||
|
||||
new_access_token = create_access_token({"sub": str(user_id)})
|
||||
return TokenRefresh(access_token=new_access_token, token_type="bearer")
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
data: TokenRefreshRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
token_hash = hash_token(data.refresh_token)
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
token_entry = result.scalar_one_or_none()
|
||||
if token_entry is not None:
|
||||
token_entry.revoked_at = utcnow()
|
||||
await session.commit()
|
||||
@@ -0,0 +1,46 @@
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
from app.utils import utcnow
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict[str, Any]) -> str:
|
||||
to_encode = data.copy()
|
||||
to_encode["exp"] = utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict[str, Any]) -> str:
|
||||
to_encode = data.copy()
|
||||
to_encode["exp"] = utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
# jti ensures each token is unique even when issued for the same user at the same time
|
||||
to_encode["jti"] = str(uuid.uuid4())
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict[str, Any]:
|
||||
"""Decode and verify a JWT token. Raises jose.JWTError on failure."""
|
||||
return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Return SHA-256 hex digest of a token for safe storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
Reference in New Issue
Block a user