diff --git a/backend/app/db/connection.py b/backend/app/db/connection.py index bff63ae1..a0bbe238 100644 --- a/backend/app/db/connection.py +++ b/backend/app/db/connection.py @@ -18,8 +18,8 @@ db_string = connection_string.format( # each lambda doesn't hog all connections db_engine = create_engine( db_string, - pool_size=1, - max_overflow=0, # Limit the number of extra connections. With this and pool size, we allow 1 connection per lambda + pool_size=3, + max_overflow=5, # Limit the number of extra connections. With this and pool size, we allow 1 connection per lambda pool_pre_ping=True, pool_recycle=300, # Forces SQLAlchemy to close and reopen any connection older than 300 seconds ) diff --git a/backend/app/plan/router.py b/backend/app/plan/router.py index e6e6052f..5de6b74e 100644 --- a/backend/app/plan/router.py +++ b/backend/app/plan/router.py @@ -2,7 +2,8 @@ import boto3 import json import math import asyncio -import random +from contextlib import contextmanager +from sqlmodel import Session from datetime import datetime @@ -29,6 +30,19 @@ router = APIRouter( sqs_client = boto3.client("sqs") +@contextmanager +def db_session(): + session = Session(db_engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + @router.post("/trigger", status_code=202) async def trigger_plan_entrypoint(body: PlanTriggerRequest): """ @@ -57,28 +71,27 @@ async def trigger_plan_entrypoint(body: PlanTriggerRequest): scenario_id = data.get("scenario_id") if not scenario_id: created_at = datetime.now().isoformat() - session = sessionmaker(bind=db_engine)() - - # Create a new scenario - new_scenario = create_scenario( - session=session, - scenario={ - "name": body.scenario_name, - "created_at": created_at, - "budget": body.budget, - "portfolio_id": body.portfolio_id, - "housing_type": body.housing_type, - "goal": body.goal, - "goal_value": body.goal_value, - "trigger_file_path": body.trigger_file_path, - "already_installed_file_path": body.already_installed_file_path, - "patches_file_path": body.patches_file_path, - "non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path, - "exclusions": body.exclusions, - "multi_plan": body.multi_plan - } - ) - scenario_id = new_scenario.id + with db_session() as session: + # Create a new scenario + new_scenario = create_scenario( + session=session, + scenario={ + "name": body.scenario_name, + "created_at": created_at, + "budget": body.budget, + "portfolio_id": body.portfolio_id, + "housing_type": body.housing_type, + "goal": body.goal, + "goal_value": body.goal_value, + "trigger_file_path": body.trigger_file_path, + "already_installed_file_path": body.already_installed_file_path, + "patches_file_path": body.patches_file_path, + "non_invasive_recommendations_file_path": body.non_invasive_recommendations_file_path, + "exclusions": body.exclusions, + "multi_plan": body.multi_plan + } + ) + scenario_id = new_scenario.id # Insert the scenario ID into the data payload data["scenario_id"] = scenario_id