diff --git a/backend/app/db/functions/recommendations_functions.py b/backend/app/db/functions/recommendations_functions.py index 1864a330..2f85cbec 100644 --- a/backend/app/db/functions/recommendations_functions.py +++ b/backend/app/db/functions/recommendations_functions.py @@ -1,8 +1,9 @@ -from typing import List -from sqlalchemy import text -from sqlalchemy import insert, delete +from typing import Any, List, Optional +from sqlalchemy import text, insert, delete, select, update from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError +from sqlmodel import Session + from backend.app.db.models.recommendations import ( PlanModel, Recommendation, @@ -618,12 +619,26 @@ def clear_portfolio_in_batches( def get_plans_by_portfolio_id(portfolio_id: int) -> List[PlanModel]: - raise NotImplementedError + stmt = select(PlanModel).where(PlanModel.portfolio_id == portfolio_id) + with db_read_session() as session: + session_any: Any = session # Typehint as Any to satisfy Pylance... + return session_any.exec(stmt).all() -def get_scenario(scenario_id: int) -> ScenarioModel: - raise NotImplementedError +def get_scenario(scenario_id: int) -> Optional[ScenarioModel]: + stmt = select(ScenarioModel).where(ScenarioModel.id == scenario_id) + with db_read_session() as session: + session_any: Any = session # Typehint as Any to satisfy Pylance... + return session_any.exec(stmt).scalar_one_or_none() def set_plan_default(plan_id: int, is_default: bool) -> bool: - raise NotImplementedError + with db_read_session() as session: + stmt = ( + update(PlanModel) + .where(PlanModel.id == plan_id) + .values(is_default=is_default) + ) + result = session.exec(stmt) + session.commit() + return result.rowcount > 0