Model/repositories/tasks/subtask_postgres_repository.py
2026-05-19 16:35:09 +00:00

89 lines
3 KiB
Python

import json
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID
from sqlmodel import Session, select
from domain.tasks.subtasks import SubTask, SubTaskStatus
from infrastructure.postgres.subtask_table import SubTaskRow
from repositories.tasks.subtask_repository import SubTaskRepository
from utilities.private import private
class SubTaskPostgresRepository(SubTaskRepository):
def __init__(self, session: Session) -> None:
self._session = session
def create(self, subtask: SubTask) -> SubTask:
row = self._to_row(subtask)
self._session.add(row)
self._session.commit()
self._session.refresh(row)
return self._to_domain(row)
def get(self, subtask_id: UUID) -> SubTask:
row = self._session.get(SubTaskRow, subtask_id)
if row is None:
raise ValueError(f"SubTask {subtask_id} not found")
return self._to_domain(row)
def save(self, subtask: SubTask) -> None:
row = self._session.get(SubTaskRow, subtask.id)
if row is None:
raise ValueError(f"SubTask {subtask.id} not found")
row.status = subtask.status.value
row.job_started = subtask.job_started
row.job_completed = subtask.job_completed
row.inputs = (
json.dumps(subtask.inputs) if subtask.inputs is not None else None
)
row.outputs = (
json.dumps(subtask.outputs) if subtask.outputs is not None else None
)
row.cloud_logs_url = subtask.cloud_logs_url
row.updated_at = datetime.now(timezone.utc)
self._session.add(row)
self._session.commit()
def list_by_task(self, task_id: UUID) -> list[SubTask]:
rows = self._session.exec(
select(SubTaskRow).where(SubTaskRow.task_id == task_id)
).all()
return [self._to_domain(r) for r in rows]
@private
def _to_row(self, subtask: SubTask) -> SubTaskRow:
return SubTaskRow(
id=subtask.id,
task_id=subtask.task_id,
status=subtask.status.value,
inputs=(
json.dumps(subtask.inputs) if subtask.inputs is not None else None
),
outputs=(
json.dumps(subtask.outputs)
if subtask.outputs is not None
else None
),
cloud_logs_url=subtask.cloud_logs_url,
job_started=subtask.job_started,
job_completed=subtask.job_completed,
)
@private
def _to_domain(self, row: SubTaskRow) -> SubTask:
return SubTask(
id=row.id,
task_id=row.task_id,
status=SubTaskStatus(row.status.lower()),
inputs=_loads_or_none(row.inputs),
outputs=_loads_or_none(row.outputs),
cloud_logs_url=row.cloud_logs_url,
job_started=row.job_started,
job_completed=row.job_completed,
)
def _loads_or_none(s: Optional[str]) -> Optional[dict[str, Any]]:
return json.loads(s) if s else None