mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
utilities/aws_lambda: @subtask_handler injects TaskOrchestrator as third positional arg
The wrapped function now receives the decorator-owned TaskOrchestrator as a third positional argument so handlers can compose their own use-case orchestrator that shares the session, instead of opening a second Postgres connection per invocation. Both existing callers (backend/ordnanceSurvey/main.py and backend/bulk_address2uprn_combiner/main.py) have their signatures extended to accept the new positional argument (typed Optional[TaskOrchestrator] so the legacy backend.utils.subtasks.subtask_handler — which only passes two args — keeps working until the migration to the new decorator lands). @task_handler is intentionally unchanged in this slice; symmetry is deferred per issue #1103.
This commit is contained in:
parent
d7f14033ba
commit
d70e8a9e53
6 changed files with 168 additions and 4 deletions
|
|
@ -2,7 +2,7 @@ import os
|
|||
import boto3
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ from backend.app.db.functions.bulk_address_uploads_functions import (
|
|||
set_combined_output_s3_uri,
|
||||
set_combining_status,
|
||||
)
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
|
@ -35,7 +36,16 @@ def download_csv(s3_client, bucket: str, key: str) -> pd.DataFrame:
|
|||
|
||||
|
||||
@subtask_handler()
|
||||
def handler(body: dict[str, Any], context: Any) -> str:
|
||||
def handler(
|
||||
body: dict[str, Any],
|
||||
context: Any,
|
||||
orchestrator: Optional[TaskOrchestrator] = None,
|
||||
) -> str:
|
||||
# `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler
|
||||
# decorator; unused here but accepted so the contract is uniform across
|
||||
# callers (see issue #1103).
|
||||
del orchestrator
|
||||
|
||||
task_id_str: str = body.get("task_id", "")
|
||||
|
||||
if not task_id_str:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from backend.ordnanceSurvey.helpers import (
|
|||
os_places_results_to_dataframe,
|
||||
)
|
||||
from backend.app.config import get_settings
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
|
@ -105,7 +106,16 @@ def save_results_to_s3(
|
|||
|
||||
|
||||
@subtask_handler() # This assumes task_id and subtask_id is defined in event.Records.body
|
||||
def handler(body: dict[str, Any], context: Any, local: bool = False) -> None:
|
||||
def handler(
|
||||
body: dict[str, Any],
|
||||
context: Any,
|
||||
orchestrator: Optional[TaskOrchestrator] = None,
|
||||
local: bool = False,
|
||||
) -> None:
|
||||
# `orchestrator` is injected by the new utilities.aws_lambda.subtask_handler
|
||||
# decorator; unused here but accepted so the contract is uniform across
|
||||
# callers (see issue #1103).
|
||||
del orchestrator
|
||||
|
||||
# delete this line after test
|
||||
# local = True
|
||||
|
|
|
|||
0
tests/utilities/__init__.py
Normal file
0
tests/utilities/__init__.py
Normal file
0
tests/utilities/aws_lambda/__init__.py
Normal file
0
tests/utilities/aws_lambda/__init__.py
Normal file
144
tests/utilities/aws_lambda/test_subtask_handler.py
Normal file
144
tests/utilities/aws_lambda/test_subtask_handler.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Tests for the @subtask_handler decorator.
|
||||
|
||||
Covers the contract that the decorator owns the parent SubTask lifecycle and
|
||||
injects the decorator-owned TaskOrchestrator as a third positional argument
|
||||
to the wrapped function — so the handler can compose its own use-case
|
||||
orchestrator that shares the session.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from domain.tasks.subtasks import SubTaskStatus
|
||||
from domain.tasks.tasks import TaskStatus
|
||||
from orchestration.task_orchestrator import TaskOrchestrator
|
||||
from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository
|
||||
from repositories.tasks.task_postgres_repository import TaskPostgresRepository
|
||||
from utilities.aws_lambda.subtask_handler import subtask_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
class Harness:
|
||||
orchestrator: TaskOrchestrator
|
||||
tasks: TaskPostgresRepository
|
||||
subtasks: SubTaskPostgresRepository
|
||||
|
||||
@contextmanager
|
||||
def factory(self) -> Generator[TaskOrchestrator, None, None]:
|
||||
yield self.orchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def harness() -> Iterator[Harness]:
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
tasks = TaskPostgresRepository(session=session)
|
||||
subtasks = SubTaskPostgresRepository(session=session)
|
||||
yield Harness(
|
||||
orchestrator=TaskOrchestrator(task_repo=tasks, subtask_repo=subtasks),
|
||||
tasks=tasks,
|
||||
subtasks=subtasks,
|
||||
)
|
||||
|
||||
|
||||
def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]:
|
||||
return {"task_id": str(task_id), "sub_task_id": str(subtask_id)}
|
||||
|
||||
|
||||
def test_subtask_handler_injects_orchestrator_as_third_positional_argument(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
"""The wrapped function receives the decorator-owned TaskOrchestrator
|
||||
so it can share the session with its own use-case orchestrator."""
|
||||
_, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
received: dict[str, Any] = {}
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
received["body"] = body
|
||||
received["context"] = context
|
||||
received["orchestrator"] = orchestrator
|
||||
|
||||
handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel")
|
||||
|
||||
assert received["orchestrator"] is harness.orchestrator
|
||||
assert received["context"] == "ctx-sentinel"
|
||||
assert received["body"]["sub_task_id"] == str(subtask.id)
|
||||
|
||||
|
||||
def test_subtask_handler_completes_parent_subtask_on_success(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
|
||||
|
||||
|
||||
def test_subtask_handler_marks_parent_failed_and_reraises_on_error(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.FAILED
|
||||
assert harness.tasks.get(task.id).status is TaskStatus.FAILED
|
||||
|
||||
|
||||
def test_subtask_handler_injected_orchestrator_can_create_child_subtask(
|
||||
harness: Harness,
|
||||
) -> None:
|
||||
"""Smoke check the share-the-session promise: the injected orchestrator
|
||||
is the same one the decorator owns, so a handler can use it to create
|
||||
child SubTasks under the same session."""
|
||||
task, subtask = harness.orchestrator.create_task_with_subtask(
|
||||
task_source="manual:test"
|
||||
)
|
||||
|
||||
child_ids: list[UUID] = []
|
||||
|
||||
@subtask_handler(orchestrator_cm=harness.factory)
|
||||
def handler(
|
||||
body: dict[str, Any], context: Any, orchestrator: TaskOrchestrator
|
||||
) -> None:
|
||||
child = orchestrator.create_child_subtask(task.id, inputs={"split": 1})
|
||||
child_ids.append(child.id)
|
||||
|
||||
handler(_direct_event(task.id, subtask.id), context=None)
|
||||
|
||||
assert len(child_ids) == 1
|
||||
persisted_child = harness.subtasks.get(child_ids[0])
|
||||
assert persisted_child.task_id == task.id
|
||||
assert persisted_child.status is SubTaskStatus.WAITING
|
||||
|
|
@ -39,7 +39,7 @@ def subtask_handler(
|
|||
trigger = SubtaskTriggerBody.model_validate(body)
|
||||
orchestrator.run_subtask(
|
||||
trigger.sub_task_id,
|
||||
work=lambda body=body: func(body, context),
|
||||
work=lambda body=body, o=orchestrator: func(body, context, o),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue