import uuid from datetime import date from fastapi import HTTPException, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.transaction import Transaction, TransactionType from app.schemas.transaction import TransactionCreate, TransactionUpdate from app.utils import utcnow def _month_bounds(month: str) -> tuple[date, date]: """Return (start_inclusive, end_exclusive) date bounds for a YYYY-MM string.""" year, mon = map(int, month.split("-")) start = date(year, mon, 1) if mon == 12: end = date(year + 1, 1, 1) else: end = date(year, mon + 1, 1) return start, end async def create_transaction( session: AsyncSession, user_id: uuid.UUID, data: TransactionCreate, ) -> Transaction: tx = Transaction( user_id=user_id, category_id=data.category_id, amount_cents=data.amount_cents, type=data.type, description=data.description, transaction_date=data.transaction_date, ) session.add(tx) await session.commit() # Reload with category relationship result = await session.execute( select(Transaction) .options(selectinload(Transaction.category)) .where(Transaction.id == tx.id) ) return result.scalar_one() async def list_transactions( session: AsyncSession, user_id: uuid.UUID, *, month: str | None = None, category_id: uuid.UUID | None = None, type_filter: TransactionType | None = None, page: int = 1, per_page: int = 20, ) -> tuple[list[Transaction], int]: base_query = select(Transaction).where( Transaction.user_id == user_id, Transaction.deleted_at.is_(None), ) if month is not None: start, end = _month_bounds(month) base_query = base_query.where( Transaction.transaction_date >= start, Transaction.transaction_date < end, ) if category_id is not None: base_query = base_query.where(Transaction.category_id == category_id) if type_filter is not None: base_query = base_query.where(Transaction.type == type_filter) # Total count (before pagination) count_result = await session.execute( select(func.count()).select_from(base_query.subquery()) ) total = count_result.scalar_one() # Paginated rows with category eager-loaded items_query = ( base_query.options(selectinload(Transaction.category)) .order_by(Transaction.transaction_date.desc(), Transaction.created_at.desc()) .offset((page - 1) * per_page) .limit(per_page) ) result = await session.execute(items_query) return list(result.scalars().all()), total async def get_transaction( session: AsyncSession, user_id: uuid.UUID, transaction_id: uuid.UUID, ) -> Transaction: result = await session.execute( select(Transaction) .options(selectinload(Transaction.category)) .where( Transaction.id == transaction_id, Transaction.user_id == user_id, Transaction.deleted_at.is_(None), ) ) tx = result.scalar_one_or_none() if tx is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Transaction not found") return tx async def update_transaction( session: AsyncSession, user_id: uuid.UUID, transaction_id: uuid.UUID, data: TransactionUpdate, ) -> Transaction: tx = await get_transaction(session, user_id, transaction_id) if data.amount_cents is not None: tx.amount_cents = data.amount_cents if data.type is not None: tx.type = data.type if data.category_id is not None: tx.category_id = data.category_id if data.description is not None: tx.description = data.description if data.transaction_date is not None: tx.transaction_date = data.transaction_date tx.updated_at = utcnow() await session.commit() # Reload with fresh category result = await session.execute( select(Transaction) .options(selectinload(Transaction.category)) .where(Transaction.id == tx.id) ) return result.scalar_one() async def delete_transaction( session: AsyncSession, user_id: uuid.UUID, transaction_id: uuid.UUID, ) -> None: tx = await get_transaction(session, user_id, transaction_id) tx.deleted_at = utcnow() tx.updated_at = utcnow() await session.commit()