fixes so it runs (as far as the database update), plus some temp prints

This commit is contained in:
Daniel Roth 2026-02-13 12:26:31 +00:00
parent 561594a6ca
commit e0e50d696a
6 changed files with 59 additions and 29 deletions

View file

@ -622,7 +622,7 @@ def get_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]:
stmt = select(PlanModel).where(PlanModel.portfolio_id == portfolio_id)
with db_read_session() as session:
session_any: Any = session # Typehint as Any to satisfy Pylance...
return session_any.exec(stmt).all()
return session_any.exec(stmt).scalars().all()
def get_scenario(scenario_id: int) -> Optional[ScenarioModel]:

View file

@ -1,4 +1,4 @@
from typing import Iterable, Optional
from typing import Iterable, List, NamedTuple, Optional, Type
from sqlalchemy import (
Column,
BigInteger,
@ -22,6 +22,10 @@ import enum
Base = declarative_base()
def portfolio_goal_values(enum_cls: Type[PortfolioGoal]) -> List[str]:
return [e.value for e in enum_cls]
class Recommendation(Base):
__tablename__ = "recommendation"
@ -152,7 +156,10 @@ class ScenarioModel(Base):
BigInteger, ForeignKey(Portfolio.id), nullable=False
)
housing_type: Mapped[str] = mapped_column(String, nullable=False)
goal: Mapped[PortfolioGoal] = mapped_column(Enum(PortfolioGoal), nullable=False)
goal: Mapped[PortfolioGoal] = mapped_column(
Enum(PortfolioGoal, values_callable=portfolio_goal_values, name="goal"),
nullable=False,
)
goal_value: Mapped[str] = mapped_column(String, nullable=False)
trigger_file_path: Mapped[str] = mapped_column(String, nullable=False)
already_installed_file_path: Mapped[Optional[str]] = mapped_column(String)
@ -252,3 +259,8 @@ class InstalledMeasure(Base):
def enum_values(e: Iterable[PlanTypeEnum]) -> list[str]:
return [m.value for m in e]
class PlanPersistence(NamedTuple):
plan: PlanModel
scenario: ScenarioModel

View file

@ -5,7 +5,11 @@ from typing import Optional
from sqlalchemy import Tuple
from backend.app.db.models.portfolio import PortfolioGoal
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.db.models.recommendations import (
PlanModel,
PlanPersistence,
ScenarioModel,
)
from backend.app.domain.classes.scenario import Scenario
from backend.app.domain.records.plan_record import PlanRecord
from backend.app.utils import sap_to_epc
@ -58,7 +62,7 @@ class Plan:
case _:
raise NotImplementedError
def to_sqlalchemy(self) -> Tuple[PlanModel, ScenarioModel]:
def to_sqlalchemy(self) -> PlanPersistence:
scenario_record = self.scenario.record
scenario_model = ScenarioModel(
@ -129,7 +133,7 @@ class Plan:
contingency_cost=record.contingency_cost,
)
return Tuple(plan_model, scenario_model) # TODO: create a type for this
return PlanPersistence(plan=plan_model, scenario=scenario_model)
def set_default(self, value: bool) -> None:
self.record = replace(self.record, is_default=value)

View file

@ -1,12 +0,0 @@
from typing import List
from backend.app.domain.classes.plan import Plan
class CategorisationLogic:
@staticmethod
def get_compliant_plans(plans: List[Plan]) -> List[Plan]:
raise NotImplementedError
@staticmethod
def get_cheapest_plan(plans: List[Plan]) -> Plan:
raise NotImplementedError

View file

@ -1,5 +1,10 @@
from backend.categorisation.processor import process_portfolio
def main() -> None:
pass
portfolio_id = 556
process_portfolio(portfolio_id)
if __name__ == "__main__":

View file

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List, Tuple, cast
from typing import Dict, List, Tuple, cast
from backend.app.db.functions.recommendations_functions import (
get_plans_by_portfolio_id,
@ -8,23 +8,30 @@ from backend.app.db.functions.recommendations_functions import (
)
from backend.app.db.models.recommendations import PlanModel, ScenarioModel
from backend.app.domain.classes.plan import Plan
from backend.categorisation.categorisation_logic import CategorisationLogic
from backend.app.domain.classes.scenario import Scenario
from utils.logger import setup_logger
logger = setup_logger()
def process_portfolio(portfolio_id: int) -> None:
plans = _load_plans_for_portfolio(portfolio_id)
plans_by_property = _group_plans_by_property(plans)
print(f"Processing portfolio {portfolio_id}")
plans: List[Plan] = _load_plans_for_portfolio(portfolio_id)
plans_by_property: Dict[int, List[Plan]] = _group_plans_by_property(plans)
for uprn, property_plans in plans_by_property.items():
if not property_plans:
raise ValueError(f"No plans for property {uprn}")
for property_plans in plans_by_property.values():
cheapest_plan = _choose_cheapest_relevant_plan(property_plans)
_update_default_flags(property_plans, cheapest_plan)
def _load_plans_for_portfolio(portfolio_id: int) -> List[Plan]:
plan_models = get_plans_by_portfolio_id(portfolio_id)
print(f"Got {len(plan_models)} plans from database")
plans: List[Plan] = []
for model in plan_models:
@ -33,12 +40,15 @@ def _load_plans_for_portfolio(portfolio_id: int) -> List[Plan]:
continue
scenario_model = get_scenario(model.scenario_id)
plans.append(Plan.from_sqlalchemy(model, scenario_model))
plans.append(
Plan.from_sqlalchemy(model, Scenario.from_sqlalchemy(scenario_model))
)
print("Successfully mapped plan and scenario to domain object")
return plans
def _group_plans_by_property(plans: List[Plan]) -> dict[int, List[Plan]]:
def _group_plans_by_property(plans: List[Plan]) -> Dict[int, List[Plan]]:
grouped: dict[int, List[Plan]] = defaultdict(list)
for plan in plans:
@ -48,10 +58,18 @@ def _group_plans_by_property(plans: List[Plan]) -> dict[int, List[Plan]]:
def _choose_cheapest_relevant_plan(plans: List[Plan]) -> Plan:
compliant_plans = CategorisationLogic.get_compliant_plans(plans)
plans_to_consider: List[Plan] = [p for p in plans if p.is_compliant] or plans
plans_to_consider = compliant_plans or plans
return CategorisationLogic.get_cheapest_plan(plans_to_consider)
def plan_cost(plan: Plan) -> float:
return (
plan.record.cost_of_works
if plan.record.cost_of_works is not None
else float("inf")
)
cheapest_plan = min(plans_to_consider, key=plan_cost)
return cheapest_plan
def _update_default_flags(plans: List[Plan], cheapest_plan: Plan) -> None:
@ -60,6 +78,9 @@ def _update_default_flags(plans: List[Plan], cheapest_plan: Plan) -> None:
raise ValueError("Cannot update Plan with missing ID")
plan.set_default(plan.id == cheapest_plan.id)
print(
f"Setting plan of id {plan.id}, scenario name {plan.scenario.record.name} to is_default value {plan.id == cheapest_plan.id}"
)
plan_model, scenario_model = cast(
Tuple[PlanModel, ScenarioModel],