Model/backend/app/db/functions/tasks/Tasks.py
2026-03-20 09:53:48 +00:00

334 lines
10 KiB
Python

# ---- Standard Library ----
from typing import Optional, Dict, Any
from datetime import datetime, timezone
from uuid import UUID
import json
# ---- SQLModel / SQLAlchemy ----
from sqlmodel import Session, select
# ---- DB Session ----
from backend.app.db.connection import get_db_session
# ---- Models ----
from backend.app.db.models.tasks import SourceEnum, Task, SubTask
# ============================================================
# SubTask Interface
# ============================================================
class SubTaskInterface:
"""
CRUD operations for SubTask + cascading Task progress updates.
"""
# --------------------------------------------------------
# CREATE SUBTASK
# --------------------------------------------------------
def create_subtask(
self,
task_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
status: Optional[str] = None,
):
now = datetime.now(timezone.utc)
with get_db_session() as session:
task = session.get(Task, task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
# We treat waiting as the default status
status = "waiting" if status is None else status
subtask = SubTask(
task_id=task_id,
inputs=json.dumps(inputs) if inputs else None,
status=status,
job_started=now,
job_completed=None,
)
session.add(subtask)
session.commit()
session.refresh(subtask)
# Recalculate parent task progress
self._update_task_progress(session, task_id)
return subtask.id
# --------------------------------------------------------
# UPDATE STATUS (in progress, complete, failed)
# --------------------------------------------------------
def update_subtask_status(
self,
subtask_id: UUID,
status: str,
outputs: Optional[Dict[str, str] | str] = None,
cloud_logs_url: Optional[str] = None,
) -> SubTask:
"""
Update the status of a subtask, and recalculate the parent task progress.
:param subtask_id: UUID of the subtask to update
:param status: New status (in progress, complete, failed)
:param outputs: Optional outputs to set
:param cloud_logs_url: Optional cloud logs URL to set
:return:
"""
now = datetime.now(timezone.utc)
with get_db_session() as session:
subtask = session.get(SubTask, subtask_id)
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
normalized = status.lower()
# When job really starts
if normalized == "in progress" and subtask.job_started is None:
subtask.job_started = now
# Completed or failed
if normalized in ("complete", "failed"):
subtask.job_completed = now
subtask.status = normalized
subtask.updated_at = now
if outputs is not None:
subtask.outputs = json.dumps(outputs)
if cloud_logs_url is not None:
subtask.cloud_logs_url = cloud_logs_url
session.add(subtask)
session.commit()
# Recalculate task status
self._update_task_progress(session, subtask.task_id)
session.refresh(subtask)
return subtask
# --------------------------------------------------------
# UPDATE OUTPUTS
# --------------------------------------------------------
@staticmethod
def update_subtask_output(subtask_id: UUID, outputs: Dict[str, Any]):
now = datetime.now(timezone.utc)
with get_db_session() as session:
subtask = session.get(SubTask, subtask_id)
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
subtask.outputs = json.dumps(outputs)
subtask.updated_at = now
session.add(subtask)
session.commit()
session.refresh(subtask)
return subtask
# --------------------------------------------------------
# UPDATE CLOUD LOGS URL
# --------------------------------------------------------
@staticmethod
def update_subtask_logs(subtask_id: UUID, cloud_logs_url: str):
now = datetime.now(timezone.utc)
with get_db_session() as session:
subtask = session.get(SubTask, subtask_id)
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
subtask.cloud_logs_url = cloud_logs_url
subtask.updated_at = now
session.add(subtask)
session.commit()
session.refresh(subtask)
return subtask
# --------------------------------------------------------
# SET BOTH OUTPUT + LOGS
# --------------------------------------------------------
@staticmethod
def set_subtask_result(
subtask_id: UUID,
outputs: Optional[Dict[str, Any]] = None,
cloud_logs_url: Optional[str] = None,
):
now = datetime.now(timezone.utc)
with get_db_session() as session:
subtask = session.get(SubTask, subtask_id)
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
if outputs is not None:
subtask.outputs = json.dumps(outputs)
if cloud_logs_url is not None:
subtask.cloud_logs_url = cloud_logs_url
subtask.updated_at = now
session.add(subtask)
session.commit()
session.refresh(subtask)
return subtask
# --------------------------------------------------------
# TASK PROGRESS CALCULATION
# --------------------------------------------------------
@staticmethod
def _update_task_progress(session: Session, task_id: UUID):
task = session.get(Task, task_id)
if not task:
return
subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
statuses = [s.status.lower() for s in subtasks]
now = datetime.now(timezone.utc)
if "failed" in statuses:
task.status = "failed"
task.job_completed = now
elif all(s == "complete" for s in statuses):
task.status = "complete"
task.job_completed = now
elif "in progress" in statuses:
task.status = "in progress"
# if task.job_started is None:
# task.job_started = now
else:
# All waiting
task.status = "waiting"
task.job_completed = None
task.updated_at = now
session.add(task)
session.commit()
def finalize_subtask(
self,
subtask_id: UUID,
status: str,
outputs: Optional[Dict[str, Any]],
cloud_logs_url: Optional[str],
):
now = datetime.now(timezone.utc)
with get_db_session() as session:
subtask = session.get(SubTask, subtask_id)
if not subtask:
raise ValueError(f"SubTask {subtask_id} not found")
normalized = status.lower()
if normalized not in ("complete", "failed"):
raise ValueError("Status must be 'complete' or 'failed'")
# Set outputs
if outputs is not None:
subtask.outputs = json.dumps(outputs)
# Set logs
if cloud_logs_url is not None:
subtask.cloud_logs_url = cloud_logs_url
# Status + timestamps
subtask.status = normalized
subtask.job_completed = now
subtask.updated_at = now
session.add(subtask)
session.commit()
# Update parent task (complete/failed)
self._update_task_progress(session, subtask.task_id)
session.refresh(subtask)
return subtask
# ============================================================
# Task Interface
# ============================================================
class TasksInterface:
"""
High-level operations for Task records.
"""
@staticmethod
def create_task(
task_source: str,
service: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None,
task_only: bool = False,
source: Optional[SourceEnum] = None,
source_id: Optional[str] = None,
):
"""
Create a new Task record, and an initial SubTask in waiting state. Can also be used to create just
a task, without a subtask
:param task_source: Text indicating source of task creation (e.g. file path + function name)
:param service: Optional service name
:param inputs: Inputs of the job being run
:param task_only: If True, only create the Task record, without a SubTask
:return:
"""
now = datetime.now(timezone.utc)
with get_db_session() as session:
task = Task(
task_source=task_source,
service=service,
status="waiting",
job_started=now,
job_completed=None,
source=source,
source_id=source_id,
)
session.add(task)
session.commit()
session.refresh(task)
if task_only:
return task.id, None
# Create first subtask in waiting state
subtask_interface = SubTaskInterface()
subtask_id = subtask_interface.create_subtask(
task_id=task.id,
inputs=inputs,
)
return task.id, subtask_id
@staticmethod
def update_task_status(task_id: UUID, status: str):
now = datetime.now(timezone.utc)
with get_db_session() as session:
task = session.get(Task, task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
normalized = status.lower()
if normalized == "in progress" and task.job_started is None:
task.job_started = now
if normalized == "complete":
task.job_completed = now
task.status = normalized
task.updated_at = now
session.add(task)
session.commit()
session.refresh(task)
return task