mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
326 lines
10 KiB
Python
326 lines
10 KiB
Python
# ---- Standard Library ----
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime, timezone
|
|
from uuid import UUID
|
|
import json
|
|
|
|
# ---- SQLModel / SQLAlchemy ----
|
|
from sqlmodel import Session, select
|
|
|
|
# ---- DB Session ----
|
|
from backend.app.db.connection import get_db_session
|
|
|
|
# ---- Models ----
|
|
from backend.app.db.models.tasks import Task, SubTask
|
|
|
|
|
|
# ============================================================
|
|
# SubTask Interface
|
|
# ============================================================
|
|
class SubTaskInterface:
|
|
"""
|
|
CRUD operations for SubTask + cascading Task progress updates.
|
|
"""
|
|
|
|
# --------------------------------------------------------
|
|
# CREATE SUBTASK
|
|
# --------------------------------------------------------
|
|
def create_subtask(
|
|
self,
|
|
task_id: UUID,
|
|
inputs: Optional[Dict[str, Any]] = None,
|
|
status: Optional[str] = None,
|
|
):
|
|
|
|
now = datetime.now(timezone.utc)
|
|
with get_db_session() as session:
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
# We treat waiting as the default status
|
|
status = "waiting" if status is None else status
|
|
|
|
subtask = SubTask(
|
|
task_id=task_id,
|
|
inputs=json.dumps(inputs) if inputs else None,
|
|
status=status,
|
|
job_started=now,
|
|
job_completed=None,
|
|
)
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
|
|
# Recalculate parent task progress
|
|
self._update_task_progress(session, task_id)
|
|
return subtask.id
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE STATUS (in progress, complete, failed)
|
|
# --------------------------------------------------------
|
|
def update_subtask_status(
|
|
self, subtask_id: UUID, status: str, outputs=None, cloud_logs_url=None
|
|
):
|
|
"""
|
|
Update the status of a subtask, and recalculate the parent task progress.
|
|
:param subtask_id: UUID of the subtask to update
|
|
:param status: New status (in progress, complete, failed)
|
|
:param outputs: Optional outputs to set
|
|
:param cloud_logs_url: Optional cloud logs URL to set
|
|
:return:
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
normalized = status.lower()
|
|
|
|
# When job really starts
|
|
if normalized == "in progress" and subtask.job_started is None:
|
|
subtask.job_started = now
|
|
|
|
# Completed or failed
|
|
if normalized in ("complete", "failed"):
|
|
subtask.job_completed = now
|
|
|
|
subtask.status = normalized
|
|
subtask.updated_at = now
|
|
if outputs is not None:
|
|
subtask.outputs = json.dumps(outputs)
|
|
|
|
if cloud_logs_url is not None:
|
|
subtask.cloud_logs_url = cloud_logs_url
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
|
|
# Recalculate task status
|
|
self._update_task_progress(session, subtask.task_id)
|
|
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE OUTPUTS
|
|
# --------------------------------------------------------
|
|
@staticmethod
|
|
def update_subtask_output(subtask_id: UUID, outputs: Dict[str, Any]):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
subtask.outputs = json.dumps(outputs)
|
|
subtask.updated_at = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE CLOUD LOGS URL
|
|
# --------------------------------------------------------
|
|
@staticmethod
|
|
def update_subtask_logs(subtask_id: UUID, cloud_logs_url: str):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
subtask.cloud_logs_url = cloud_logs_url
|
|
subtask.updated_at = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# SET BOTH OUTPUT + LOGS
|
|
# --------------------------------------------------------
|
|
@staticmethod
|
|
def set_subtask_result(
|
|
subtask_id: UUID,
|
|
outputs: Optional[Dict[str, Any]] = None,
|
|
cloud_logs_url: Optional[str] = None,
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
if outputs is not None:
|
|
subtask.outputs = json.dumps(outputs)
|
|
|
|
if cloud_logs_url is not None:
|
|
subtask.cloud_logs_url = cloud_logs_url
|
|
|
|
subtask.updated_at = now
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# TASK PROGRESS CALCULATION
|
|
# --------------------------------------------------------
|
|
@staticmethod
|
|
def _update_task_progress(session: Session, task_id: UUID):
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
return
|
|
|
|
subtasks = session.exec(select(SubTask).where(SubTask.task_id == task_id)).all()
|
|
|
|
statuses = [s.status.lower() for s in subtasks]
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if "failed" in statuses:
|
|
task.status = "failed"
|
|
task.job_completed = now
|
|
|
|
elif all(s == "complete" for s in statuses):
|
|
task.status = "complete"
|
|
task.job_completed = now
|
|
|
|
elif "in progress" in statuses:
|
|
task.status = "in progress"
|
|
# if task.job_started is None:
|
|
# task.job_started = now
|
|
|
|
else:
|
|
# All waiting
|
|
task.status = "waiting"
|
|
task.job_completed = None
|
|
|
|
task.updated_at = now
|
|
session.add(task)
|
|
session.commit()
|
|
|
|
def finalize_subtask(
|
|
self,
|
|
subtask_id: UUID,
|
|
status: str,
|
|
outputs: Optional[Dict[str, Any]],
|
|
cloud_logs_url: Optional[str],
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
normalized = status.lower()
|
|
if normalized not in ("complete", "failed"):
|
|
raise ValueError("Status must be 'complete' or 'failed'")
|
|
|
|
# Set outputs
|
|
if outputs is not None:
|
|
subtask.outputs = json.dumps(outputs)
|
|
|
|
# Set logs
|
|
if cloud_logs_url is not None:
|
|
subtask.cloud_logs_url = cloud_logs_url
|
|
|
|
# Status + timestamps
|
|
subtask.status = normalized
|
|
subtask.job_completed = now
|
|
subtask.updated_at = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
|
|
# Update parent task (complete/failed)
|
|
self._update_task_progress(session, subtask.task_id)
|
|
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
|
|
# ============================================================
|
|
# Task Interface
|
|
# ============================================================
|
|
class TasksInterface:
|
|
"""
|
|
High-level operations for Task records.
|
|
"""
|
|
|
|
@staticmethod
|
|
def create_task(
|
|
task_source: str,
|
|
service: Optional[str] = None,
|
|
inputs: Optional[Dict[str, Any]] = None,
|
|
task_only: bool = False,
|
|
):
|
|
"""
|
|
Create a new Task record, and an initial SubTask in waiting state. Can also be used to create just
|
|
a task, without a subtask
|
|
:param task_source: Text indicating source of task creation (e.g. file path + function name)
|
|
:param service: Optional service name
|
|
:param inputs: Inputs of the job being run
|
|
:param task_only: If True, only create the Task record, without a SubTask
|
|
:return:
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
with get_db_session() as session:
|
|
task = Task(
|
|
task_source=task_source,
|
|
service=service,
|
|
status="waiting",
|
|
job_started=now,
|
|
job_completed=None,
|
|
)
|
|
|
|
session.add(task)
|
|
session.commit()
|
|
session.refresh(task)
|
|
|
|
if task_only:
|
|
return task.id, None
|
|
|
|
# Create first subtask in waiting state
|
|
subtask_interface = SubTaskInterface()
|
|
subtask_id = subtask_interface.create_subtask(
|
|
task_id=task.id,
|
|
inputs=inputs,
|
|
)
|
|
|
|
return task.id, subtask_id
|
|
|
|
@staticmethod
|
|
def update_task_status(task_id: UUID, status: str):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
normalized = status.lower()
|
|
|
|
if normalized == "in progress" and task.job_started is None:
|
|
task.job_started = now
|
|
|
|
if normalized == "complete":
|
|
task.job_completed = now
|
|
|
|
task.status = normalized
|
|
task.updated_at = now
|
|
|
|
session.add(task)
|
|
session.commit()
|
|
session.refresh(task)
|
|
return task
|