bulk update of plans

This commit is contained in:
Daniel Roth 2026-02-13 14:50:48 +00:00
parent 8e5016978e
commit bd9e553e35
2 changed files with 49 additions and 39 deletions

View file

@ -1,6 +1,6 @@
from typing import Any, List, Optional
from sqlalchemy import text, insert, delete, select, update
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
from sqlalchemy import inspect, text, insert, delete, select, update
from sqlalchemy.orm import Session, Mapper
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Session
@ -632,30 +632,45 @@ def get_scenario(scenario_id: int) -> Optional[ScenarioModel]:
return session_any.exec(stmt).scalar_one_or_none()
def update_plan(plan_model: PlanModel, scenario_model: ScenarioModel) -> bool:
def bulk_update_plans(
plan_models: List[PlanModel],
scenario_models: List[ScenarioModel],
) -> int:
if not plan_models:
return 0
with db_read_session() as session:
plan_values = {
c.name: getattr(plan_model, c.name)
for c in plan_model.__table__.columns
if c.name != "id"
}
scenario_values = {
c.name: getattr(scenario_model, c.name)
for c in scenario_model.__table__.columns
if c.name not in {"id", "portfolio_id"}
}
plan_stmt = (
update(PlanModel).where(PlanModel.id == plan_model.id).values(**plan_values)
)
plan_result = session.exec(plan_stmt)
plan_mapper: Mapper[Any] = inspect(PlanModel)
scenario_mapper: Mapper[Any] = inspect(ScenarioModel)
scenario_stmt = (
update(ScenarioModel)
.where(ScenarioModel.id == scenario_model.id)
.values(**scenario_values)
)
session.exec(scenario_stmt)
plan_mappings: List[Dict[str, Any]] = (
[]
) # Typehint as Any to satisfy Pylance...
for plan in plan_models:
data: Dict[str, Any] = {
c.name: getattr(plan, c.name)
for c in plan.__table__.columns
if c.name != "id"
}
data["id"] = plan.id
plan_mappings.append(data)
session.bulk_update_mappings(plan_mapper, plan_mappings)
scenario_mappings: List[Dict[str, Any]] = (
[]
) # Typehint as Any to satisfy Pylance...
for scenario in scenario_models:
data: Dict[str, Any] = {
c.name: getattr(scenario, c.name)
for c in scenario.__table__.columns
if c.name not in {"id", "portfolio_id"}
}
data["id"] = scenario.id
scenario_mappings.append(data)
session.bulk_update_mappings(scenario_mapper, scenario_mappings)
session.commit()
return plan_result.rowcount > 0
return len(plan_models)

View file

@ -1,10 +1,10 @@
from collections import defaultdict
from typing import Dict, List, Tuple, cast
from typing import Dict, List
from backend.app.db.functions.recommendations_functions import (
bulk_update_plans,
get_plans_by_portfolio_id,
get_scenario,
update_plan,
)
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.domain.classes.plan import Plan
@ -73,18 +73,13 @@ def _choose_cheapest_relevant_plan(plans: List[Plan]) -> Plan:
def _update_default_flags(plans: List[Plan], cheapest_plan: Plan) -> None:
plan_models: List[PlanModel] = []
scenario_models: List[ScenarioModel] = []
for plan in plans:
if plan.id is None:
raise ValueError("Cannot update Plan with missing ID")
plan.set_default(plan.id == cheapest_plan.id)
print(
f"Setting plan of id {plan.id}, scenario name {plan.scenario.record.name} to is_default value {plan.id == cheapest_plan.id}"
)
plan_model, scenario_model = plan.to_sqlalchemy()
plan_models.append(plan_model)
scenario_models.append(scenario_model)
plan_model, scenario_model = cast(
Tuple[PlanModel, ScenarioModel],
plan.to_sqlalchemy(),
)
update_plan(plan_model, scenario_model)
bulk_update_plans(plan_models, scenario_models)