feat: backend core — models, auth, CRUD, tests
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.security import decode_token
|
||||
from app.database import get_session
|
||||
from app.models.user import User
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if user_id_str is None:
|
||||
raise credentials_exception
|
||||
user_id = uuid.UUID(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
@@ -0,0 +1,146 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.auth.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
hash_token,
|
||||
verify_password,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import get_session
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import Token, TokenRefresh, TokenRefreshRequest, UserCreate, UserResponse
|
||||
from app.services.category_service import create_default_categories
|
||||
from app.utils import utcnow
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserResponse:
|
||||
existing = await session.execute(select(User).where(User.email == user_data.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail="Email already registered"
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
hashed_password=hash_password(user_data.password),
|
||||
full_name=user_data.full_name,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush() # Populate user.id before creating categories
|
||||
|
||||
await create_default_categories(session, user.id)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Token:
|
||||
result = await session.execute(select(User).where(User.email == form_data.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account inactive")
|
||||
|
||||
access_token = create_access_token({"sub": str(user.id)})
|
||||
refresh_token_str = create_refresh_token({"sub": str(user.id)})
|
||||
|
||||
token_entry = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=hash_token(refresh_token_str),
|
||||
expires_at=utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
)
|
||||
session.add(token_entry)
|
||||
await session.commit()
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token_str,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenRefresh)
|
||||
async def refresh_token(
|
||||
data: TokenRefreshRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenRefresh:
|
||||
try:
|
||||
payload = decode_token(data.refresh_token)
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if not user_id_str:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
user_id = uuid.UUID(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
token_hash = hash_token(data.refresh_token)
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
)
|
||||
token_entry = result.scalar_one_or_none()
|
||||
|
||||
if token_entry is None or token_entry.expires_at < utcnow():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired or revoked"
|
||||
)
|
||||
|
||||
# Revoke consumed token (rotation)
|
||||
token_entry.revoked_at = utcnow()
|
||||
await session.commit()
|
||||
|
||||
new_access_token = create_access_token({"sub": str(user_id)})
|
||||
return TokenRefresh(access_token=new_access_token, token_type="bearer")
|
||||
|
||||
|
||||
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def logout(
|
||||
data: TokenRefreshRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
token_hash = hash_token(data.refresh_token)
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
token_entry = result.scalar_one_or_none()
|
||||
if token_entry is not None:
|
||||
token_entry.revoked_at = utcnow()
|
||||
await session.commit()
|
||||
@@ -0,0 +1,46 @@
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
from app.utils import utcnow
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict[str, Any]) -> str:
|
||||
to_encode = data.copy()
|
||||
to_encode["exp"] = utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict[str, Any]) -> str:
|
||||
to_encode = data.copy()
|
||||
to_encode["exp"] = utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
# jti ensures each token is unique even when issued for the same user at the same time
|
||||
to_encode["jti"] = str(uuid.uuid4())
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict[str, Any]:
|
||||
"""Decode and verify a JWT token. Raises jose.JWTError on failure."""
|
||||
return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Return SHA-256 hex digest of a token for safe storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
+9
-1
@@ -1,11 +1,14 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.auth.router import router as auth_router
|
||||
from app.config import settings
|
||||
from app.database import engine
|
||||
from app.routers.categories import router as categories_router
|
||||
from app.routers.transactions import router as transactions_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -33,6 +36,11 @@ def create_app() -> FastAPI:
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
api_prefix = "/api/v1"
|
||||
app.include_router(auth_router, prefix=api_prefix)
|
||||
app.include_router(transactions_router, prefix=api_prefix)
|
||||
app.include_router(categories_router, prefix=api_prefix)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Import all models so Alembic can discover them for autogenerate
|
||||
from app.models.budget import Budget
|
||||
from app.models.category import Category, CategoryType
|
||||
from app.models.refresh_token import RefreshToken
|
||||
from app.models.transaction import Transaction, TransactionType
|
||||
from app.models.user import User
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Category",
|
||||
"CategoryType",
|
||||
"Transaction",
|
||||
"TransactionType",
|
||||
"Budget",
|
||||
"RefreshToken",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.category import Category
|
||||
|
||||
|
||||
class Budget(Base):
|
||||
__tablename__ = "budgets"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "category_id", "month", name="uq_budgets_user_category_month"),
|
||||
CheckConstraint("limit_cents > 0", name="ck_budgets_limit_positive"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("categories.id"), nullable=False
|
||||
)
|
||||
# Format YYYY-MM
|
||||
month: Mapped[str] = mapped_column(String(7), nullable=False)
|
||||
limit_cents: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
category: Mapped["Category"] = relationship("Category", lazy="raise")
|
||||
@@ -0,0 +1,36 @@
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, String, Uuid, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class CategoryType(str, enum.Enum):
|
||||
income = "income"
|
||||
expense = "expense"
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
type: Mapped[CategoryType] = mapped_column(
|
||||
Enum(CategoryType, name="categorytype", native_enum=False, length=10),
|
||||
nullable=False,
|
||||
)
|
||||
color: Mapped[str | None] = mapped_column(String(7), nullable=True)
|
||||
icon: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
is_default: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=False), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Uuid, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
# SHA-256 hash of the raw token (never store raw tokens)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=False), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,62 @@
|
||||
import enum
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Uuid,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.category import Category
|
||||
|
||||
|
||||
class TransactionType(str, enum.Enum):
|
||||
income = "income"
|
||||
expense = "expense"
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
__table_args__ = (
|
||||
CheckConstraint("amount_cents > 0", name="ck_transactions_amount_positive"),
|
||||
Index("ix_transactions_user_date", "user_id", "transaction_date", "deleted_at"),
|
||||
Index("ix_transactions_user_category", "user_id", "category_id", "deleted_at"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("users.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid, ForeignKey("categories.id"), nullable=False
|
||||
)
|
||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
type: Mapped[TransactionType] = mapped_column(
|
||||
Enum(TransactionType, name="transactiontype", native_enum=False, length=10),
|
||||
nullable=False,
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
transaction_date: Mapped[date] = mapped_column(Date, nullable=False)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=False), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships — use selectinload() explicitly; lazy="raise" prevents accidental N+1
|
||||
category: Mapped["Category"] = relationship("Category", lazy="raise")
|
||||
@@ -0,0 +1,23 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String, Uuid, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, server_default="true")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.database import get_session
|
||||
from app.models.category import CategoryType
|
||||
from app.models.user import User
|
||||
from app.schemas.category import CategoryCreate, CategoryResponse, CategoryUpdate
|
||||
from app.services import category_service
|
||||
|
||||
router = APIRouter(prefix="/categories", tags=["categories"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[CategoryResponse])
|
||||
async def list_categories(
|
||||
type: CategoryType | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[CategoryResponse]:
|
||||
cats = await category_service.list_categories(session, current_user.id, type_filter=type)
|
||||
return [CategoryResponse.model_validate(c) for c in cats]
|
||||
|
||||
|
||||
@router.post("", response_model=CategoryResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_category(
|
||||
data: CategoryCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> CategoryResponse:
|
||||
cat = await category_service.create_category(session, current_user.id, data)
|
||||
return CategoryResponse.model_validate(cat)
|
||||
|
||||
|
||||
@router.put("/{category_id}", response_model=CategoryResponse)
|
||||
async def update_category(
|
||||
category_id: uuid.UUID,
|
||||
data: CategoryUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> CategoryResponse:
|
||||
cat = await category_service.update_category(session, current_user.id, category_id, data)
|
||||
return CategoryResponse.model_validate(cat)
|
||||
|
||||
|
||||
@router.delete("/{category_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_category(
|
||||
category_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
await category_service.delete_category(session, current_user.id, category_id)
|
||||
@@ -0,0 +1,87 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.dependencies import get_current_user
|
||||
from app.database import get_session
|
||||
from app.models.transaction import TransactionType
|
||||
from app.models.user import User
|
||||
from app.schemas.transaction import (
|
||||
PaginatedTransactions,
|
||||
TransactionCreate,
|
||||
TransactionResponse,
|
||||
TransactionUpdate,
|
||||
)
|
||||
from app.services import transaction_service
|
||||
|
||||
router = APIRouter(prefix="/transactions", tags=["transactions"])
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedTransactions)
|
||||
async def list_transactions(
|
||||
month: str | None = Query(None, description="Filter by month (YYYY-MM)"),
|
||||
category_id: uuid.UUID | None = Query(None),
|
||||
type: TransactionType | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> PaginatedTransactions:
|
||||
items, total = await transaction_service.list_transactions(
|
||||
session,
|
||||
current_user.id,
|
||||
month=month,
|
||||
category_id=category_id,
|
||||
type_filter=type,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
return PaginatedTransactions(
|
||||
items=[TransactionResponse.model_validate(t) for t in items],
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=TransactionResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_transaction(
|
||||
data: TransactionCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TransactionResponse:
|
||||
tx = await transaction_service.create_transaction(session, current_user.id, data)
|
||||
return TransactionResponse.model_validate(tx)
|
||||
|
||||
|
||||
@router.get("/{transaction_id}", response_model=TransactionResponse)
|
||||
async def get_transaction(
|
||||
transaction_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TransactionResponse:
|
||||
tx = await transaction_service.get_transaction(session, current_user.id, transaction_id)
|
||||
return TransactionResponse.model_validate(tx)
|
||||
|
||||
|
||||
@router.put("/{transaction_id}", response_model=TransactionResponse)
|
||||
async def update_transaction(
|
||||
transaction_id: uuid.UUID,
|
||||
data: TransactionUpdate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TransactionResponse:
|
||||
tx = await transaction_service.update_transaction(
|
||||
session, current_user.id, transaction_id, data
|
||||
)
|
||||
return TransactionResponse.model_validate(tx)
|
||||
|
||||
|
||||
@router.delete("/{transaction_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_transaction(
|
||||
transaction_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> None:
|
||||
await transaction_service.delete_transaction(session, current_user.id, transaction_id)
|
||||
@@ -0,0 +1,32 @@
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
full_name: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
email: str
|
||||
full_name: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenRefresh(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class TokenRefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
@@ -0,0 +1,31 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.category import CategoryType
|
||||
|
||||
|
||||
class CategoryCreate(BaseModel):
|
||||
name: str
|
||||
type: CategoryType
|
||||
color: str | None = None
|
||||
icon: str | None = None
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
color: str | None = None
|
||||
icon: str | None = None
|
||||
|
||||
|
||||
class CategoryResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
type: CategoryType
|
||||
color: str | None
|
||||
icon: str | None
|
||||
is_default: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
@@ -0,0 +1,12 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
items: list[T]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
@@ -0,0 +1,60 @@
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from app.models.transaction import TransactionType
|
||||
from app.schemas.common import PaginatedResponse
|
||||
|
||||
|
||||
class CategoryBrief(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
color: str | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class TransactionCreate(BaseModel):
|
||||
amount_cents: int
|
||||
type: TransactionType
|
||||
category_id: uuid.UUID
|
||||
description: str | None = None
|
||||
transaction_date: date
|
||||
|
||||
@field_validator("amount_cents")
|
||||
@classmethod
|
||||
def amount_must_be_positive(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("amount_cents must be greater than 0")
|
||||
return v
|
||||
|
||||
|
||||
class TransactionUpdate(BaseModel):
|
||||
amount_cents: int | None = None
|
||||
type: TransactionType | None = None
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
transaction_date: date | None = None
|
||||
|
||||
@field_validator("amount_cents")
|
||||
@classmethod
|
||||
def amount_must_be_positive(cls, v: int | None) -> int | None:
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("amount_cents must be greater than 0")
|
||||
return v
|
||||
|
||||
|
||||
class TransactionResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
amount_cents: int
|
||||
type: TransactionType
|
||||
description: str | None
|
||||
category: CategoryBrief
|
||||
transaction_date: date
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
PaginatedTransactions = PaginatedResponse[TransactionResponse]
|
||||
@@ -0,0 +1,134 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.category import Category, CategoryType
|
||||
from app.models.transaction import Transaction
|
||||
from app.schemas.category import CategoryCreate, CategoryUpdate
|
||||
from app.utils import utcnow
|
||||
|
||||
# Default categories created at registration
|
||||
_DEFAULT_CATEGORIES = [
|
||||
{"name": "Alimentation", "type": CategoryType.expense, "color": "#22c55e", "icon": "utensils"},
|
||||
{"name": "Transport", "type": CategoryType.expense, "color": "#3b82f6", "icon": "car"},
|
||||
{"name": "Logement", "type": CategoryType.expense, "color": "#f59e0b", "icon": "home"},
|
||||
{"name": "Santé", "type": CategoryType.expense, "color": "#ef4444", "icon": "heart-pulse"},
|
||||
{"name": "Loisirs", "type": CategoryType.expense, "color": "#a855f7", "icon": "gamepad-2"},
|
||||
{"name": "Divers", "type": CategoryType.expense, "color": "#6b7280", "icon": "package"},
|
||||
{"name": "Salaire", "type": CategoryType.income, "color": "#10b981", "icon": "briefcase"},
|
||||
{"name": "Freelance", "type": CategoryType.income, "color": "#06b6d4", "icon": "laptop"},
|
||||
{"name": "Remboursement", "type": CategoryType.income, "color": "#8b5cf6", "icon": "refresh-cw"},
|
||||
]
|
||||
|
||||
|
||||
async def create_default_categories(
|
||||
session: AsyncSession, user_id: uuid.UUID
|
||||
) -> list[Category]:
|
||||
"""Create the default categories for a newly registered user."""
|
||||
categories = []
|
||||
for data in _DEFAULT_CATEGORIES:
|
||||
cat = Category(user_id=user_id, is_default=True, **data)
|
||||
session.add(cat)
|
||||
categories.append(cat)
|
||||
await session.flush()
|
||||
return categories
|
||||
|
||||
|
||||
async def list_categories(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
type_filter: CategoryType | None = None,
|
||||
) -> list[Category]:
|
||||
query = select(Category).where(
|
||||
Category.user_id == user_id,
|
||||
Category.deleted_at.is_(None),
|
||||
)
|
||||
if type_filter is not None:
|
||||
query = query.where(Category.type == type_filter)
|
||||
query = query.order_by(Category.name)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_category(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
category_id: uuid.UUID,
|
||||
) -> Category:
|
||||
result = await session.execute(
|
||||
select(Category).where(
|
||||
Category.id == category_id,
|
||||
Category.user_id == user_id,
|
||||
Category.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
cat = result.scalar_one_or_none()
|
||||
if cat is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
|
||||
return cat
|
||||
|
||||
|
||||
async def create_category(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
data: CategoryCreate,
|
||||
) -> Category:
|
||||
cat = Category(
|
||||
user_id=user_id,
|
||||
name=data.name,
|
||||
type=data.type,
|
||||
color=data.color,
|
||||
icon=data.icon,
|
||||
is_default=False,
|
||||
)
|
||||
session.add(cat)
|
||||
await session.commit()
|
||||
await session.refresh(cat)
|
||||
return cat
|
||||
|
||||
|
||||
async def update_category(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
category_id: uuid.UUID,
|
||||
data: CategoryUpdate,
|
||||
) -> Category:
|
||||
cat = await get_category(session, user_id, category_id)
|
||||
if data.name is not None:
|
||||
cat.name = data.name
|
||||
if data.color is not None:
|
||||
cat.color = data.color
|
||||
if data.icon is not None:
|
||||
cat.icon = data.icon
|
||||
await session.commit()
|
||||
await session.refresh(cat)
|
||||
return cat
|
||||
|
||||
|
||||
async def delete_category(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
category_id: uuid.UUID,
|
||||
) -> None:
|
||||
cat = await get_category(session, user_id, category_id)
|
||||
|
||||
# Refuse deletion if active transactions exist
|
||||
count_result = await session.execute(
|
||||
select(func.count()).where(
|
||||
Transaction.category_id == category_id,
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
active_count = count_result.scalar_one()
|
||||
if active_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Cannot delete category: {active_count} active transaction(s) are linked to it",
|
||||
)
|
||||
|
||||
cat.deleted_at = utcnow()
|
||||
await session.commit()
|
||||
@@ -0,0 +1,134 @@
|
||||
> import uuid
|
||||
|
||||
> from fastapi import HTTPException, status
|
||||
> from sqlalchemy import func, select
|
||||
> from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
> from app.models.category import Category, CategoryType
|
||||
> from app.models.transaction import Transaction
|
||||
> from app.schemas.category import CategoryCreate, CategoryUpdate
|
||||
> from app.utils import utcnow
|
||||
|
||||
# Default categories created at registration
|
||||
> _DEFAULT_CATEGORIES = [
|
||||
> {"name": "Alimentation", "type": CategoryType.expense, "color": "#22c55e", "icon": "utensils"},
|
||||
> {"name": "Transport", "type": CategoryType.expense, "color": "#3b82f6", "icon": "car"},
|
||||
> {"name": "Logement", "type": CategoryType.expense, "color": "#f59e0b", "icon": "home"},
|
||||
> {"name": "Santé", "type": CategoryType.expense, "color": "#ef4444", "icon": "heart-pulse"},
|
||||
> {"name": "Loisirs", "type": CategoryType.expense, "color": "#a855f7", "icon": "gamepad-2"},
|
||||
> {"name": "Divers", "type": CategoryType.expense, "color": "#6b7280", "icon": "package"},
|
||||
> {"name": "Salaire", "type": CategoryType.income, "color": "#10b981", "icon": "briefcase"},
|
||||
> {"name": "Freelance", "type": CategoryType.income, "color": "#06b6d4", "icon": "laptop"},
|
||||
> {"name": "Remboursement", "type": CategoryType.income, "color": "#8b5cf6", "icon": "refresh-cw"},
|
||||
> ]
|
||||
|
||||
|
||||
> async def create_default_categories(
|
||||
> session: AsyncSession, user_id: uuid.UUID
|
||||
> ) -> list[Category]:
|
||||
> """Create the default categories for a newly registered user."""
|
||||
> categories = []
|
||||
> for data in _DEFAULT_CATEGORIES:
|
||||
> cat = Category(user_id=user_id, is_default=True, **data)
|
||||
> session.add(cat)
|
||||
> categories.append(cat)
|
||||
> await session.flush()
|
||||
! return categories
|
||||
|
||||
|
||||
> async def list_categories(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> *,
|
||||
> type_filter: CategoryType | None = None,
|
||||
> ) -> list[Category]:
|
||||
> query = select(Category).where(
|
||||
> Category.user_id == user_id,
|
||||
> Category.deleted_at.is_(None),
|
||||
> )
|
||||
> if type_filter is not None:
|
||||
> query = query.where(Category.type == type_filter)
|
||||
> query = query.order_by(Category.name)
|
||||
> result = await session.execute(query)
|
||||
! return list(result.scalars().all())
|
||||
|
||||
|
||||
> async def get_category(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> category_id: uuid.UUID,
|
||||
> ) -> Category:
|
||||
> result = await session.execute(
|
||||
> select(Category).where(
|
||||
> Category.id == category_id,
|
||||
> Category.user_id == user_id,
|
||||
> Category.deleted_at.is_(None),
|
||||
> )
|
||||
> )
|
||||
! cat = result.scalar_one_or_none()
|
||||
! if cat is None:
|
||||
! raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
|
||||
! return cat
|
||||
|
||||
|
||||
> async def create_category(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> data: CategoryCreate,
|
||||
> ) -> Category:
|
||||
> cat = Category(
|
||||
> user_id=user_id,
|
||||
> name=data.name,
|
||||
> type=data.type,
|
||||
> color=data.color,
|
||||
> icon=data.icon,
|
||||
> is_default=False,
|
||||
> )
|
||||
> session.add(cat)
|
||||
> await session.commit()
|
||||
! await session.refresh(cat)
|
||||
! return cat
|
||||
|
||||
|
||||
> async def update_category(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> category_id: uuid.UUID,
|
||||
> data: CategoryUpdate,
|
||||
> ) -> Category:
|
||||
> cat = await get_category(session, user_id, category_id)
|
||||
! if data.name is not None:
|
||||
! cat.name = data.name
|
||||
! if data.color is not None:
|
||||
! cat.color = data.color
|
||||
! if data.icon is not None:
|
||||
! cat.icon = data.icon
|
||||
! await session.commit()
|
||||
! await session.refresh(cat)
|
||||
! return cat
|
||||
|
||||
|
||||
> async def delete_category(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> category_id: uuid.UUID,
|
||||
> ) -> None:
|
||||
> cat = await get_category(session, user_id, category_id)
|
||||
|
||||
# Refuse deletion if active transactions exist
|
||||
! count_result = await session.execute(
|
||||
! select(func.count()).where(
|
||||
! Transaction.category_id == category_id,
|
||||
! Transaction.user_id == user_id,
|
||||
! Transaction.deleted_at.is_(None),
|
||||
! )
|
||||
! )
|
||||
! active_count = count_result.scalar_one()
|
||||
! if active_count > 0:
|
||||
! raise HTTPException(
|
||||
! status_code=status.HTTP_409_CONFLICT,
|
||||
! detail=f"Cannot delete category: {active_count} active transaction(s) are linked to it",
|
||||
! )
|
||||
|
||||
! cat.deleted_at = utcnow()
|
||||
! await session.commit()
|
||||
@@ -0,0 +1,152 @@
|
||||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.transaction import Transaction, TransactionType
|
||||
from app.schemas.transaction import TransactionCreate, TransactionUpdate
|
||||
from app.utils import utcnow
|
||||
|
||||
|
||||
def _month_bounds(month: str) -> tuple[date, date]:
|
||||
"""Return (start_inclusive, end_exclusive) date bounds for a YYYY-MM string."""
|
||||
year, mon = map(int, month.split("-"))
|
||||
start = date(year, mon, 1)
|
||||
if mon == 12:
|
||||
end = date(year + 1, 1, 1)
|
||||
else:
|
||||
end = date(year, mon + 1, 1)
|
||||
return start, end
|
||||
|
||||
|
||||
async def create_transaction(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
data: TransactionCreate,
|
||||
) -> Transaction:
|
||||
tx = Transaction(
|
||||
user_id=user_id,
|
||||
category_id=data.category_id,
|
||||
amount_cents=data.amount_cents,
|
||||
type=data.type,
|
||||
description=data.description,
|
||||
transaction_date=data.transaction_date,
|
||||
)
|
||||
session.add(tx)
|
||||
await session.commit()
|
||||
|
||||
# Reload with category relationship
|
||||
result = await session.execute(
|
||||
select(Transaction)
|
||||
.options(selectinload(Transaction.category))
|
||||
.where(Transaction.id == tx.id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def list_transactions(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
month: str | None = None,
|
||||
category_id: uuid.UUID | None = None,
|
||||
type_filter: TransactionType | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[Transaction], int]:
|
||||
base_query = select(Transaction).where(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
|
||||
if month is not None:
|
||||
start, end = _month_bounds(month)
|
||||
base_query = base_query.where(
|
||||
Transaction.transaction_date >= start,
|
||||
Transaction.transaction_date < end,
|
||||
)
|
||||
if category_id is not None:
|
||||
base_query = base_query.where(Transaction.category_id == category_id)
|
||||
if type_filter is not None:
|
||||
base_query = base_query.where(Transaction.type == type_filter)
|
||||
|
||||
# Total count (before pagination)
|
||||
count_result = await session.execute(
|
||||
select(func.count()).select_from(base_query.subquery())
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Paginated rows with category eager-loaded
|
||||
items_query = (
|
||||
base_query.options(selectinload(Transaction.category))
|
||||
.order_by(Transaction.transaction_date.desc(), Transaction.created_at.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
)
|
||||
result = await session.execute(items_query)
|
||||
return list(result.scalars().all()), total
|
||||
|
||||
|
||||
async def get_transaction(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
transaction_id: uuid.UUID,
|
||||
) -> Transaction:
|
||||
result = await session.execute(
|
||||
select(Transaction)
|
||||
.options(selectinload(Transaction.category))
|
||||
.where(
|
||||
Transaction.id == transaction_id,
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
tx = result.scalar_one_or_none()
|
||||
if tx is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Transaction not found")
|
||||
return tx
|
||||
|
||||
|
||||
async def update_transaction(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
transaction_id: uuid.UUID,
|
||||
data: TransactionUpdate,
|
||||
) -> Transaction:
|
||||
tx = await get_transaction(session, user_id, transaction_id)
|
||||
|
||||
if data.amount_cents is not None:
|
||||
tx.amount_cents = data.amount_cents
|
||||
if data.type is not None:
|
||||
tx.type = data.type
|
||||
if data.category_id is not None:
|
||||
tx.category_id = data.category_id
|
||||
if data.description is not None:
|
||||
tx.description = data.description
|
||||
if data.transaction_date is not None:
|
||||
tx.transaction_date = data.transaction_date
|
||||
tx.updated_at = utcnow()
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Reload with fresh category
|
||||
result = await session.execute(
|
||||
select(Transaction)
|
||||
.options(selectinload(Transaction.category))
|
||||
.where(Transaction.id == tx.id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def delete_transaction(
|
||||
session: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
transaction_id: uuid.UUID,
|
||||
) -> None:
|
||||
tx = await get_transaction(session, user_id, transaction_id)
|
||||
tx.deleted_at = utcnow()
|
||||
tx.updated_at = utcnow()
|
||||
await session.commit()
|
||||
@@ -0,0 +1,152 @@
|
||||
> import uuid
|
||||
> from datetime import date
|
||||
|
||||
> from fastapi import HTTPException, status
|
||||
> from sqlalchemy import func, select
|
||||
> from sqlalchemy.ext.asyncio import AsyncSession
|
||||
> from sqlalchemy.orm import selectinload
|
||||
|
||||
> from app.models.transaction import Transaction, TransactionType
|
||||
> from app.schemas.transaction import TransactionCreate, TransactionUpdate
|
||||
> from app.utils import utcnow
|
||||
|
||||
|
||||
> def _month_bounds(month: str) -> tuple[date, date]:
|
||||
> """Return (start_inclusive, end_exclusive) date bounds for a YYYY-MM string."""
|
||||
> year, mon = map(int, month.split("-"))
|
||||
> start = date(year, mon, 1)
|
||||
> if mon == 12:
|
||||
! end = date(year + 1, 1, 1)
|
||||
> else:
|
||||
> end = date(year, mon + 1, 1)
|
||||
> return start, end
|
||||
|
||||
|
||||
> async def create_transaction(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> data: TransactionCreate,
|
||||
> ) -> Transaction:
|
||||
> tx = Transaction(
|
||||
> user_id=user_id,
|
||||
> category_id=data.category_id,
|
||||
> amount_cents=data.amount_cents,
|
||||
> type=data.type,
|
||||
> description=data.description,
|
||||
> transaction_date=data.transaction_date,
|
||||
> )
|
||||
> session.add(tx)
|
||||
> await session.commit()
|
||||
|
||||
# Reload with category relationship
|
||||
! result = await session.execute(
|
||||
! select(Transaction)
|
||||
! .options(selectinload(Transaction.category))
|
||||
! .where(Transaction.id == tx.id)
|
||||
! )
|
||||
! return result.scalar_one()
|
||||
|
||||
|
||||
> async def list_transactions(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> *,
|
||||
> month: str | None = None,
|
||||
> category_id: uuid.UUID | None = None,
|
||||
> type_filter: TransactionType | None = None,
|
||||
> page: int = 1,
|
||||
> per_page: int = 20,
|
||||
> ) -> tuple[list[Transaction], int]:
|
||||
> base_query = select(Transaction).where(
|
||||
> Transaction.user_id == user_id,
|
||||
> Transaction.deleted_at.is_(None),
|
||||
> )
|
||||
|
||||
> if month is not None:
|
||||
> start, end = _month_bounds(month)
|
||||
> base_query = base_query.where(
|
||||
> Transaction.transaction_date >= start,
|
||||
> Transaction.transaction_date < end,
|
||||
> )
|
||||
> if category_id is not None:
|
||||
> base_query = base_query.where(Transaction.category_id == category_id)
|
||||
> if type_filter is not None:
|
||||
> base_query = base_query.where(Transaction.type == type_filter)
|
||||
|
||||
# Total count (before pagination)
|
||||
> count_result = await session.execute(
|
||||
> select(func.count()).select_from(base_query.subquery())
|
||||
> )
|
||||
! total = count_result.scalar_one()
|
||||
|
||||
# Paginated rows with category eager-loaded
|
||||
! items_query = (
|
||||
! base_query.options(selectinload(Transaction.category))
|
||||
! .order_by(Transaction.transaction_date.desc(), Transaction.created_at.desc())
|
||||
! .offset((page - 1) * per_page)
|
||||
! .limit(per_page)
|
||||
! )
|
||||
! result = await session.execute(items_query)
|
||||
! return list(result.scalars().all()), total
|
||||
|
||||
|
||||
> async def get_transaction(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> transaction_id: uuid.UUID,
|
||||
> ) -> Transaction:
|
||||
> result = await session.execute(
|
||||
> select(Transaction)
|
||||
> .options(selectinload(Transaction.category))
|
||||
> .where(
|
||||
> Transaction.id == transaction_id,
|
||||
> Transaction.user_id == user_id,
|
||||
> Transaction.deleted_at.is_(None),
|
||||
> )
|
||||
> )
|
||||
! tx = result.scalar_one_or_none()
|
||||
! if tx is None:
|
||||
! raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Transaction not found")
|
||||
! return tx
|
||||
|
||||
|
||||
> async def update_transaction(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> transaction_id: uuid.UUID,
|
||||
> data: TransactionUpdate,
|
||||
> ) -> Transaction:
|
||||
> tx = await get_transaction(session, user_id, transaction_id)
|
||||
|
||||
! if data.amount_cents is not None:
|
||||
! tx.amount_cents = data.amount_cents
|
||||
! if data.type is not None:
|
||||
! tx.type = data.type
|
||||
! if data.category_id is not None:
|
||||
! tx.category_id = data.category_id
|
||||
! if data.description is not None:
|
||||
! tx.description = data.description
|
||||
! if data.transaction_date is not None:
|
||||
! tx.transaction_date = data.transaction_date
|
||||
! tx.updated_at = utcnow()
|
||||
|
||||
! await session.commit()
|
||||
|
||||
# Reload with fresh category
|
||||
! result = await session.execute(
|
||||
! select(Transaction)
|
||||
! .options(selectinload(Transaction.category))
|
||||
! .where(Transaction.id == tx.id)
|
||||
! )
|
||||
! return result.scalar_one()
|
||||
|
||||
|
||||
> async def delete_transaction(
|
||||
> session: AsyncSession,
|
||||
> user_id: uuid.UUID,
|
||||
> transaction_id: uuid.UUID,
|
||||
> ) -> None:
|
||||
> tx = await get_transaction(session, user_id, transaction_id)
|
||||
! tx.deleted_at = utcnow()
|
||||
! tx.updated_at = utcnow()
|
||||
! await session.commit()
|
||||
@@ -0,0 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return current UTC datetime without timezone info (stored as naive UTC)."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
Reference in New Issue
Block a user