diff --git a/backend/app/db/functions/recommendations_functions.py b/backend/app/db/functions/recommendations_functions.py index 28d82416..6816e25b 100644 --- a/backend/app/db/functions/recommendations_functions.py +++ b/backend/app/db/functions/recommendations_functions.py @@ -634,17 +634,26 @@ def get_scenario(scenario_id: int) -> Optional[ScenarioModel]: def update_plan(plan_model: PlanModel, scenario_model: ScenarioModel) -> bool: with db_read_session() as session: - stmt = ( - update(PlanModel) - .where(PlanModel.id == plan_model.id) - .values(**plan_model.model_dump(exclude={"id"}, exclude_unset=True)) + plan_values = { + c.name: getattr(plan_model, c.name) + for c in plan_model.__table__.columns + if c.name != "id" + } + scenario_values = { + c.name: getattr(scenario_model, c.name) + for c in scenario_model.__table__.columns + if c.name not in {"id", "portfolio_id"} + } + + plan_stmt = ( + update(PlanModel).where(PlanModel.id == plan_model.id).values(**plan_values) ) - plan_result = session.exec(stmt) + plan_result = session.exec(plan_stmt) scenario_stmt = ( update(ScenarioModel) .where(ScenarioModel.id == scenario_model.id) - .values(**scenario_model.model_dump(exclude={"id"}, exclude_unset=True)) + .values(**scenario_values) ) session.exec(scenario_stmt)