diff --git a/backend/Outputs.py b/backend/Outputs.py index 7111e4d3..0a62cf95 100644 --- a/backend/Outputs.py +++ b/backend/Outputs.py @@ -11,7 +11,6 @@ from backend.app.db.models.portfolio import PropertyModel, PropertyDetailsEpcMod from backend.app.db.models.recommendations import ( Recommendation, PlanModel, - PlanRecommendations, ) @@ -124,20 +123,15 @@ class Outputs: return plans_data def get_recommendations_from_db(self, plan_ids): - # Get recommendations through PlanRecommendations for those plans and that are default + # Get default recommendations for those plans, linked by recommendation.plan_id recommendations_query = ( self.session.query(Recommendation, PlanModel.scenario_id) - .join( - PlanRecommendations, - Recommendation.id == PlanRecommendations.recommendation_id, - ) .join( PlanModel, - PlanModel.id - == PlanRecommendations.plan_id, # Join with Plan to access scenario_id + PlanModel.id == Recommendation.plan_id, # access scenario_id ) .filter( - PlanRecommendations.plan_id.in_(plan_ids), + Recommendation.plan_id.in_(plan_ids), Recommendation.default == True, # Filtering for default recommendations ) .all() diff --git a/backend/app/db/functions/portfolio_functions.py b/backend/app/db/functions/portfolio_functions.py index ae48afed..c9b15cd2 100644 --- a/backend/app/db/functions/portfolio_functions.py +++ b/backend/app/db/functions/portfolio_functions.py @@ -1,7 +1,6 @@ from sqlalchemy import func from backend.app.db.models.recommendations import ( PlanModel, - PlanRecommendations, Recommendation, ScenarioModel, ) @@ -26,11 +25,7 @@ def aggregate_portfolio_recommendations( ), func.sum(Recommendation.energy_cost_savings).label("energy_cost_savings"), ) - .join( - PlanRecommendations, - PlanRecommendations.recommendation_id == Recommendation.id, - ) - .join(PlanModel, PlanModel.id == PlanRecommendations.plan_id) + .join(PlanModel, PlanModel.id == Recommendation.plan_id) .filter( PlanModel.portfolio_id == portfolio_id, PlanModel.scenario_id == scenario_id, diff --git a/backend/export/property_scenarios/db_functions.py b/backend/export/property_scenarios/db_functions.py index e9b3d7e3..d18b97f6 100644 --- a/backend/export/property_scenarios/db_functions.py +++ b/backend/export/property_scenarios/db_functions.py @@ -8,7 +8,6 @@ from collections import defaultdict from backend.app.db.models.recommendations import ( Recommendation, PlanModel, - PlanRecommendations, RecommendationMaterials, ) from backend.app.db.models.portfolio import ( @@ -157,13 +156,9 @@ class DbMethods: stmt = ( select(Recommendation, PlanModel.scenario_id, PlanModel.name) - .join( - PlanRecommendations, - Recommendation.id == PlanRecommendations.recommendation_id, - ) - .join(PlanModel, PlanModel.id == PlanRecommendations.plan_id) + .join(PlanModel, PlanModel.id == Recommendation.plan_id) .where( - PlanRecommendations.plan_id.in_(plan_ids), + Recommendation.plan_id.in_(plan_ids), Recommendation.default.is_(True), Recommendation.already_installed.is_(False), ) diff --git a/backend/export/tests/test_export.py b/backend/export/tests/test_export.py index 42177749..973364fd 100644 --- a/backend/export/tests/test_export.py +++ b/backend/export/tests/test_export.py @@ -171,13 +171,17 @@ def test_default_export_integration(db_session): # 5) Insert recommendation # ---------------------------------------- + rec_to_plan = dict( + zip(plan_recs_df["recommendation_id"], plan_recs_df["plan_id"]) + ) recs = [ Recommendation( + plan_id=rec_to_plan.get(row["id"]), **{ col: row[col] for col in Recommendation.__table__.columns.keys() - if col in row - } + if col in row and col != "plan_id" + }, ) for _, row in recommendations_df.iterrows() ] @@ -607,9 +611,11 @@ def test_solar_with_battery_example(db_session): # ------------------------------------------------- recommendations_df.loc[0, "measure_type"] = "solar_pv" + rec_to_plan = dict(zip(plan_recs_df.recommendation_id, plan_recs_df.plan_id)) for row in recommendations_df.itertuples(index=False): rec = Recommendation( id=row.id, + plan_id=rec_to_plan.get(row.id), property_id=row.property_id, measure_type=row.measure_type, estimated_cost=row.estimated_cost,