diff --git a/applications/audit_generator/handler.py b/applications/audit_generator/handler.py index 939c408e..3bbed734 100644 --- a/applications/audit_generator/handler.py +++ b/applications/audit_generator/handler.py @@ -13,14 +13,11 @@ from infrastructure.postgres.engine import make_engine, make_session from infrastructure.s3.s3_client import S3Client from orchestration.audit_generator_orchestrator import AuditGeneratorOrchestrator from orchestration.audit_generator_unit_of_work import AuditGeneratorUnitOfWork -from orchestration.task_orchestrator import TaskOrchestrator from utilities.aws_lambda.subtask_handler import subtask_handler -@subtask_handler() -def handler( - body: dict[str, Any], context: Any, task_orchestrator: TaskOrchestrator -) -> None: +@subtask_handler(pass_task_orchestrator=False) +def handler(body: dict[str, Any], context: Any) -> None: trigger = AuditGeneratorTriggerRequest.model_validate(body) boto3_client: Any = ( diff --git a/tests/utilities/aws_lambda/test_subtask_handler.py b/tests/utilities/aws_lambda/test_subtask_handler.py index d671adc4..b79de3b3 100644 --- a/tests/utilities/aws_lambda/test_subtask_handler.py +++ b/tests/utilities/aws_lambda/test_subtask_handler.py @@ -228,6 +228,31 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask( assert "$255B$2524LATEST$255D" in saved_url +def test_subtask_handler_completes_subtask_without_orchestrator_parameter( + harness: Harness, +) -> None: + # arrange + task, subtask = harness.orchestrator.create_task_with_subtask( + task_source="manual:test" + ) + + received: dict[str, Any] = {} + + @subtask_handler(orchestrator_cm=harness.factory, pass_task_orchestrator=False) + def handler(body: dict[str, Any], context: Any) -> None: + received["body"] = body + received["context"] = context + + # act + handler(_direct_event(task.id, subtask.id), context="ctx-sentinel") + + # assert — SubTask lifecycle completes and handler received correct args + assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE + assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE + assert received["context"] == "ctx-sentinel" + assert received["body"]["sub_task_id"] == str(subtask.id) + + def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( harness: Harness, monkeypatch: pytest.MonkeyPatch ) -> None: diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index 592ffebf..e5ac086a 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -25,6 +25,7 @@ OrchestratorCM = Callable[[], AbstractContextManager[TaskOrchestrator]] def subtask_handler( *, orchestrator_cm: Optional[OrchestratorCM] = None, + pass_task_orchestrator: bool = True, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Run the wrapped function as the body of an existing SubTask. @@ -37,6 +38,8 @@ def subtask_handler( factory = orchestrator_cm or default_orchestrator def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + _wants_orchestrator = pass_task_orchestrator + @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: cloud_logs_url = _cloudwatch_url() @@ -45,18 +48,27 @@ def subtask_handler( body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) logger.info("Running subtask %s", trigger.sub_task_id) + + def _work_with( + _body: dict[str, Any] = body, + _o: TaskOrchestrator = orchestrator, + ) -> Any: + return func(_body, context, _o) + + def _work_without(_body: dict[str, Any] = body) -> Any: + return func(_body, context) + + work: Callable[[], Any] = ( + _work_with if _wants_orchestrator else _work_without + ) try: orchestrator.run_subtask( trigger.sub_task_id, - work=lambda body=body, o=orchestrator: func( - body, context, o - ), + work=work, cloud_logs_url=cloud_logs_url, ) except Exception: - logger.exception( - "Subtask %s failed", trigger.sub_task_id - ) + logger.exception("Subtask %s failed", trigger.sub_task_id) raise logger.info("Subtask %s completed", trigger.sub_task_id)