diff --git a/backend/app/db/functions/recommendations_functions.py b/backend/app/db/functions/recommendations_functions.py index 141ba2dd..09d6da83 100644 --- a/backend/app/db/functions/recommendations_functions.py +++ b/backend/app/db/functions/recommendations_functions.py @@ -673,22 +673,17 @@ def get_scenarios_by_portfolio_id(portfolio_id: int) -> List[ScenarioModel]: return session_any.exec(stmt).scalars().all() -def get_default_plans_and_scenarios( +def get_default_plans( portfolio_id: int, -) -> Tuple[List[PlanModel], List[ScenarioModel]]: +) -> List[PlanModel]: plan_stmt = select(PlanModel).where( (PlanModel.portfolio_id == portfolio_id) & (PlanModel.is_default == True) ) - scenario_stmt = select(ScenarioModel).where( - (ScenarioModel.portfolio_id == portfolio_id) - & (ScenarioModel.is_default == True) - ) with db_read_session() as session: session_any: Any = session # Typehint as Any to satisfy Pylance... plans: List[PlanModel] = session_any.exec(plan_stmt).scalars().all() - scenarios: List[ScenarioModel] = session_any.exec(scenario_stmt).scalars().all() - return (plans, scenarios) + return plans def bulk_update_plans( diff --git a/backend/categorisation/processor.py b/backend/categorisation/processor.py index 95d4de3a..09db2983 100644 --- a/backend/categorisation/processor.py +++ b/backend/categorisation/processor.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from backend.app.db.functions.recommendations_functions import ( bulk_update_plans, - get_default_plans_and_scenarios, + get_default_plans, get_most_recent_plans_by_portfolio_id, get_most_recent_plans_by_scenario_ids, get_scenarios_by_portfolio_id, @@ -23,6 +23,7 @@ def process_portfolio( ) -> None: # TODO: make this a class logger.info(f"Processing portfolio {portfolio_id}") + all_scenarios: List[Scenario] = _load_scenarios_for_portfolio(portfolio_id) plans_by_id: Dict[int, Plan] = {} # TODO: make this an in-memory repository class if scenarios_to_consider: @@ -33,14 +34,14 @@ def process_portfolio( # first get all plans that we're interested in plans_for_consideration: List[Plan] = _load_plans_for_portfolio( - portfolio_id, scenarios_to_consider + portfolio_id, all_scenarios, scenarios_to_consider ) for plan in plans_for_consideration: if plan.id is not None: # just in case plans_by_id[plan.id] = plan # then unset existing defaults on domain objects regardless of whether they're under consideration or not - default_plans: List[Plan] = _get_default_plans(portfolio_id) + default_plans: List[Plan] = _get_default_plans(portfolio_id, all_scenarios) for plan in default_plans: plan.set_default(False) if plan.id is not None: # just in case @@ -108,26 +109,28 @@ def choose_cheapest_relevant_plan( return cheapest_plans[0] -def _get_default_plans(portfolio_id: int) -> List[Plan]: - default_plan_models, default_scenario_models = get_default_plans_and_scenarios( - portfolio_id - ) +def _get_default_plans(portfolio_id: int, scenarios: List[Scenario]) -> List[Plan]: + default_plan_models = get_default_plans(portfolio_id) + + scenario_map = {s.id: s for s in scenarios} return [ - Plan.from_sqlalchemy( - p, - next( - Scenario.from_sqlalchemy(s) - for s in default_scenario_models - if s.id == p.scenario_id - ), - ) + Plan.from_sqlalchemy(p, scenario_map[p.scenario_id]) for p in default_plan_models + if p.scenario_id in scenario_map ] +def _load_scenarios_for_portfolio(portfolio_id: int) -> List[Scenario]: + scenario_models: List[ScenarioModel] = get_scenarios_by_portfolio_id(portfolio_id) + + return [Scenario.from_sqlalchemy(s) for s in scenario_models] + + def _load_plans_for_portfolio( - portfolio_id: int, scenarios_to_consider: Optional[List[int]] = None + portfolio_id: int, + all_scenarios: List[Scenario], + scenarios_to_consider: Optional[List[int]] = None, ) -> List[Plan]: if scenarios_to_consider: @@ -146,21 +149,17 @@ def _load_plans_for_portfolio( plans: List[Plan] = [] - scenarios: List[ScenarioModel] = get_scenarios_by_portfolio_id(portfolio_id) - - if not scenarios: + if not all_scenarios: raise Exception(f"No scenarios found for Portfolio {portfolio_id}") for model in plan_models: - scenario_model = next((s for s in scenarios if s.id == model.scenario_id)) - if not scenario_model: + scenario = next((s for s in all_scenarios if s.id == model.scenario_id)) + if not scenario: logger.info(f"No Scenario associated with Plan of ID {model.id}") continue - plans.append( - Plan.from_sqlalchemy(model, Scenario.from_sqlalchemy(scenario_model)) - ) + plans.append(Plan.from_sqlalchemy(model, scenario)) logger.info(f"Got {len(plans)} Plans") return plans