Batch plan saves reduce RDS CPU during bulk modelling runs 🟩

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Daniel Roth 2026-06-29 15:06:47 +00:00
parent 9c6b477025
commit 46ca714ef9
2 changed files with 75 additions and 2 deletions

View file

@ -1,5 +1,8 @@
from __future__ import annotations
from typing import Any
from sqlalchemy import insert as _sa_insert
from sqlmodel import Session, col, update
from domain.modelling.plan import Plan
@ -7,6 +10,15 @@ from infrastructure.postgres.modelling import PlanModel, RecommendationModel
from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest
def _col_values(model: Any, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
"""Extract column-keyed values from a SQLModel instance for Core INSERT."""
return {
c.name: getattr(model, c.name)
for c in model.__table__.c
if c.name not in exclude
}
class PlanPostgresRepository(PlanRepository):
"""Maps a Plan and its Plan Measures onto the live ``plan`` /
``recommendation`` tables (ADR-0017). Does not commit the Unit of Work
@ -65,4 +77,65 @@ class PlanPostgresRepository(PlanRepository):
return plan_row.id
def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]:
raise NotImplementedError
"""Persist all Plans in three statements regardless of batch size.
1. One demote UPDATE (only when any request has ``is_default=True``).
2. One bulk plan INSERT with RETURNING to capture ids positionally.
3. One bulk recommendation INSERT (skipped when no measures exist).
"""
if not requests:
return []
# Demote prior default Plans for every property in the batch that is
# receiving a new default Plan — one UPDATE for the whole batch.
default_pids = [r.property_id for r in requests if r.is_default]
if default_pids:
# scenario_id is uniform per batch (one scenario per SQS message).
scenario_id = requests[0].scenario_id
self._session.exec( # type: ignore[call-overload]
update(PlanModel)
.where(
col(PlanModel.property_id).in_(default_pids),
col(PlanModel.scenario_id) == scenario_id,
)
.values(is_default=False)
)
# Bulk INSERT all plan rows; capture returned ids positionally.
plan_rows = [
_col_values(
PlanModel.from_domain(
r.plan,
property_id=r.property_id,
scenario_id=r.scenario_id,
portfolio_id=r.portfolio_id,
is_default=r.is_default,
),
exclude=frozenset({"id"}),
)
for r in requests
]
returned = self._session.execute( # type: ignore[deprecated]
_sa_insert(PlanModel).returning(PlanModel.__table__.c["id"]), # type: ignore[attr-defined]
plan_rows,
).all()
plan_ids = [row[0] for row in returned]
# Accumulate recommendation rows across all requests; properties with
# zero measures contribute nothing (no special-casing needed).
rec_rows = [
_col_values(
RecommendationModel.from_domain(
measure, property_id=r.property_id, plan_id=plan_id
),
exclude=frozenset({"id"}),
)
for r, plan_id in zip(requests, plan_ids)
for measure in r.plan.measures
]
if rec_rows:
self._session.execute( # type: ignore[deprecated]
_sa_insert(RecommendationModel), rec_rows
)
return plan_ids

View file

@ -86,7 +86,7 @@ def test_single_request_save_batch_matches_save(db_engine: Engine) -> None:
assert recs_save[0].type == recs_batch[0].type
assert recs_save[0].estimated_cost == recs_batch[0].estimated_cost
assert recs_save[0].sap_points == recs_batch[0].sap_points
assert recs_save[0].plan_id == batch_id
assert recs_batch[0].plan_id == batch_id
# ---------------------------------------------------------------------------