diff --git a/.devcontainer/backend/requirements.txt b/.devcontainer/backend/requirements.txt index 5cd40ced..e7d1b099 100644 --- a/.devcontainer/backend/requirements.txt +++ b/.devcontainer/backend/requirements.txt @@ -21,6 +21,7 @@ ipykernel>=6.25,<7 dotenv psycopg[binary] pytest-postgresql +httpx # Formatting black==26.1.0 boto3-stubs \ No newline at end of file diff --git a/backend/app/db/connection.py b/backend/app/db/connection.py index f0649c71..c0656374 100644 --- a/backend/app/db/connection.py +++ b/backend/app/db/connection.py @@ -1,7 +1,7 @@ from sqlalchemy import create_engine from contextlib import contextmanager from backend.app.config import get_settings -from sqlmodel import Session +from sqlalchemy.orm import Session connection_string = ( "postgresql+{drivername}://{username}:{password}@{server}:{port}/{dbname}" @@ -56,3 +56,8 @@ def db_read_session(): yield session finally: session.close() + + +def get_session(): + with db_session() as session: + yield session diff --git a/backend/app/db/models/tasks.py b/backend/app/db/models/tasks.py index e97a939f..1eeeafaa 100644 --- a/backend/app/db/models/tasks.py +++ b/backend/app/db/models/tasks.py @@ -1,65 +1,130 @@ +# import enum +# from typing import Optional +# from datetime import datetime +# from uuid import UUID, uuid4 + +# from sqlalchemy import Column, Enum +# from sqlmodel import SQLModel, Field, Relationship + + +# class SourceEnum(enum.Enum): # TODO: move to domain? +# PORTFOLIO = "portfolio_id" + + +# class Task(SQLModel, table=True): +# __tablename__ = "tasks" + +# id: UUID = Field( +# default_factory=uuid4, +# primary_key=True, +# index=True, +# ) +# task_source: str +# job_started: Optional[datetime] = None +# job_completed: Optional[datetime] = None +# status: str = Field(default="In Progress") +# service: Optional[str] = None +# updated_at: datetime = Field(default_factory=datetime.utcnow) + +# # source: Mapped[Optional[SourceEnum]] = mapped_column(Enum(SourceEnum)) <- SQLAlchemy not SQLModel + +# source: Optional[SourceEnum] = Field( +# default=None, +# sa_column=Column( +# Enum( +# SourceEnum, +# name="source", +# values_callable=lambda e: [m.value for m in e], +# ), +# nullable=True, +# ), +# ) +# source_id: Optional[str] = None + +# sub_tasks: list["SubTask"] = Relationship(back_populates="task") + + +# class SubTask(SQLModel, table=True): +# __tablename__ = "sub_task" + +# id: UUID = Field( +# default_factory=uuid4, +# primary_key=True, +# index=True, +# ) + +# task_id: UUID = Field(foreign_key="tasks.id") +# job_started: Optional[datetime] = None +# job_completed: Optional[datetime] = None +# status: str = Field(default="In Progress") +# inputs: Optional[str] = None +# outputs: Optional[str] = None +# cloud_logs_url: Optional[str] = None +# updated_at: datetime = Field(default_factory=datetime.utcnow) + +# task: Optional["Task"] = Relationship(back_populates="sub_tasks") + + import enum -from typing import Optional +from typing import Optional, List from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import Column, Enum -from sqlmodel import SQLModel, Field, Relationship +from sqlalchemy import Enum, String, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import UUID as PG_UUID, TIMESTAMP + +from backend.app.db.base import Base -class SourceEnum(enum.Enum): # TODO: move to domain? +class SourceEnum(enum.Enum): PORTFOLIO = "portfolio_id" -class Task(SQLModel, table=True): +class Task(Base): __tablename__ = "tasks" - id: UUID = Field( - default_factory=uuid4, - primary_key=True, - index=True, + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), primary_key=True, default=uuid4, index=True ) - task_source: str - job_started: Optional[datetime] = None - job_completed: Optional[datetime] = None - status: str = Field(default="In Progress") - service: Optional[str] = None - updated_at: datetime = Field(default_factory=datetime.utcnow) - - # source: Mapped[Optional[SourceEnum]] = mapped_column(Enum(SourceEnum)) <- SQLAlchemy not SQLModel - - source: Optional[SourceEnum] = Field( - default=None, - sa_column=Column( - Enum( - SourceEnum, - name="source", - values_callable=lambda e: [m.value for m in e], - ), - nullable=True, + task_source: Mapped[str] = mapped_column(String, nullable=False) + job_started: Mapped[Optional[datetime]] = mapped_column(TIMESTAMP, nullable=True) + job_completed: Mapped[Optional[datetime]] = mapped_column(TIMESTAMP, nullable=True) + status: Mapped[str] = mapped_column(String, nullable=False, default="In Progress") + service: Mapped[Optional[str]] = mapped_column(String, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + TIMESTAMP, nullable=False, default=datetime.utcnow + ) + source: Mapped[Optional[SourceEnum]] = mapped_column( + Enum( + SourceEnum, + name="source", + values_callable=lambda e: [m.value for m in e], ), + nullable=True, ) - source_id: Optional[str] = None + source_id: Mapped[Optional[str]] = mapped_column(String, nullable=True) - sub_tasks: list["SubTask"] = Relationship(back_populates="task") + sub_tasks: Mapped[List["SubTask"]] = relationship("SubTask", back_populates="task") -class SubTask(SQLModel, table=True): +class SubTask(Base): __tablename__ = "sub_task" - id: UUID = Field( - default_factory=uuid4, - primary_key=True, - index=True, + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), primary_key=True, default=uuid4, index=True + ) + task_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("tasks.id"), nullable=False + ) + job_started: Mapped[Optional[datetime]] = mapped_column(TIMESTAMP, nullable=True) + job_completed: Mapped[Optional[datetime]] = mapped_column(TIMESTAMP, nullable=True) + status: Mapped[str] = mapped_column(String, nullable=False, default="In Progress") + inputs: Mapped[Optional[str]] = mapped_column(String, nullable=True) + outputs: Mapped[Optional[str]] = mapped_column(String, nullable=True) + cloud_logs_url: Mapped[Optional[str]] = mapped_column(String, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + TIMESTAMP, nullable=False, default=datetime.utcnow ) - task_id: UUID = Field(foreign_key="tasks.id") - job_started: Optional[datetime] = None - job_completed: Optional[datetime] = None - status: str = Field(default="In Progress") - inputs: Optional[str] = None - outputs: Optional[str] = None - cloud_logs_url: Optional[str] = None - updated_at: datetime = Field(default_factory=datetime.utcnow) - - task: Optional["Task"] = Relationship(back_populates="sub_tasks") + task: Mapped[Optional["Task"]] = relationship("Task", back_populates="sub_tasks") diff --git a/backend/app/tasks/router.py b/backend/app/tasks/router.py index e8ec2686..88f68762 100644 --- a/backend/app/tasks/router.py +++ b/backend/app/tasks/router.py @@ -15,9 +15,12 @@ from backend.app.tasks.schema import ( # Correct location of interfaces from backend.app.db.functions.tasks.Tasks import TasksInterface, SubTaskInterface -from backend.app.db.connection import get_db_session +from backend.app.db.connection import get_db_session, get_session from backend.app.db.models.tasks import SourceEnum, Task, SubTask -from sqlmodel import select +from sqlalchemy.orm import Session +from sqlalchemy import select + +# from sqlmodel import Session, select router = APIRouter( @@ -74,23 +77,27 @@ async def get_task(task_id: UUID): "/by-source/{source}/{source_id}/{service}", summary="Get the most recent task by source, source_id, and service", ) -async def get_task_by_source(source: SourceEnum, source_id: str, service: str): - with get_db_session() as session: - task = session.exec( - select(Task) - .where( - Task.source == source, - Task.source_id == source_id, - Task.service == service, - ) - .order_by(Task.job_started.desc()) - .limit(1) - ).first() +async def get_task_by_source( + source: SourceEnum, + source_id: str, + service: str, + session: Session = Depends(get_session), +): + task = session.execute( + select(Task) + .where( + Task.source == source, + Task.source_id == source_id, + Task.service == service, + ) + .order_by(Task.job_started.desc()) + .limit(1) + ).first() - if not task: - raise HTTPException(status_code=404, detail="Task not found") + if not task: + raise HTTPException(status_code=404, detail="Task not found") - return {"task": task} + return {"task": task} # ============================================================ diff --git a/backend/app/tests/tasks/test_get_task.py b/backend/app/tests/tasks/test_get_task.py index 751ecd7d..b8e36be8 100644 --- a/backend/app/tests/tasks/test_get_task.py +++ b/backend/app/tests/tasks/test_get_task.py @@ -1,12 +1,17 @@ -from datetime import datetime import uuid +from datetime import datetime +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import select +from sqlalchemy.orm import Session -from sqlmodel import Session - +from backend.app.db.connection import get_session, db_session as db_session_dependency +from backend.app.dependencies import validate_token from backend.app.db.models.tasks import SourceEnum, Task +from backend.app.tasks.router import router -def test_get_task_by_source(db_session: Session): +def test_get_task_by_source(db_session: Session) -> None: # arrange current_categorisation_task = Task( id=uuid.uuid4(), @@ -56,6 +61,36 @@ def test_get_task_by_source(db_session: Session): source_id="100", ) + db_session.add_all( + [ + current_categorisation_task, + previous_categorisation_task, + other_portfolio_categorisation_task, + engine_task, + ] + ) + db_session.commit() + # db_session.flush() + + # debug: confirm data is visible in this session + all_tasks = db_session.execute(select(Task)).scalars().all() + print(f"Tasks in db: {[(t.service, t.source_id, t.source) for t in all_tasks]}") + # act + test_app = FastAPI() + test_app.include_router(router) + + def override_get_session(): + yield db_session + + test_app.dependency_overrides[get_session] = override_get_session + test_app.dependency_overrides[validate_token] = lambda: None + + client = TestClient(test_app) + response = client.get("/tasks/by-source/portfolio_id/100/plan_categorisation") + + test_app.dependency_overrides.clear() # assert + assert response.status_code == 200 + assert response.json()["task"]["id"] == str(current_categorisation_task.id) diff --git a/pytest.ini b/pytest.ini index 608d5e0c..06eee3ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,5 +2,24 @@ pythonpath = . log_cli = true log_cli_level = INFO -addopts = --cov-report term-missing --cov=etl/epc --cov=recommendations --cov=backend --cov=etl/epc_clean --cov=etl/spatial -testpaths = recommendations/tests backend/tests etl/epc/tests etl/epc_clean/tests etl/spatial/tests backend/condition/tests backend/address2UPRN/tests backend/onboarders/tests backend/categorisation/tests backend/export/tests + +addopts = + --cov-report term-missing + --cov=etl/epc + --cov=etl/epc_clean + --cov=etl/spatial + --cov=recommendations + --cov=backend + +testpaths = + backend/tests + backend/address2UPRN/tests + backend/categorisation/tests + backend/condition/tests + backend/export/tests + backend/onboarders/tests + backend/app/tests + etl/epc/tests + etl/epc_clean/tests + etl/spatial/tests + recommendations/tests \ No newline at end of file diff --git a/test.requirements.txt b/test.requirements.txt index d8b8b777..8fa139d3 100644 --- a/test.requirements.txt +++ b/test.requirements.txt @@ -4,4 +4,5 @@ pytest-cov pytest-mock dotenv psycopg[binary] -pytest-postgresql \ No newline at end of file +pytest-postgresql +httpx \ No newline at end of file