Model/tests/orchestration/test_task_orchestrator.py
2026-05-19 16:35:09 +00:00

151 lines
4.9 KiB
Python

from collections.abc import Iterator
from dataclasses import dataclass
import pytest
from sqlmodel import Session, SQLModel, create_engine
from domain.tasks.subtasks import SubTask, SubTaskStatus
from domain.tasks.tasks import Source, TaskStatus
from orchestration.task_orchestrator import TaskOrchestrator
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
@dataclass
class Harness:
orchestrator: TaskOrchestrator
tasks: TaskPostgresRepository
subtasks: SubTaskPostgresRepository
@pytest.fixture
def harness() -> Iterator[Harness]:
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
tasks = TaskPostgresRepository(session=session)
subtasks = SubTaskPostgresRepository(session=session)
yield Harness(
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
tasks=tasks,
subtasks=subtasks,
)
def test_create_task_with_subtask_creates_both_in_waiting(
harness: Harness,
) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test",
inputs={"foo": "bar"},
source=Source.PORTFOLIO,
source_id="abc",
)
assert task.status is TaskStatus.WAITING
assert subtask.status is SubTaskStatus.WAITING
assert subtask.task_id == task.id
assert subtask.inputs == {"foo": "bar"}
def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
started = harness.orchestrator.start_subtask(
subtask.id, cloud_logs_url="https://example/log"
)
assert started.status is SubTaskStatus.IN_PROGRESS
assert started.cloud_logs_url == "https://example/log"
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
def test_complete_subtask_cascades_to_complete(harness: Harness) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.start_subtask(subtask.id)
harness.orchestrator.complete_subtask(subtask.id, {"value": 42})
done_subtask = harness.subtasks.get(subtask.id)
done_task = harness.tasks.get(task.id)
assert done_subtask.outputs == {"result": {"value": 42}}
assert done_task.status is TaskStatus.COMPLETE
assert done_task.job_completed is not None
def test_fail_subtask_cascades_to_failed(harness: Harness) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom"))
failed_subtask = harness.subtasks.get(subtask.id)
failed_task = harness.tasks.get(task.id)
assert failed_subtask.outputs == {"error": "boom"}
assert failed_task.status is TaskStatus.FAILED
def test_failed_subtask_locks_task_failed_even_with_others_complete(
harness: Harness,
) -> None:
task, first = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
second = SubTask.create(task_id=task.id)
harness.subtasks.create(second)
harness.orchestrator.complete_subtask(first.id)
harness.orchestrator.fail_subtask(second.id, RuntimeError("nope"))
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
def test_mixed_complete_and_in_progress_keeps_task_in_progress(
harness: Harness,
) -> None:
task, first = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
second = SubTask.create(task_id=task.id)
harness.subtasks.create(second)
harness.orchestrator.complete_subtask(first.id)
harness.orchestrator.start_subtask(second.id)
assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS
def test_run_subtask_happy_path_returns_result_and_cascades_complete(
harness: Harness,
) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42})
assert result == {"answer": 42}
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
def test_run_subtask_failing_work_marks_failed_and_reraises(
harness: Harness,
) -> None:
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
def boom() -> None:
raise RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
harness.orchestrator.run_subtask(subtask.id, work=boom)
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
assert harness.tasks.get(task.id).status is TaskStatus.FAILED