153 lines
4.4 KiB
Python
153 lines
4.4 KiB
Python
import uuid
|
|
from datetime import date
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.models.transaction import Transaction, TransactionType
|
|
from app.schemas.transaction import TransactionCreate, TransactionUpdate
|
|
from app.utils import utcnow
|
|
|
|
|
|
def _month_bounds(month: str) -> tuple[date, date]:
|
|
"""Return (start_inclusive, end_exclusive) date bounds for a YYYY-MM string."""
|
|
year, mon = map(int, month.split("-"))
|
|
start = date(year, mon, 1)
|
|
if mon == 12:
|
|
end = date(year + 1, 1, 1)
|
|
else:
|
|
end = date(year, mon + 1, 1)
|
|
return start, end
|
|
|
|
|
|
async def create_transaction(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
data: TransactionCreate,
|
|
) -> Transaction:
|
|
tx = Transaction(
|
|
user_id=user_id,
|
|
category_id=data.category_id,
|
|
amount_cents=data.amount_cents,
|
|
type=data.type,
|
|
description=data.description,
|
|
transaction_date=data.transaction_date,
|
|
)
|
|
session.add(tx)
|
|
await session.commit()
|
|
|
|
# Reload with category relationship
|
|
result = await session.execute(
|
|
select(Transaction)
|
|
.options(selectinload(Transaction.category))
|
|
.where(Transaction.id == tx.id)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
|
|
async def list_transactions(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
*,
|
|
month: str | None = None,
|
|
category_id: uuid.UUID | None = None,
|
|
type_filter: TransactionType | None = None,
|
|
page: int = 1,
|
|
per_page: int = 20,
|
|
) -> tuple[list[Transaction], int]:
|
|
base_query = select(Transaction).where(
|
|
Transaction.user_id == user_id,
|
|
Transaction.deleted_at.is_(None),
|
|
)
|
|
|
|
if month is not None:
|
|
start, end = _month_bounds(month)
|
|
base_query = base_query.where(
|
|
Transaction.transaction_date >= start,
|
|
Transaction.transaction_date < end,
|
|
)
|
|
if category_id is not None:
|
|
base_query = base_query.where(Transaction.category_id == category_id)
|
|
if type_filter is not None:
|
|
base_query = base_query.where(Transaction.type == type_filter)
|
|
|
|
# Total count (before pagination)
|
|
count_result = await session.execute(
|
|
select(func.count()).select_from(base_query.subquery())
|
|
)
|
|
total = count_result.scalar_one()
|
|
|
|
# Paginated rows with category eager-loaded
|
|
items_query = (
|
|
base_query.options(selectinload(Transaction.category))
|
|
.order_by(Transaction.transaction_date.desc(), Transaction.created_at.desc())
|
|
.offset((page - 1) * per_page)
|
|
.limit(per_page)
|
|
)
|
|
result = await session.execute(items_query)
|
|
return list(result.scalars().all()), total
|
|
|
|
|
|
async def get_transaction(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
transaction_id: uuid.UUID,
|
|
) -> Transaction:
|
|
result = await session.execute(
|
|
select(Transaction)
|
|
.options(selectinload(Transaction.category))
|
|
.where(
|
|
Transaction.id == transaction_id,
|
|
Transaction.user_id == user_id,
|
|
Transaction.deleted_at.is_(None),
|
|
)
|
|
)
|
|
tx = result.scalar_one_or_none()
|
|
if tx is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Transaction not found")
|
|
return tx
|
|
|
|
|
|
async def update_transaction(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
transaction_id: uuid.UUID,
|
|
data: TransactionUpdate,
|
|
) -> Transaction:
|
|
tx = await get_transaction(session, user_id, transaction_id)
|
|
|
|
if data.amount_cents is not None:
|
|
tx.amount_cents = data.amount_cents
|
|
if data.type is not None:
|
|
tx.type = data.type
|
|
if data.category_id is not None:
|
|
tx.category_id = data.category_id
|
|
if data.description is not None:
|
|
tx.description = data.description
|
|
if data.transaction_date is not None:
|
|
tx.transaction_date = data.transaction_date
|
|
tx.updated_at = utcnow()
|
|
|
|
await session.commit()
|
|
|
|
# Reload with fresh category
|
|
result = await session.execute(
|
|
select(Transaction)
|
|
.options(selectinload(Transaction.category))
|
|
.where(Transaction.id == tx.id)
|
|
)
|
|
return result.scalar_one()
|
|
|
|
|
|
async def delete_transaction(
|
|
session: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
transaction_id: uuid.UUID,
|
|
) -> None:
|
|
tx = await get_transaction(session, user_id, transaction_id)
|
|
tx.deleted_at = utcnow()
|
|
tx.updated_at = utcnow()
|
|
await session.commit()
|