mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
293 lines
8.9 KiB
Python
293 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
# ---- Standard Library ----
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime, timezone
|
|
from uuid import UUID
|
|
import json
|
|
|
|
# ---- SQLModel / SQLAlchemy ----
|
|
from sqlmodel import Session, select
|
|
|
|
# ---- DB Session ----
|
|
from backend.app.db.connection import get_db_session
|
|
|
|
# ---- Models ----
|
|
from backend.app.db.models.tasks import Task, SubTask
|
|
|
|
|
|
# ============================================================
|
|
# SubTask Interface
|
|
# ============================================================
|
|
class SubTaskInterface:
|
|
"""
|
|
CRUD operations for SubTask + cascading Task progress updates.
|
|
"""
|
|
|
|
# --------------------------------------------------------
|
|
# CREATE SUBTASK
|
|
# --------------------------------------------------------
|
|
def create_subtask(self, task_id: UUID, inputs: Optional[Dict[str, Any]] = None):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
subtask = SubTask(
|
|
taskId=task_id,
|
|
inputs=json.dumps(inputs) if inputs else None,
|
|
status="waiting",
|
|
jobStarted=None,
|
|
jobCompleted=None,
|
|
)
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
|
|
# Recalculate parent task progress
|
|
self._update_task_progress(session, task_id)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE STATUS (in progress, complete, failed)
|
|
# --------------------------------------------------------
|
|
def update_subtask_status(self, subtask_id: UUID, status: str):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
normalized = status.lower()
|
|
|
|
# When job really starts
|
|
if normalized == "in progress" and subtask.jobStarted is None:
|
|
subtask.jobStarted = now
|
|
|
|
# Completed or failed
|
|
if normalized in ("complete", "failed"):
|
|
subtask.jobCompleted = now
|
|
|
|
subtask.status = normalized
|
|
subtask.updatedAt = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
|
|
# Recalculate task status
|
|
self._update_task_progress(session, subtask.taskId)
|
|
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE OUTPUTS
|
|
# --------------------------------------------------------
|
|
def update_subtask_output(self, subtask_id: UUID, outputs: Dict[str, Any]):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
subtask.outputs = json.dumps(outputs)
|
|
subtask.updatedAt = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# UPDATE CLOUD LOGS URL
|
|
# --------------------------------------------------------
|
|
def update_subtask_logs(self, subtask_id: UUID, cloud_logs_url: str):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
subtask.cloudLogsURL = cloud_logs_url
|
|
subtask.updatedAt = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# SET BOTH OUTPUT + LOGS
|
|
# --------------------------------------------------------
|
|
def set_subtask_result(
|
|
self,
|
|
subtask_id: UUID,
|
|
outputs: Optional[Dict[str, Any]] = None,
|
|
cloud_logs_url: Optional[str] = None,
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
if outputs is not None:
|
|
subtask.outputs = json.dumps(outputs)
|
|
|
|
if cloud_logs_url is not None:
|
|
subtask.cloudLogsURL = cloud_logs_url
|
|
|
|
subtask.updatedAt = now
|
|
session.add(subtask)
|
|
session.commit()
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
# --------------------------------------------------------
|
|
# TASK PROGRESS CALCULATION
|
|
# --------------------------------------------------------
|
|
def _update_task_progress(self, session: Session, task_id: UUID):
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
return
|
|
|
|
subtasks = session.exec(
|
|
select(SubTask).where(SubTask.taskId == task_id)
|
|
).all()
|
|
|
|
statuses = [s.status.lower() for s in subtasks]
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if "failed" in statuses:
|
|
task.status = "failed"
|
|
task.jobCompleted = now
|
|
|
|
elif all(s == "complete" for s in statuses):
|
|
task.status = "complete"
|
|
task.jobCompleted = now
|
|
|
|
elif "in progress" in statuses:
|
|
task.status = "in progress"
|
|
if task.jobStarted is None:
|
|
task.jobStarted = now
|
|
|
|
else:
|
|
# All waiting
|
|
task.status = "waiting"
|
|
task.jobStarted = None
|
|
task.jobCompleted = None
|
|
|
|
task.updatedAt = now
|
|
session.add(task)
|
|
session.commit()
|
|
|
|
def finalize_subtask(
|
|
self,
|
|
subtask_id: UUID,
|
|
status: str,
|
|
outputs: Optional[Dict[str, Any]],
|
|
cloud_logs_url: Optional[str]
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
subtask = session.get(SubTask, subtask_id)
|
|
if not subtask:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
|
|
normalized = status.lower()
|
|
if normalized not in ("complete", "failed"):
|
|
raise ValueError("Status must be 'complete' or 'failed'")
|
|
|
|
# Set outputs
|
|
if outputs is not None:
|
|
subtask.outputs = json.dumps(outputs)
|
|
|
|
# Set logs
|
|
if cloud_logs_url is not None:
|
|
subtask.cloudLogsURL = cloud_logs_url
|
|
|
|
# Status + timestamps
|
|
subtask.status = normalized
|
|
subtask.jobCompleted = now
|
|
subtask.updatedAt = now
|
|
|
|
session.add(subtask)
|
|
session.commit()
|
|
|
|
# Update parent task (complete/failed)
|
|
self._update_task_progress(session, subtask.taskId)
|
|
|
|
session.refresh(subtask)
|
|
return subtask
|
|
|
|
|
|
# ============================================================
|
|
# Task Interface
|
|
# ============================================================
|
|
class TasksInterface:
|
|
"""
|
|
High-level operations for Task records.
|
|
"""
|
|
|
|
def create_task(
|
|
self,
|
|
*,
|
|
task_source: str,
|
|
service: Optional[str] = None,
|
|
inputs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
task = Task(
|
|
taskSource=task_source,
|
|
service=service,
|
|
status="waiting",
|
|
jobStarted=None,
|
|
jobCompleted=None,
|
|
)
|
|
|
|
session.add(task)
|
|
session.commit()
|
|
session.refresh(task)
|
|
|
|
# Create first subtask in waiting state
|
|
subtask_interface = SubTaskInterface()
|
|
subtask = subtask_interface.create_subtask(
|
|
task_id=task.id,
|
|
inputs=inputs,
|
|
)
|
|
|
|
return task.id, subtask.id
|
|
|
|
def update_task_status(self, task_id: UUID, status: str):
|
|
now = datetime.now(timezone.utc)
|
|
|
|
with get_db_session() as session:
|
|
task = session.get(Task, task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
|
|
normalized = status.lower()
|
|
|
|
if normalized == "in progress" and task.jobStarted is None:
|
|
task.jobStarted = now
|
|
|
|
if normalized == "complete":
|
|
task.jobCompleted = now
|
|
|
|
task.status = normalized
|
|
task.updatedAt = now
|
|
|
|
session.add(task)
|
|
session.commit()
|
|
session.refresh(task)
|
|
return task
|