mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
89 lines
3 KiB
Python
89 lines
3 KiB
Python
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlmodel import Session, select
|
|
|
|
from domain.tasks.subtasks import SubTask, SubTaskStatus
|
|
from infrastructure.postgres.subtask_table import SubTaskRow
|
|
from repositories.tasks.subtask_repository import SubTaskRepository
|
|
from utilities.private import private
|
|
|
|
|
|
class SubTaskPostgresRepository(SubTaskRepository):
|
|
def __init__(self, session: Session) -> None:
|
|
self._session = session
|
|
|
|
def create(self, subtask: SubTask) -> SubTask:
|
|
row = self._to_row(subtask)
|
|
self._session.add(row)
|
|
self._session.commit()
|
|
self._session.refresh(row)
|
|
return self._to_domain(row)
|
|
|
|
def get(self, subtask_id: UUID) -> SubTask:
|
|
row = self._session.get(SubTaskRow, subtask_id)
|
|
if row is None:
|
|
raise ValueError(f"SubTask {subtask_id} not found")
|
|
return self._to_domain(row)
|
|
|
|
def save(self, subtask: SubTask) -> None:
|
|
row = self._session.get(SubTaskRow, subtask.id)
|
|
if row is None:
|
|
raise ValueError(f"SubTask {subtask.id} not found")
|
|
row.status = subtask.status.value
|
|
row.job_started = subtask.job_started
|
|
row.job_completed = subtask.job_completed
|
|
row.inputs = (
|
|
json.dumps(subtask.inputs) if subtask.inputs is not None else None
|
|
)
|
|
row.outputs = (
|
|
json.dumps(subtask.outputs) if subtask.outputs is not None else None
|
|
)
|
|
row.cloud_logs_url = subtask.cloud_logs_url
|
|
row.updated_at = datetime.now(timezone.utc)
|
|
self._session.add(row)
|
|
self._session.commit()
|
|
|
|
def list_by_task(self, task_id: UUID) -> list[SubTask]:
|
|
rows = self._session.exec(
|
|
select(SubTaskRow).where(SubTaskRow.task_id == task_id)
|
|
).all()
|
|
return [self._to_domain(r) for r in rows]
|
|
|
|
@private
|
|
def _to_row(self, subtask: SubTask) -> SubTaskRow:
|
|
return SubTaskRow(
|
|
id=subtask.id,
|
|
task_id=subtask.task_id,
|
|
status=subtask.status.value,
|
|
inputs=(
|
|
json.dumps(subtask.inputs) if subtask.inputs is not None else None
|
|
),
|
|
outputs=(
|
|
json.dumps(subtask.outputs)
|
|
if subtask.outputs is not None
|
|
else None
|
|
),
|
|
cloud_logs_url=subtask.cloud_logs_url,
|
|
job_started=subtask.job_started,
|
|
job_completed=subtask.job_completed,
|
|
)
|
|
|
|
@private
|
|
def _to_domain(self, row: SubTaskRow) -> SubTask:
|
|
return SubTask(
|
|
id=row.id,
|
|
task_id=row.task_id,
|
|
status=SubTaskStatus(row.status.lower()),
|
|
inputs=_loads_or_none(row.inputs),
|
|
outputs=_loads_or_none(row.outputs),
|
|
cloud_logs_url=row.cloud_logs_url,
|
|
job_started=row.job_started,
|
|
job_completed=row.job_completed,
|
|
)
|
|
|
|
|
|
def _loads_or_none(s: Optional[str]) -> Optional[dict[str, Any]]:
|
|
return json.loads(s) if s else None
|