# ---- Standard Library ---- from typing import Optional, Dict, Any from datetime import datetime, timezone from uuid import UUID import json # ---- SQLModel / SQLAlchemy ---- from sqlmodel import Session, select # ---- DB Session ---- from backend.app.db.connection import get_db_session # ---- Models ---- from backend.app.db.models.tasks import Task, SubTask # ============================================================ # SubTask Interface # ============================================================ class SubTaskInterface: """ CRUD operations for SubTask + cascading Task progress updates. """ # -------------------------------------------------------- # CREATE SUBTASK # -------------------------------------------------------- def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None, status=None): now = datetime.now(timezone.utc) with get_db_session() as session: task = session.get(Task, task_id) if not task: raise ValueError(f"Task {task_id} not found") # We treat waiting as the default status status = "waiting" if status is None else status subtask = SubTask( task_id=task_id, inputs=json.dumps(inputs) if inputs else None, status=status, job_started=now, job_completed=None, ) session.add(subtask) session.commit() session.refresh(subtask) # Recalculate parent task progress self._update_task_progress(session, task_id) return subtask.id # -------------------------------------------------------- # UPDATE STATUS (in progress, complete, failed) # -------------------------------------------------------- def update_subtask_status( self, subtask_id: UUID, status: str, outputs=None, cloud_logs_url=None ): """ Update the status of a subtask, and recalculate the parent task progress. :param subtask_id: UUID of the subtask to update :param status: New status (in progress, complete, failed) :param outputs: Optional outputs to set :param cloud_logs_url: Optional cloud logs URL to set :return: """ now = datetime.now(timezone.utc) with get_db_session() as session: subtask = session.get(SubTask, subtask_id) if not subtask: raise ValueError(f"SubTask {subtask_id} not found") normalized = status.lower() # When job really starts if normalized == "in progress" and subtask.job_started is None: subtask.job_started = now # Completed or failed if normalized in ("complete", "failed"): subtask.job_completed = now subtask.status = normalized subtask.updated_at = now if outputs is not None: subtask.outputs = json.dumps(outputs) if cloud_logs_url is not None: subtask.cloud_logs_url = cloud_logs_url session.add(subtask) session.commit() # Recalculate task status self._update_task_progress(session, subtask.task_id) session.refresh(subtask) return subtask # -------------------------------------------------------- # UPDATE OUTPUTS # -------------------------------------------------------- @staticmethod def update_subtask_output(subtask_id: UUID, outputs: Dict[str, Any]): now = datetime.now(timezone.utc) with get_db_session() as session: subtask = session.get(SubTask, subtask_id) if not subtask: raise ValueError(f"SubTask {subtask_id} not found") subtask.outputs = json.dumps(outputs) subtask.updated_at = now session.add(subtask) session.commit() session.refresh(subtask) return subtask # -------------------------------------------------------- # UPDATE CLOUD LOGS URL # -------------------------------------------------------- @staticmethod def update_subtask_logs(subtask_id: UUID, cloud_logs_url: str): now = datetime.now(timezone.utc) with get_db_session() as session: subtask = session.get(SubTask, subtask_id) if not subtask: raise ValueError(f"SubTask {subtask_id} not found") subtask.cloud_logs_url = cloud_logs_url subtask.updated_at = now session.add(subtask) session.commit() session.refresh(subtask) return subtask # -------------------------------------------------------- # SET BOTH OUTPUT + LOGS # -------------------------------------------------------- @staticmethod def set_subtask_result( subtask_id: UUID, outputs: Optional[Dict[str, Any]] = None, cloud_logs_url: Optional[str] = None, ): now = datetime.now(timezone.utc) with get_db_session() as session: subtask = session.get(SubTask, subtask_id) if not subtask: raise ValueError(f"SubTask {subtask_id} not found") if outputs is not None: subtask.outputs = json.dumps(outputs) if cloud_logs_url is not None: subtask.cloud_logs_url = cloud_logs_url subtask.updated_at = now session.add(subtask) session.commit() session.refresh(subtask) return subtask # -------------------------------------------------------- # TASK PROGRESS CALCULATION # -------------------------------------------------------- @staticmethod def _update_task_progress(session: Session, task_id: UUID): task = session.get(Task, task_id) if not task: return subtasks = session.exec( select(SubTask).where(SubTask.task_id == task_id) ).all() statuses = [s.status.lower() for s in subtasks] now = datetime.now(timezone.utc) if "failed" in statuses: task.status = "failed" task.job_completed = now elif all(s == "complete" for s in statuses): task.status = "complete" task.job_completed = now elif "in progress" in statuses: task.status = "in progress" # if task.job_started is None: # task.job_started = now else: # All waiting task.status = "waiting" task.job_completed = None task.updated_at = now session.add(task) session.commit() def finalize_subtask( self, subtask_id: UUID, status: str, outputs: Optional[Dict[str, Any]], cloud_logs_url: Optional[str] ): now = datetime.now(timezone.utc) with get_db_session() as session: subtask = session.get(SubTask, subtask_id) if not subtask: raise ValueError(f"SubTask {subtask_id} not found") normalized = status.lower() if normalized not in ("complete", "failed"): raise ValueError("Status must be 'complete' or 'failed'") # Set outputs if outputs is not None: subtask.outputs = json.dumps(outputs) # Set logs if cloud_logs_url is not None: subtask.cloud_logs_url = cloud_logs_url # Status + timestamps subtask.status = normalized subtask.job_completed = now subtask.updated_at = now session.add(subtask) session.commit() # Update parent task (complete/failed) self._update_task_progress(session, subtask.task_id) session.refresh(subtask) return subtask # ============================================================ # Task Interface # ============================================================ class TasksInterface: """ High-level operations for Task records. """ @staticmethod def create_task( task_source: str, service: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None, task_only: bool = False, ): """ Create a new Task record, and an initial SubTask in waiting state. Can also be used to create just a task, without a subtask :param task_source: Text indicating source of task creation (e.g. file path + function name) :param service: Optional service name :param inputs: Inputs of the job being run :param task_only: If True, only create the Task record, without a SubTask :return: """ now = datetime.now(timezone.utc) with get_db_session() as session: task = Task( task_source=task_source, service=service, status="waiting", job_started=now, job_completed=None, ) session.add(task) session.commit() session.refresh(task) if task_only: return task.id, None # Create first subtask in waiting state subtask_interface = SubTaskInterface() subtask_id = subtask_interface.create_subtask( task_id=task.id, inputs=inputs, ) return task.id, subtask_id @staticmethod def update_task_status(task_id: UUID, status: str): now = datetime.now(timezone.utc) with get_db_session() as session: task = session.get(Task, task_id) if not task: raise ValueError(f"Task {task_id} not found") normalized = status.lower() if normalized == "in progress" and task.job_started is None: task.job_started = now if normalized == "complete": task.job_completed = now task.status = normalized task.updated_at = now session.add(task) session.commit() session.refresh(task) return task