Model/orchestration/task_orchestrator.py
2026-05-20 13:21:11 +00:00

106 lines
3.3 KiB
Python

from typing import Any, Callable, Optional
from uuid import UUID
from domain.tasks.subtasks import SubTask
from domain.tasks.tasks import Source, Task
from repositories.tasks.subtask_repository import SubTaskRepository
from repositories.tasks.task_repository import TaskRepository
from utilities.private import private
class TaskOrchestrator:
"""Coordinates Task + SubTask lifecycle.
Exposes primitives (start/complete/fail_subtask) for handlers that want
fine-grained control, and a high-level run_subtask wrapper that owns the
try/except so it can replace the body of the legacy subtask_handler
decorator in backend/utils/subtasks.py.
Each primitive saves the SubTask, then recomputes the parent Task's
status from all its children.
"""
def __init__(
self,
task_repo: TaskRepository,
subtask_repo: SubTaskRepository,
) -> None:
self._tasks = task_repo
self._subtasks = subtask_repo
def create_task_with_subtask(
self,
*,
task_source: str,
inputs: Optional[dict[str, Any]] = None,
service: Optional[str] = None,
source: Optional[Source] = None,
source_id: Optional[str] = None,
) -> tuple[Task, SubTask]:
task = Task.create(
task_source=task_source,
service=service,
source=source,
source_id=source_id,
)
self._tasks.create(task)
subtask = SubTask.create(task_id=task.id, inputs=inputs)
self._subtasks.create(subtask)
return task, subtask
def create_child_subtask(
self,
parent_task_id: UUID,
*,
inputs: Optional[dict[str, Any]] = None,
) -> SubTask:
subtask = SubTask.create(task_id=parent_task_id, inputs=inputs)
self._subtasks.create(subtask)
return subtask
def start_subtask(
self, subtask_id: UUID, cloud_logs_url: Optional[str] = None
) -> SubTask:
subtask = self._subtasks.get(subtask_id)
subtask.start(cloud_logs_url)
self._subtasks.save(subtask)
self._cascade(subtask.task_id)
return subtask
def complete_subtask(
self, subtask_id: UUID, result: Any = None
) -> SubTask:
subtask = self._subtasks.get(subtask_id)
subtask.complete(result)
self._subtasks.save(subtask)
self._cascade(subtask.task_id)
return subtask
def fail_subtask(self, subtask_id: UUID, error: BaseException) -> SubTask:
subtask = self._subtasks.get(subtask_id)
subtask.fail(error)
self._subtasks.save(subtask)
self._cascade(subtask.task_id)
return subtask
def run_subtask(
self,
subtask_id: UUID,
work: Callable[[], Any],
cloud_logs_url: Optional[str] = None,
) -> Any:
self.start_subtask(subtask_id, cloud_logs_url)
try:
result = work()
except Exception as e:
self.fail_subtask(subtask_id, e)
raise
self.complete_subtask(subtask_id, result)
return result
@private
def _cascade(self, task_id: UUID) -> None:
statuses = [s.status for s in self._subtasks.list_by_task(task_id)]
task = self._tasks.get(task_id)
task.recalculate_from_subtasks(statuses)
self._tasks.save(task)