ensure all defaults are unset before setting new ones, refactor of processor

This commit is contained in:
Daniel Roth 2026-02-20 15:26:40 +00:00
parent ec01e1d190
commit 96fbd7f24c
2 changed files with 98 additions and 104 deletions

View file

@ -1,5 +1,5 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import inspect, text, insert, delete, select, update
from typing import Any, Dict, List, Tuple
from sqlalchemy import inspect, text, insert, delete, select
from sqlalchemy.orm import Session, Mapper
from sqlalchemy.exc import SQLAlchemyError
from sqlmodel import Session
@ -618,13 +618,6 @@ def clear_portfolio_in_batches(
print("Portfolio cleared in batches.")
def get_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]:
stmt = select(PlanModel).where(PlanModel.portfolio_id == portfolio_id)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
return session_any.exec(stmt).scalars().all()
def get_plans_by_scenario_ids(ids: List[int]) -> List[PlanModel]:
stmt = select(PlanModel).where(PlanModel.scenario_id.in_(ids))
with db_read_session() as session:
@ -632,13 +625,36 @@ def get_plans_by_scenario_ids(ids: List[int]) -> List[PlanModel]:
return session_any.exec(stmt).scalars().all()
def get_most_recent_plan_ids_by_scenario_ids(scenario_ids: List[int]) -> List[int]:
def get_most_recent_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]:
# NOTE: This statement works for Postgres only, because of the Distinct
stmt = (
select(PlanModel.id)
.where(PlanModel.scenario_id.in_(scenario_ids))
.distinct(PlanModel.scenario_id)
select(PlanModel)
.where(PlanModel.portfolio_id == portfolio_id)
.distinct(
PlanModel.property_id, PlanModel.scenario_id
) # one plan per property per scenario
.order_by(
PlanModel.property_id,
PlanModel.scenario_id,
PlanModel.created_at.desc(),
PlanModel.id.desc(),
)
)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
return session_any.exec(stmt).scalars().all()
def get_most_recent_plans_by_scenario_ids(scenario_ids: List[int]) -> List[PlanModel]:
# NOTE: This statement works for Postgres only, because of the Distinct
stmt = (
select(PlanModel)
.where(PlanModel.scenario_id.in_(scenario_ids))
.distinct(
PlanModel.property_id, PlanModel.scenario_id
) # one plan per property per scenario
.order_by(
PlanModel.property_id,
PlanModel.scenario_id,
PlanModel.created_at.desc(),
PlanModel.id.desc(),
@ -646,7 +662,7 @@ def get_most_recent_plan_ids_by_scenario_ids(scenario_ids: List[int]) -> List[in
)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
session_any: Any = session # Typehint as Any to satisfy Pylance
return session_any.exec(stmt).scalars().all()
@ -657,39 +673,22 @@ def get_scenarios_by_portfolio_id(portfolio_id: int) -> List[ScenarioModel]:
return session_any.exec(stmt).scalars().all()
def get_default_scenario_ids_for_portfolio(portfolio_id: int) -> List[int]:
# This should in reality always return exactly 1 ID, but there's currently
# no database constraint to enforce that, so account for 0 or >1
stmt = select(ScenarioModel.id).where(
def get_default_plans_and_scenarios(
portfolio_id: int,
) -> Tuple[List[PlanModel], List[ScenarioModel]]:
plan_stmt = select(PlanModel).where(
(PlanModel.portfolio_id == portfolio_id) & (PlanModel.is_default == True)
)
scenario_stmt = select(ScenarioModel).where(
(ScenarioModel.portfolio_id == portfolio_id)
& (ScenarioModel.is_default == True)
)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
return session_any.exec(stmt).scalars().all()
def set_plan_and_scenario_default(plan_id: int, default: bool) -> bool:
with db_session() as session:
plan: PlanModel = session.get(PlanModel, plan_id)
if not plan:
return False
scenario_id = plan.scenario_id
plan_mapper: Mapper[Any] = inspect(PlanModel)
scenario_mapper: Mapper[Any] = inspect(ScenarioModel)
plan_mappings: List[Dict[str, Any]] = [{"id": plan.id, "is_default": default}]
scenario_mappings: List[Dict[str, Any]] = [
{"id": scenario_id, "is_default": default}
]
session.bulk_update_mappings(plan_mapper, plan_mappings)
session.bulk_update_mappings(scenario_mapper, scenario_mappings)
session.commit()
return True
plans: List[PlanModel] = session_any.exec(plan_stmt).scalars().all()
scenarios: List[ScenarioModel] = session_any.exec(scenario_stmt).scalars().all()
return (plans, scenarios)
def bulk_update_plans(

View file

@ -1,14 +1,12 @@
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from backend.app.db.functions.recommendations_functions import (
bulk_update_plans,
get_default_scenario_ids_for_portfolio,
get_most_recent_plan_ids_by_scenario_ids,
get_plans_by_portfolio_id,
get_plans_by_scenario_ids,
get_default_plans_and_scenarios,
get_most_recent_plans_by_portfolio_id,
get_most_recent_plans_by_scenario_ids,
get_scenarios_by_portfolio_id,
set_plan_and_scenario_default,
)
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.domain.classes.plan import Plan
@ -22,29 +20,38 @@ def process_portfolio(
portfolio_id: int,
scenarios_to_consider: Optional[List[int]] = None,
scenario_priority_order: Optional[List[int]] = None,
) -> None:
) -> None: # TODO: make this a class
logger.info(f"Processing portfolio {portfolio_id}")
plans_by_id: Dict[int, Plan] = {} # TODO: make this an in-memory repository class
if scenarios_to_consider:
if len(scenarios_to_consider) < 2:
raise ValueError(
"Cannot run auto categorisation for fewer than 2 scenarios"
)
if scenarios_to_consider is not None:
_unset_defaults_for_scenarios_not_being_considered(
portfolio_id, scenarios_to_consider
)
# first get all plans that we're interested in
plans_for_consideration: List[Plan] = _load_plans_for_portfolio(
portfolio_id, scenarios_to_consider
)
for plan in plans_for_consideration:
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
plans: List[Plan] = _load_plans_for_portfolio(portfolio_id, scenarios_to_consider)
# then unset existing defaults on domain objects regardless of whether they're under consideration or not
default_plans: List[Plan] = _get_default_plans(portfolio_id)
for plan in default_plans:
plan.set_default(False)
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
plans_by_property: Dict[int, List[Plan]] = _group_plans_by_property(plans)
updated_plan_models: List[PlanModel] = []
updated_scenario_models: List[ScenarioModel] = []
for property_id, property_plans in plans_by_property.items():
# then set new defaults on domain objects under consideration
plans_for_consideration_by_property: Dict[int, List[Plan]] = (
_group_plans_by_property(plans_for_consideration)
)
for property_id, property_plans in plans_for_consideration_by_property.items():
if not property_plans:
raise ValueError(f"No plans for property {property_id}")
@ -56,17 +63,13 @@ def process_portfolio(
logger.error(f"Failed to find cheapest plan for property {property_id}")
raise
updated_property_plan_models, updated_property_scenario_models = (
_update_plan_and_scenario_objects(property_plans, cheapest_plan)
)
property_plans = _update_plan_objects(property_plans, cheapest_plan)
for plan in property_plans:
if plan.id is not None: # just in case
plans_by_id[plan.id] = plan
updated_plan_models.extend(updated_property_plan_models)
updated_scenario_models.extend(updated_property_scenario_models)
if len(updated_plan_models) > 0:
logger.info(f"Updating {len(updated_plan_models)} Plans in database")
bulk_update_plans(updated_plan_models, updated_scenario_models)
logger.info("Successfully updated Plan default values in database")
# then pass all domain objects to database to update (regardless of whether they've changed)
_update_plans_in_db(list(plans_by_id.values()))
def choose_cheapest_relevant_plan(
@ -100,29 +103,17 @@ def choose_cheapest_relevant_plan(
return cheapest_plans[0]
def _unset_defaults_for_scenarios_not_being_considered(
portfolio_id: int, scenarios_to_consider: List[int]
) -> None:
default_scenario_ids: List[int] = get_default_scenario_ids_for_portfolio(
def _get_default_plans(portfolio_id: int) -> List[Plan]:
default_plan_models, default_scenario_models = get_default_plans_and_scenarios(
portfolio_id
)
scenarios_to_unset_default: List[int] = []
for id in default_scenario_ids:
if id not in scenarios_to_consider:
scenarios_to_unset_default.append(id)
if len(scenarios_to_unset_default) > 0:
logger.info(
f"Unsetting {scenarios_to_unset_default} as default scenario(s) as not included in provided list of scenarios to consider"
return [
Plan.from_sqlalchemy(
p, next(s for s in default_scenario_models if s.id == p.scenario_id)
)
if len(scenarios_to_unset_default) > 0:
plans_to_unset_default: List[int] = get_most_recent_plan_ids_by_scenario_ids(
scenarios_to_unset_default
)
for plan_id in plans_to_unset_default:
set_plan_and_scenario_default(plan_id, False) # TODO: do this in batch
for p in default_plan_models
]
def _load_plans_for_portfolio(
@ -131,13 +122,17 @@ def _load_plans_for_portfolio(
if scenarios_to_consider:
logger.info(f"Getting plans for {len(scenarios_to_consider)} scenarios")
plan_models: List[PlanModel] = get_plans_by_scenario_ids(scenarios_to_consider)
plan_models: List[PlanModel] = get_most_recent_plans_by_scenario_ids(
scenarios_to_consider
)
logger.info(f"Got {len(plan_models)} plan models from database")
else:
logger.info(
f"No list of Plans to consider provided. Getting all Plans for portfolio {portfolio_id}"
)
plan_models: List[PlanModel] = get_plans_by_portfolio_id(portfolio_id)
plan_models: List[PlanModel] = get_most_recent_plans_by_portfolio_id(
portfolio_id
)
plans: List[Plan] = []
@ -170,26 +165,26 @@ def _group_plans_by_property(plans: List[Plan]) -> Dict[int, List[Plan]]:
return grouped
def _update_plan_and_scenario_objects(
plans: List[Plan], cheapest_plan: Plan
) -> Tuple[List[PlanModel], List[ScenarioModel]]:
plans_to_update: List[Plan] = []
def _update_plan_objects(plans: List[Plan], cheapest_plan: Plan) -> List[Plan]:
for plan in plans:
should_be_default: bool = plan.id == cheapest_plan.id
if plan.record.is_default != should_be_default:
logger.info(
f"Setting Plan {plan.id} (Scenario Name: {plan.scenario.record.name}) to is_default: {should_be_default}"
)
plan.set_default(should_be_default)
plans_to_update.append(plan)
plan.set_default(should_be_default)
if should_be_default:
logger.debug(
f"Setting Plan {plan.id} (Scenario Name: {plan.scenario.record.name}) to default"
)
return plans
def _update_plans_in_db(plans: List[Plan]) -> None:
plan_models: List[PlanModel] = []
scenario_models: List[ScenarioModel] = []
for plan in plans_to_update:
for plan in plans:
plan_model, scenario_model = plan.to_sqlalchemy()
plan_models.append(plan_model)
scenario_models.append(scenario_model)
return (plan_models, scenario_models)
bulk_update_plans(plan_models, scenario_models)