import json from datetime import datetime, timezone from typing import Any, Optional from uuid import UUID from sqlmodel import Session, select from domain.tasks.subtasks import SubTask, SubTaskStatus from infrastructure.postgres.subtask_table import SubTaskRow from repositories.tasks.subtask_repository import SubTaskRepository from utilities.private import private class SubTaskPostgresRepository(SubTaskRepository): def __init__(self, session: Session) -> None: self._session = session def create(self, subtask: SubTask) -> SubTask: row = self._to_row(subtask) self._session.add(row) self._session.commit() self._session.refresh(row) return self._to_domain(row) def get(self, subtask_id: UUID) -> SubTask: row = self._session.get(SubTaskRow, subtask_id) if row is None: raise ValueError(f"SubTask {subtask_id} not found") return self._to_domain(row) def save(self, subtask: SubTask) -> None: row = self._session.get(SubTaskRow, subtask.id) if row is None: raise ValueError(f"SubTask {subtask.id} not found") row.status = subtask.status.value row.job_started = subtask.job_started row.job_completed = subtask.job_completed row.inputs = ( json.dumps(subtask.inputs) if subtask.inputs is not None else None ) row.outputs = ( json.dumps(subtask.outputs) if subtask.outputs is not None else None ) row.cloud_logs_url = subtask.cloud_logs_url row.updated_at = datetime.now(timezone.utc) self._session.add(row) self._session.commit() def list_by_task(self, task_id: UUID) -> list[SubTask]: rows = self._session.exec( select(SubTaskRow).where(SubTaskRow.task_id == task_id) ).all() return [self._to_domain(r) for r in rows] @private def _to_row(self, subtask: SubTask) -> SubTaskRow: return SubTaskRow( id=subtask.id, task_id=subtask.task_id, status=subtask.status.value, inputs=( json.dumps(subtask.inputs) if subtask.inputs is not None else None ), outputs=( json.dumps(subtask.outputs) if subtask.outputs is not None else None ), cloud_logs_url=subtask.cloud_logs_url, job_started=subtask.job_started, job_completed=subtask.job_completed, ) @private def _to_domain(self, row: SubTaskRow) -> SubTask: return SubTask( id=row.id, task_id=row.task_id, status=SubTaskStatus(row.status.lower()), inputs=_loads_or_none(row.inputs), outputs=_loads_or_none(row.outputs), cloud_logs_url=row.cloud_logs_url, job_started=row.job_started, job_completed=row.job_completed, ) def _loads_or_none(s: Optional[str]) -> Optional[dict[str, Any]]: return json.loads(s) if s else None