import logging from collections.abc import Generator, Iterator from contextlib import contextmanager from dataclasses import dataclass from typing import Any from uuid import UUID import pytest from sqlalchemy import Engine from sqlmodel import Session from domain.tasks.subtasks import SubTaskStatus from domain.tasks.tasks import TaskStatus from orchestration.task_orchestrator import TaskOrchestrator from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository from repositories.tasks.task_postgres_repository import TaskPostgresRepository from utilities.aws_lambda.subtask_handler import subtask_handler _LOGGER_NAME = "utilities.aws_lambda.subtask_handler" @dataclass class Harness: orchestrator: TaskOrchestrator tasks: TaskPostgresRepository subtasks: SubTaskPostgresRepository @contextmanager def factory(self) -> Generator[TaskOrchestrator, None, None]: yield self.orchestrator @pytest.fixture def harness(db_engine: Engine) -> Iterator[Harness]: with Session(db_engine) as session: tasks = TaskPostgresRepository(session=session) subtasks = SubTaskPostgresRepository(session=session) yield Harness( orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks), tasks=tasks, subtasks=subtasks, ) def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]: return {"task_id": str(task_id), "sub_task_id": str(subtask_id)} def test_subtask_handler_injects_orchestrator_as_third_positional_argument( harness: Harness, ) -> None: # arrange _, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) received: dict[str, Any] = {} @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: received["body"] = body received["context"] = context received["orchestrator"] = orchestrator # act handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel") # assert assert received["orchestrator"] is harness.orchestrator assert received["context"] == "ctx-sentinel" assert received["body"]["sub_task_id"] == str(subtask.id) def test_subtask_handler_completes_parent_subtask_on_success( harness: Harness, ) -> None: # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: return None # act handler(_direct_event(task.id, subtask.id), context=None) # assert assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE def test_subtask_handler_marks_parent_failed_and_reraises_on_error( harness: Harness, ) -> None: # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: raise RuntimeError("boom") # act / assert with pytest.raises(RuntimeError, match="boom"): handler(_direct_event(task.id, subtask.id), context=None) assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED assert harness.tasks.get(task.id).status is TaskStatus.FAILED def test_subtask_handler_injected_orchestrator_can_create_child_subtask( harness: Harness, ) -> None: # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) child_ids: list[UUID] = [] @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: child = orchestrator.create_child_subtask(task.id, inputs={"split": 1}) child_ids.append(child.id) # act handler(_direct_event(task.id, subtask.id), context=None) # assert assert len(child_ids) == 1 persisted_child = harness.subtasks.get(child_ids[0]) assert persisted_child.task_id == task.id assert persisted_child.status is SubTaskStatus.WAITING def test_subtask_handler_logs_subtask_lifecycle_on_success( harness: Harness, caplog: pytest.LogCaptureFixture ) -> None: # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: return None # act with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): handler(_direct_event(task.id, subtask.id), context=None) # assert assert f"Running subtask {subtask.id}" in caplog.text assert f"Subtask {subtask.id} completed" in caplog.text def test_subtask_handler_logs_exception_on_failure( harness: Harness, caplog: pytest.LogCaptureFixture ) -> None: # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: raise RuntimeError("boom") # act / assert with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): with pytest.raises(RuntimeError, match="boom"): handler(_direct_event(task.id, subtask.id), context=None) failures = [r for r in caplog.records if r.levelno == logging.ERROR] assert any( f"Subtask {subtask.id} failed" in r.getMessage() for r in failures ) assert any(r.exc_info is not None for r in failures) def test_subtask_handler_records_cloudwatch_url_on_subtask( harness: Harness, monkeypatch: pytest.MonkeyPatch ) -> None: # arrange monkeypatch.setenv("AWS_REGION", "eu-west-2") monkeypatch.setenv( "AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter" ) monkeypatch.setenv( "AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123" ) task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: return None # act handler(_direct_event(task.id, subtask.id), context=None) # assert saved_url = harness.subtasks.get(subtask.id).cloud_logs_url assert saved_url is not None assert saved_url.startswith( "https://eu-west-2.console.aws.amazon.com/cloudwatch/home" ) # Log group / stream are console-encoded ("/" -> "$252F"). assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url assert "$255B$2524LATEST$255D" in saved_url def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( harness: Harness, monkeypatch: pytest.MonkeyPatch ) -> None: # arrange for var in ( "AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", "AWS_LAMBDA_LOG_STREAM_NAME", ): monkeypatch.delenv(var, raising=False) task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: return None # act handler(_direct_event(task.id, subtask.id), context=None) # assert assert harness.subtasks.get(subtask.id).cloud_logs_url is None