From bd9e553e35c562e80007e1c057e6aa245b3a417f Mon Sep 17 00:00:00 2001 From: Daniel Roth Date: Fri, 13 Feb 2026 14:50:48 +0000 Subject: [PATCH] bulk update of plans --- .../db/functions/recommendations_functions.py | 65 ++++++++++++------- backend/categorisation/processor.py | 23 +++---- 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/backend/app/db/functions/recommendations_functions.py b/backend/app/db/functions/recommendations_functions.py index 6816e25b..e690991a 100644 --- a/backend/app/db/functions/recommendations_functions.py +++ b/backend/app/db/functions/recommendations_functions.py @@ -1,6 +1,6 @@ -from typing import Any, List, Optional -from sqlalchemy import text, insert, delete, select, update -from sqlalchemy.orm import Session +from typing import Any, Dict, List, Optional +from sqlalchemy import inspect, text, insert, delete, select, update +from sqlalchemy.orm import Session, Mapper from sqlalchemy.exc import SQLAlchemyError from sqlmodel import Session @@ -632,30 +632,45 @@ def get_scenario(scenario_id: int) -> Optional[ScenarioModel]: return session_any.exec(stmt).scalar_one_or_none() -def update_plan(plan_model: PlanModel, scenario_model: ScenarioModel) -> bool: +def bulk_update_plans( + plan_models: List[PlanModel], + scenario_models: List[ScenarioModel], +) -> int: + if not plan_models: + return 0 + with db_read_session() as session: - plan_values = { - c.name: getattr(plan_model, c.name) - for c in plan_model.__table__.columns - if c.name != "id" - } - scenario_values = { - c.name: getattr(scenario_model, c.name) - for c in scenario_model.__table__.columns - if c.name not in {"id", "portfolio_id"} - } - plan_stmt = ( - update(PlanModel).where(PlanModel.id == plan_model.id).values(**plan_values) - ) - plan_result = session.exec(plan_stmt) + plan_mapper: Mapper[Any] = inspect(PlanModel) + scenario_mapper: Mapper[Any] = inspect(ScenarioModel) - scenario_stmt = ( - update(ScenarioModel) - .where(ScenarioModel.id == scenario_model.id) - .values(**scenario_values) - ) - session.exec(scenario_stmt) + plan_mappings: List[Dict[str, Any]] = ( + [] + ) # Typehint as Any to satisfy Pylance... + for plan in plan_models: + data: Dict[str, Any] = { + c.name: getattr(plan, c.name) + for c in plan.__table__.columns + if c.name != "id" + } + data["id"] = plan.id + plan_mappings.append(data) + + session.bulk_update_mappings(plan_mapper, plan_mappings) + + scenario_mappings: List[Dict[str, Any]] = ( + [] + ) # Typehint as Any to satisfy Pylance... + for scenario in scenario_models: + data: Dict[str, Any] = { + c.name: getattr(scenario, c.name) + for c in scenario.__table__.columns + if c.name not in {"id", "portfolio_id"} + } + data["id"] = scenario.id + scenario_mappings.append(data) + + session.bulk_update_mappings(scenario_mapper, scenario_mappings) session.commit() - return plan_result.rowcount > 0 + return len(plan_models) diff --git a/backend/categorisation/processor.py b/backend/categorisation/processor.py index 704dfc07..445bbbc4 100644 --- a/backend/categorisation/processor.py +++ b/backend/categorisation/processor.py @@ -1,10 +1,10 @@ from collections import defaultdict -from typing import Dict, List, Tuple, cast +from typing import Dict, List from backend.app.db.functions.recommendations_functions import ( + bulk_update_plans, get_plans_by_portfolio_id, get_scenario, - update_plan, ) from backend.app.db.models.recommendations import PlanModel, ScenarioModel from backend.app.domain.classes.plan import Plan @@ -73,18 +73,13 @@ def _choose_cheapest_relevant_plan(plans: List[Plan]) -> Plan: def _update_default_flags(plans: List[Plan], cheapest_plan: Plan) -> None: + plan_models: List[PlanModel] = [] + scenario_models: List[ScenarioModel] = [] + for plan in plans: - if plan.id is None: - raise ValueError("Cannot update Plan with missing ID") - plan.set_default(plan.id == cheapest_plan.id) - print( - f"Setting plan of id {plan.id}, scenario name {plan.scenario.record.name} to is_default value {plan.id == cheapest_plan.id}" - ) + plan_model, scenario_model = plan.to_sqlalchemy() + plan_models.append(plan_model) + scenario_models.append(scenario_model) - plan_model, scenario_model = cast( - Tuple[PlanModel, ScenarioModel], - plan.to_sqlalchemy(), - ) - - update_plan(plan_model, scenario_model) + bulk_update_plans(plan_models, scenario_models)