Files
2026-03-17 16:16:08 +00:00

153 lines
4.4 KiB
Python

import uuid
from datetime import date
from fastapi import HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.transaction import Transaction, TransactionType
from app.schemas.transaction import TransactionCreate, TransactionUpdate
from app.utils import utcnow
def _month_bounds(month: str) -> tuple[date, date]:
"""Return (start_inclusive, end_exclusive) date bounds for a YYYY-MM string."""
year, mon = map(int, month.split("-"))
start = date(year, mon, 1)
if mon == 12:
end = date(year + 1, 1, 1)
else:
end = date(year, mon + 1, 1)
return start, end
async def create_transaction(
session: AsyncSession,
user_id: uuid.UUID,
data: TransactionCreate,
) -> Transaction:
tx = Transaction(
user_id=user_id,
category_id=data.category_id,
amount_cents=data.amount_cents,
type=data.type,
description=data.description,
transaction_date=data.transaction_date,
)
session.add(tx)
await session.commit()
# Reload with category relationship
result = await session.execute(
select(Transaction)
.options(selectinload(Transaction.category))
.where(Transaction.id == tx.id)
)
return result.scalar_one()
async def list_transactions(
session: AsyncSession,
user_id: uuid.UUID,
*,
month: str | None = None,
category_id: uuid.UUID | None = None,
type_filter: TransactionType | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[Transaction], int]:
base_query = select(Transaction).where(
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
)
if month is not None:
start, end = _month_bounds(month)
base_query = base_query.where(
Transaction.transaction_date >= start,
Transaction.transaction_date < end,
)
if category_id is not None:
base_query = base_query.where(Transaction.category_id == category_id)
if type_filter is not None:
base_query = base_query.where(Transaction.type == type_filter)
# Total count (before pagination)
count_result = await session.execute(
select(func.count()).select_from(base_query.subquery())
)
total = count_result.scalar_one()
# Paginated rows with category eager-loaded
items_query = (
base_query.options(selectinload(Transaction.category))
.order_by(Transaction.transaction_date.desc(), Transaction.created_at.desc())
.offset((page - 1) * per_page)
.limit(per_page)
)
result = await session.execute(items_query)
return list(result.scalars().all()), total
async def get_transaction(
session: AsyncSession,
user_id: uuid.UUID,
transaction_id: uuid.UUID,
) -> Transaction:
result = await session.execute(
select(Transaction)
.options(selectinload(Transaction.category))
.where(
Transaction.id == transaction_id,
Transaction.user_id == user_id,
Transaction.deleted_at.is_(None),
)
)
tx = result.scalar_one_or_none()
if tx is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Transaction not found")
return tx
async def update_transaction(
session: AsyncSession,
user_id: uuid.UUID,
transaction_id: uuid.UUID,
data: TransactionUpdate,
) -> Transaction:
tx = await get_transaction(session, user_id, transaction_id)
if data.amount_cents is not None:
tx.amount_cents = data.amount_cents
if data.type is not None:
tx.type = data.type
if data.category_id is not None:
tx.category_id = data.category_id
if data.description is not None:
tx.description = data.description
if data.transaction_date is not None:
tx.transaction_date = data.transaction_date
tx.updated_at = utcnow()
await session.commit()
# Reload with fresh category
result = await session.execute(
select(Transaction)
.options(selectinload(Transaction.category))
.where(Transaction.id == tx.id)
)
return result.scalar_one()
async def delete_transaction(
session: AsyncSession,
user_id: uuid.UUID,
transaction_id: uuid.UUID,
) -> None:
tx = await get_transaction(session, user_id, transaction_id)
tx.deleted_at = utcnow()
tx.updated_at = utcnow()
await session.commit()