"""Tests for the @subtask_handler decorator. Covers the contract that the decorator owns the parent SubTask lifecycle and injects the decorator-owned TaskOrchestrator as a third positional argument to the wrapped function — so the handler can compose its own use-case orchestrator that shares the session. """ from collections.abc import Generator, Iterator from contextlib import contextmanager from dataclasses import dataclass from typing import Any from uuid import UUID import pytest from sqlmodel import Session, SQLModel, create_engine from domain.tasks.subtasks import SubTaskStatus from domain.tasks.tasks import TaskStatus from orchestration.task_orchestrator import TaskOrchestrator from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository from repositories.tasks.task_postgres_repository import TaskPostgresRepository from utilities.aws_lambda.subtask_handler import subtask_handler @dataclass class Harness: orchestrator: TaskOrchestrator tasks: TaskPostgresRepository subtasks: SubTaskPostgresRepository @contextmanager def factory(self) -> Generator[TaskOrchestrator, None, None]: yield self.orchestrator @pytest.fixture def harness() -> Iterator[Harness]: engine = create_engine("sqlite://") SQLModel.metadata.create_all(engine) with Session(engine) as session: tasks = TaskPostgresRepository(session=session) subtasks = SubTaskPostgresRepository(session=session) yield Harness( orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks), tasks=tasks, subtasks=subtasks, ) def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]: return {"task_id": str(task_id), "sub_task_id": str(subtask_id)} def test_subtask_handler_injects_orchestrator_as_third_positional_argument( harness: Harness, ) -> None: """The wrapped function receives the decorator-owned TaskOrchestrator so it can share the session with its own use-case orchestrator.""" _, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) received: dict[str, Any] = {} @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: received["body"] = body received["context"] = context received["orchestrator"] = orchestrator handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel") assert received["orchestrator"] is harness.orchestrator assert received["context"] == "ctx-sentinel" assert received["body"]["sub_task_id"] == str(subtask.id) def test_subtask_handler_completes_parent_subtask_on_success( harness: Harness, ) -> None: task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: return None handler(_direct_event(task.id, subtask.id), context=None) assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE def test_subtask_handler_marks_parent_failed_and_reraises_on_error( harness: Harness, ) -> None: task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: raise RuntimeError("boom") with pytest.raises(RuntimeError, match="boom"): handler(_direct_event(task.id, subtask.id), context=None) assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED assert harness.tasks.get(task.id).status is TaskStatus.FAILED def test_subtask_handler_injected_orchestrator_can_create_child_subtask( harness: Harness, ) -> None: """Smoke check the share-the-session promise: the injected orchestrator is the same one the decorator owns, so a handler can use it to create child SubTasks under the same session.""" task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) child_ids: list[UUID] = [] @subtask_handler(orchestrator_cm=harness.factory) def handler( body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator ) -> None: child = orchestrator.create_child_subtask(task.id, inputs={"split": 1}) child_ids.append(child.id) handler(_direct_event(task.id, subtask.id), context=None) assert len(child_ids) == 1 persisted_child = harness.subtasks.get(child_ids[0]) assert persisted_child.task_id == task.id assert persisted_child.status is SubTaskStatus.WAITING