Model/tests/utilities/aws_lambda/test_subtask_handler.py
2026-05-20 14:00:19 +00:00

255 lines
7.8 KiB
Python

import logging
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from uuid import UUID
import pytest
from sqlalchemy import Engine
from sqlmodel import Session
from domain.tasks.subtasks import SubTaskStatus
from domain.tasks.tasks import TaskStatus
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
from utilities.aws_lambda.subtask_handler import subtask_handler
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
@dataclass
class Harness:
orchestrator: TaskOrchestrator
tasks: TaskPostgresRepository
subtasks: SubTaskPostgresRepository
@contextmanager
def factory(self) -> Generator[TaskOrchestrator, None, None]:
yield self.orchestrator
@pytest.fixture
def harness(db_engine: Engine) -> Iterator[Harness]:
with Session(db_engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
tasks=tasks,
subtasks=subtasks,
)
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
harness: Harness,
) -> None:
# arrange
_, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
received: dict[str, Any] = {}
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
received["body"] = body
received["context"] = context
received["orchestrator"] = orchestrator
# act
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
# assert
assert received["orchestrator"] is harness.orchestrator
assert received["context"] == "ctx-sentinel"
assert received["body"]["sub_task_id"] == str(subtask.id)
def test_subtask_handler_completes_parent_subtask_on_success(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
# act / assert
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
child_ids: list[UUID] = []
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
child_ids.append(child.id)
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert len(child_ids) == 1
persisted_child = harness.subtasks.get(child_ids[0])
assert persisted_child.task_id == task.id
assert persisted_child.status is SubTaskStatus.WAITING
def test_subtask_handler_logs_subtask_lifecycle_on_success(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert f"Running subtask {subtask.id}" in caplog.text
assert f"Subtask {subtask.id} completed" in caplog.text
def test_subtask_handler_logs_exception_on_failure(
harness: Harness, caplog: pytest.LogCaptureFixture
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
raise RuntimeError("boom")
# act / assert
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
with pytest.raises(RuntimeError, match="boom"):
handler(_direct_event(task.id, subtask.id), context=None)
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
assert any(
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
)
assert any(r.exc_info is not None for r in failures)
def test_subtask_handler_records_cloudwatch_url_on_subtask(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
monkeypatch.setenv("AWS_REGION", "eu-west-2")
monkeypatch.setenv(
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
)
monkeypatch.setenv(
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
assert saved_url is not None
assert saved_url.startswith(
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
)
# Log group / stream are console-encoded ("/" -> "$252F").
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
assert "$255B$2524LATEST$255D" in saved_url
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:
# arrange
for var in (
"AWS_REGION",
"AWS_LAMBDA_LOG_GROUP_NAME",
"AWS_LAMBDA_LOG_STREAM_NAME",
):
monkeypatch.delenv(var, raising=False)
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
@subtask_handler(orchestrator_cm=harness.factory)
def handler(
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
) -> None:
return None
# act
handler(_direct_event(task.id, subtask.id), context=None)
# assert
assert harness.subtasks.get(subtask.id).cloud_logs_url is None