mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
234 lines
7.5 KiB
Python
234 lines
7.5 KiB
Python
import logging
|
|
from collections.abc import Generator, Iterator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
from sqlmodel import Session, SQLModel, create_engine
|
|
|
|
from domain.tasks.subtasks import SubTaskStatus
|
|
from domain.tasks.tasks import TaskStatus
|
|
from orchestration.task_orchestrator import TaskOrchestrator
|
|
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
|
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
|
from utilities.aws_lambda.subtask_handler import subtask_handler
|
|
|
|
_LOGGER_NAME = "utilities.aws_lambda.subtask_handler"
|
|
|
|
|
|
@dataclass
|
|
class Harness:
|
|
orchestrator: TaskOrchestrator
|
|
tasks: TaskPostgresRepository
|
|
subtasks: SubTaskPostgresRepository
|
|
|
|
@contextmanager
|
|
def factory(self) -> Generator[TaskOrchestrator, None, None]:
|
|
yield self.orchestrator
|
|
|
|
|
|
@pytest.fixture
|
|
def harness() -> Iterator[Harness]:
|
|
engine = create_engine("sqlite://")
|
|
SQLModel.metadata.create_all(engine)
|
|
with Session(engine) as session:
|
|
tasks = TaskPostgresRepository(session=session)
|
|
subtasks = SubTaskPostgresRepository(session=session)
|
|
yield Harness(
|
|
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
|
|
tasks=tasks,
|
|
subtasks=subtasks,
|
|
)
|
|
|
|
|
|
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
|
|
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
|
|
|
|
|
|
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
|
harness: Harness,
|
|
) -> None:
|
|
_, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
received: dict[str, Any] = {}
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
received["body"] = body
|
|
received["context"] = context
|
|
received["orchestrator"] = orchestrator
|
|
|
|
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
|
|
|
|
assert received["orchestrator"] is harness.orchestrator
|
|
assert received["context"] == "ctx-sentinel"
|
|
assert received["body"]["sub_task_id"] == str(subtask.id)
|
|
|
|
|
|
def test_subtask_handler_completes_parent_subtask_on_success(
|
|
harness: Harness,
|
|
) -> None:
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
return None
|
|
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
|
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
|
|
|
|
|
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
|
harness: Harness,
|
|
) -> None:
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
raise RuntimeError("boom")
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
|
|
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
|
|
|
|
|
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
|
harness: Harness,
|
|
) -> None:
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
child_ids: list[UUID] = []
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
|
|
child_ids.append(child.id)
|
|
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
assert len(child_ids) == 1
|
|
persisted_child = harness.subtasks.get(child_ids[0])
|
|
assert persisted_child.task_id == task.id
|
|
assert persisted_child.status is SubTaskStatus.WAITING
|
|
|
|
|
|
def test_subtask_handler_logs_subtask_lifecycle_on_success(
|
|
harness: Harness, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
return None
|
|
|
|
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
assert f"Running subtask {subtask.id}" in caplog.text
|
|
assert f"Subtask {subtask.id} completed" in caplog.text
|
|
|
|
|
|
def test_subtask_handler_logs_exception_on_failure(
|
|
harness: Harness, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
raise RuntimeError("boom")
|
|
|
|
with caplog.at_level(logging.INFO, logger=_LOGGER_NAME):
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
failures = [r for r in caplog.records if r.levelno == logging.ERROR]
|
|
assert any(
|
|
f"Subtask {subtask.id} failed" in r.getMessage() for r in failures
|
|
)
|
|
assert any(r.exc_info is not None for r in failures)
|
|
|
|
|
|
def test_subtask_handler_records_cloudwatch_url_on_subtask(
|
|
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
monkeypatch.setenv("AWS_REGION", "eu-west-2")
|
|
monkeypatch.setenv(
|
|
"AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter"
|
|
)
|
|
monkeypatch.setenv(
|
|
"AWS_LAMBDA_LOG_STREAM_NAME", "2026/05/20/[$LATEST]abc123"
|
|
)
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
return None
|
|
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
saved_url = harness.subtasks.get(subtask.id).cloud_logs_url
|
|
assert saved_url is not None
|
|
assert saved_url.startswith(
|
|
"https://eu-west-2.console.aws.amazon.com/cloudwatch/home"
|
|
)
|
|
# Log group / stream are console-encoded ("/" -> "$252F").
|
|
assert "$252Faws$252Flambda$252Fpostcode-splitter" in saved_url
|
|
assert "$255B$2524LATEST$255D" in saved_url
|
|
|
|
|
|
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
|
|
harness: Harness, monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
for var in (
|
|
"AWS_REGION",
|
|
"AWS_LAMBDA_LOG_GROUP_NAME",
|
|
"AWS_LAMBDA_LOG_STREAM_NAME",
|
|
):
|
|
monkeypatch.delenv(var, raising=False)
|
|
task, subtask = harness.orchestrator.create_task_with_subtask(
|
|
task_source="manual:test"
|
|
)
|
|
|
|
@subtask_handler(orchestrator_cm=harness.factory)
|
|
def handler(
|
|
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
|
) -> None:
|
|
return None
|
|
|
|
handler(_direct_event(task.id, subtask.id), context=None)
|
|
|
|
assert harness.subtasks.get(subtask.id).cloud_logs_url is None
|