From 46ca714ef992726a95e6b2a6e6aed85726ab399e Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Mon, 29 Jun 2026 15:06:47 +0000 Subject: [PATCH] =?UTF-8?q?Batch=20plan=20saves=20reduce=20RDS=20CPU=20dur?= =?UTF-8?q?ing=20bulk=20modelling=20runs=20=F0=9F=9F=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- repositories/plan/plan_postgres_repository.py | 75 ++++++++++++++++++- .../repositories/plan/test_plan_batch_save.py | 2 +- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/repositories/plan/plan_postgres_repository.py b/repositories/plan/plan_postgres_repository.py index 36253764..9bf22aea 100644 --- a/repositories/plan/plan_postgres_repository.py +++ b/repositories/plan/plan_postgres_repository.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from sqlalchemy import insert as _sa_insert from sqlmodel import Session, col, update from domain.modelling.plan import Plan @@ -7,6 +10,15 @@ from infrastructure.postgres.modelling import PlanModel, RecommendationModel from repositories.plan.plan_repository import PlanRepository, PlanSaveRequest +def _col_values(model: Any, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract column-keyed values from a SQLModel instance for Core INSERT.""" + return { + c.name: getattr(model, c.name) + for c in model.__table__.c + if c.name not in exclude + } + + class PlanPostgresRepository(PlanRepository): """Maps a Plan and its Plan Measures onto the live ``plan`` / ``recommendation`` tables (ADR-0017). Does not commit — the Unit of Work @@ -65,4 +77,65 @@ class PlanPostgresRepository(PlanRepository): return plan_row.id def save_batch(self, requests: list[PlanSaveRequest]) -> list[int]: - raise NotImplementedError + """Persist all Plans in three statements regardless of batch size. + + 1. One demote UPDATE (only when any request has ``is_default=True``). + 2. One bulk plan INSERT with RETURNING to capture ids positionally. + 3. One bulk recommendation INSERT (skipped when no measures exist). + """ + if not requests: + return [] + + # Demote prior default Plans for every property in the batch that is + # receiving a new default Plan — one UPDATE for the whole batch. + default_pids = [r.property_id for r in requests if r.is_default] + if default_pids: + # scenario_id is uniform per batch (one scenario per SQS message). + scenario_id = requests[0].scenario_id + self._session.exec( # type: ignore[call-overload] + update(PlanModel) + .where( + col(PlanModel.property_id).in_(default_pids), + col(PlanModel.scenario_id) == scenario_id, + ) + .values(is_default=False) + ) + + # Bulk INSERT all plan rows; capture returned ids positionally. + plan_rows = [ + _col_values( + PlanModel.from_domain( + r.plan, + property_id=r.property_id, + scenario_id=r.scenario_id, + portfolio_id=r.portfolio_id, + is_default=r.is_default, + ), + exclude=frozenset({"id"}), + ) + for r in requests + ] + returned = self._session.execute( # type: ignore[deprecated] + _sa_insert(PlanModel).returning(PlanModel.__table__.c["id"]), # type: ignore[attr-defined] + plan_rows, + ).all() + plan_ids = [row[0] for row in returned] + + # Accumulate recommendation rows across all requests; properties with + # zero measures contribute nothing (no special-casing needed). + rec_rows = [ + _col_values( + RecommendationModel.from_domain( + measure, property_id=r.property_id, plan_id=plan_id + ), + exclude=frozenset({"id"}), + ) + for r, plan_id in zip(requests, plan_ids) + for measure in r.plan.measures + ] + if rec_rows: + self._session.execute( # type: ignore[deprecated] + _sa_insert(RecommendationModel), rec_rows + ) + + return plan_ids diff --git a/tests/repositories/plan/test_plan_batch_save.py b/tests/repositories/plan/test_plan_batch_save.py index e82d0f1e..400b2cdf 100644 --- a/tests/repositories/plan/test_plan_batch_save.py +++ b/tests/repositories/plan/test_plan_batch_save.py @@ -86,7 +86,7 @@ def test_single_request_save_batch_matches_save(db_engine: Engine) -> None: assert recs_save[0].type == recs_batch[0].type assert recs_save[0].estimated_cost == recs_batch[0].estimated_cost assert recs_save[0].sap_points == recs_batch[0].sap_points - assert recs_save[0].plan_id == batch_id + assert recs_batch[0].plan_id == batch_id # ---------------------------------------------------------------------------