from typing import Any, Callable, Optional from uuid import UUID from domain.tasks.subtasks import SubTask from domain.tasks.tasks import Source, Task from repositories.tasks.subtask_repository import SubTaskRepository from repositories.tasks.task_repository import TaskRepository from utilities.private import private class TaskOrchestrator: """Coordinates Task + SubTask lifecycle. Exposes primitives (start/complete/fail_subtask) for handlers that want fine-grained control, and a high-level run_subtask wrapper that owns the try/except so it can replace the body of the legacy subtask_handler decorator in backend/utils/subtasks.py. Each primitive saves the SubTask, then recomputes the parent Task's status from all its children. """ def __init__( self, task_repo: TaskRepository, subtask_repo: SubTaskRepository, ) -> None: self._tasks = task_repo self._subtasks = subtask_repo def create_task_with_subtask( self, *, task_source: str, inputs: Optional[dict[str, Any]] = None, service: Optional[str] = None, source: Optional[Source] = None, source_id: Optional[str] = None, ) -> tuple[Task, SubTask]: task = Task.create( task_source=task_source, service=service, source=source, source_id=source_id, ) self._tasks.create(task) subtask = SubTask.create(task_id=task.id, inputs=inputs) self._subtasks.create(subtask) return task, subtask def start_subtask( self, subtask_id: UUID, cloud_logs_url: Optional[str] = None ) -> SubTask: subtask = self._subtasks.get(subtask_id) subtask.start(cloud_logs_url) self._subtasks.save(subtask) self._cascade(subtask.task_id) return subtask def complete_subtask( self, subtask_id: UUID, result: Any = None ) -> SubTask: subtask = self._subtasks.get(subtask_id) subtask.complete(result) self._subtasks.save(subtask) self._cascade(subtask.task_id) return subtask def fail_subtask(self, subtask_id: UUID, error: BaseException) -> SubTask: subtask = self._subtasks.get(subtask_id) subtask.fail(error) self._subtasks.save(subtask) self._cascade(subtask.task_id) return subtask def run_subtask( self, subtask_id: UUID, work: Callable[[], Any], cloud_logs_url: Optional[str] = None, ) -> Any: self.start_subtask(subtask_id, cloud_logs_url) try: result = work() except Exception as e: self.fail_subtask(subtask_id, e) raise self.complete_subtask(subtask_id, result) return result @private def _cascade(self, task_id: UUID) -> None: statuses = [s.status for s in self._subtasks.list_by_task(task_id)] task = self._tasks.get(task_id) task.recalculate_from_subtasks(statuses) self._tasks.save(task)