Merge pull request #1210 from Hestia-Homes/bug/audit-generator-deploy-fix

Update subtask_handler to work for handlers where TaskOrchestrator is optional
This commit is contained in:
Daniel Roth 2026-06-10 14:04:20 +01:00 committed by GitHub
commit e7954ad83a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 45 additions and 11 deletions

View file

@ -13,14 +13,11 @@ from infrastructure.postgres.engine import make_engine, make_session
from infrastructure.s3.s3_client import S3Client
from orchestration.audit_generator_orchestrator import AuditGeneratorOrchestrator
from orchestration.audit_generator_unit_of_work import AuditGeneratorUnitOfWork
from orchestration.task_orchestrator import TaskOrchestrator
from utilities.aws_lambda.subtask_handler import subtask_handler
@subtask_handler()
def handler(
body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator
) -> None:
@subtask_handler(pass_task_orchestrator=False)
def handler(body: dict[str, Any], context: Any) -> None:
trigger = AuditGeneratorTriggerRequest.model_validate(body)
boto3_client: Any = (

View file

@ -228,6 +228,31 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask(
assert "$255B$2524LATEST$255D" in saved_url
def test_subtask_handler_completes_subtask_without_orchestrator_parameter(
harness: Harness,
) -> None:
# arrange
task, subtask = harness.orchestrator.create_task_with_subtask(
task_source="manual:test"
)
received: dict[str, Any] = {}
@subtask_handler(orchestrator_cm=harness.factory, pass_task_orchestrator=False)
def handler(body: dict[str, Any], context: Any) -> None:
received["body"] = body
received["context"] = context
# act
handler(_direct_event(task.id, subtask.id), context="ctx-sentinel")
# assert — SubTask lifecycle completes and handler received correct args
assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE
assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE
assert received["context"] == "ctx-sentinel"
assert received["body"]["sub_task_id"] == str(subtask.id)
def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda(
harness: Harness, monkeypatch: pytest.MonkeyPatch
) -> None:

View file

@ -25,6 +25,7 @@ OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]]
def subtask_handler(
*,
orchestrator_cm: Optional[OrchestratorCM] = None,
pass_task_orchestrator: bool = True,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Run the wrapped function as the body of an existing SubTask.
@ -37,6 +38,8 @@ def subtask_handler(
factory = orchestrator_cm or default_orchestrator
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
_wants_orchestrator = pass_task_orchestrator
@wraps(func)
def wrapper(event: dict[str, Any], context: Any) -> None:
cloud_logs_url = _cloudwatch_url()
@ -45,18 +48,27 @@ def subtask_handler(
body = _parse_body(record)
trigger = SubtaskTriggerBody.model_validate(body)
logger.info("Running subtask %s", trigger.sub_task_id)
def _work_with(
_body: dict[str, Any] = body,
_o: TaskOrchestrator = orchestrator,
) -> Any:
return func(_body, context, _o)
def _work_without(_body: dict[str, Any] = body) -> Any:
return func(_body, context)
work: Callable[[], Any] = (
_work_with if _wants_orchestrator else _work_without
)
try:
orchestrator.run_subtask(
trigger.sub_task_id,
work=lambda body=body, o=orchestrator: func(
body, context, o
),
work=work,
cloud_logs_url=cloud_logs_url,
)
except Exception:
logger.exception(
"Subtask %s failed", trigger.sub_task_id
)
logger.exception("Subtask %s failed", trigger.sub_task_id)
raise
logger.info("Subtask %s completed", trigger.sub_task_id)