diff --git a/backend/app/db/functions/magic_plan_functions.py b/backend/app/db/functions/magic_plan_functions.py new file mode 100644 index 00000000..129ec958 --- /dev/null +++ b/backend/app/db/functions/magic_plan_functions.py @@ -0,0 +1,116 @@ +from typing import Any, cast + +from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlmodel import Session, col + +from datatypes.magicplan.domain.models import Floor, Plan +from backend.app.db.models.magic_plan import ( + MagicPlanDoor, + MagicPlanFloor, + MagicPlanPlan, + MagicPlanRoom, + MagicPlanWindow, +) + + +def save_plan(session: Session, plan: Plan) -> None: + plan_id = _upsert_plan(session, plan) + _delete_children(session, plan_id) + floor_ids = _insert_floors(session, plan.floors, plan_id) + room_ids = _insert_rooms(session, plan.floors, floor_ids) + _insert_windows_and_doors(session, plan.floors, room_ids) + + +def _upsert_plan(session: Session, plan: Plan) -> int: + stmt = ( + pg_insert(MagicPlanPlan) + .values( + magic_plan_uid=plan.uid, + name=plan.name, + address=plan.address, + postcode=plan.postcode, + ) + .on_conflict_do_update( + index_elements=["magic_plan_uid"], + set_={"name": plan.name, "address": plan.address, "postcode": plan.postcode}, + ) + .returning(col(MagicPlanPlan.id)) + ) + row_id: int = session.execute(stmt).scalar_one() + return row_id + + +def _delete_children(session: Session, plan_id: int) -> None: + floor_subq = ( + select(col(MagicPlanFloor.id)) + .where(col(MagicPlanFloor.magic_plan_plan_id) == plan_id) + .scalar_subquery() + ) + room_subq = ( + select(col(MagicPlanRoom.id)) + .where(col(MagicPlanRoom.magic_plan_floor_id).in_(floor_subq)) + .scalar_subquery() + ) + session.execute(delete(MagicPlanWindow).where(col(MagicPlanWindow.magic_plan_room_id).in_(room_subq))) + session.execute(delete(MagicPlanDoor).where(col(MagicPlanDoor.magic_plan_room_id).in_(room_subq))) + session.execute(delete(MagicPlanRoom).where(col(MagicPlanRoom.magic_plan_floor_id).in_(floor_subq))) + session.execute(delete(MagicPlanFloor).where(col(MagicPlanFloor.magic_plan_plan_id) == plan_id)) + + +def _insert_floors(session: Session, floors: list[Floor], plan_id: int) -> list[int]: + rows: list[dict[str, Any]] = [ + {"magic_plan_plan_id": plan_id, "level": floor.level} + for floor in floors + ] + result = session.execute( + pg_insert(MagicPlanFloor).values(rows).returning(col(MagicPlanFloor.id)) + ) + return cast(list[int], list(result.scalars().all())) + + +def _insert_rooms(session: Session, floors: list[Floor], floor_ids: list[int]) -> list[int]: + rows: list[dict[str, Any]] = [ + { + "magic_plan_floor_id": floor_id, + "name": room.name, + "width_m": room.width_m, + "length_m": room.length_m, + "area_m2": room.area_m2, + } + for floor, floor_id in zip(floors, floor_ids) + for room in floor.rooms + ] + result = session.execute( + pg_insert(MagicPlanRoom).values(rows).returning(col(MagicPlanRoom.id)) + ) + return cast(list[int], list(result.scalars().all())) + + +def _insert_windows_and_doors(session: Session, floors: list[Floor], room_ids: list[int]) -> None: + all_rooms = [room for floor in floors for room in floor.rooms] + + window_rows: list[dict[str, Any]] = [ + { + "magic_plan_room_id": room_id, + "width_m": window.width_m, + "height_m": window.height_m, + "area_m2": window.area_m2, + "opening_type": window.opening_type, + } + for room, room_id in zip(all_rooms, room_ids) + for window in room.windows + ] + door_rows: list[dict[str, Any]] = [ + { + "magic_plan_room_id": room_id, + "width_mm": door.width_mm, + } + for room, room_id in zip(all_rooms, room_ids) + for door in room.doors + ] + + if window_rows: + session.execute(pg_insert(MagicPlanWindow).values(window_rows)) + if door_rows: + session.execute(pg_insert(MagicPlanDoor).values(door_rows)) diff --git a/backend/app/db/functions/tests/test_magic_plan_functions.py b/backend/app/db/functions/tests/test_magic_plan_functions.py index c9785e26..42b42bba 100644 --- a/backend/app/db/functions/tests/test_magic_plan_functions.py +++ b/backend/app/db/functions/tests/test_magic_plan_functions.py @@ -101,24 +101,25 @@ def test_save_plan_deletes_floors_before_inserting(mock_session: MagicMock, doma assert floor_delete_idx < floor_insert_idx -def test_save_plan_inserts_correct_floor_count(mock_session: MagicMock, domain_plan: Plan) -> None: +def test_save_plan_floor_insert_contains_all_levels(mock_session: MagicMock, domain_plan: Plan) -> None: # Act save_plan(mock_session, domain_plan) - # Assert — floor INSERT contains values for both floors + # Assert — each floor's level value appears in the INSERT stmts = [_compiled(c[0][0]) for c in mock_session.execute.call_args_list] floor_insert = next(s for s in stmts if "INSERT" in s.upper() and "magic_plan_floor" in s) - # Each floor appears as a row — level values 0 and 1 from fixture - assert floor_insert.count("magic_plan_plan_id") >= len(domain_plan.floors) + for floor in domain_plan.floors: + if floor.level is not None: + assert str(floor.level) in floor_insert -def test_save_plan_inserts_correct_room_count(mock_session: MagicMock, domain_plan: Plan) -> None: +def test_save_plan_room_insert_uses_all_floor_ids(mock_session: MagicMock, domain_plan: Plan) -> None: # Act save_plan(mock_session, domain_plan) - # Assert — room INSERT contains values for all 14 rooms - total_rooms = sum(len(f.rooms) for f in domain_plan.floors) + # Assert — both mocked floor ids (10, 20) appear in the room INSERT stmts = [_compiled(c[0][0]) for c in mock_session.execute.call_args_list] room_insert = next(s for s in stmts if "INSERT" in s.upper() and "magic_plan_room" in s) - assert room_insert.count("magic_plan_floor_id") >= total_rooms + assert "10" in room_insert + assert "20" in room_insert def test_save_plan_windows_use_room_ids_from_insert(mock_session: MagicMock, domain_plan: Plan) -> None: