diff --git a/utilities/aws_lambda/subtask_handler.py b/utilities/aws_lambda/subtask_handler.py index 592ffebf..42c9e07a 100644 --- a/utilities/aws_lambda/subtask_handler.py +++ b/utilities/aws_lambda/subtask_handler.py @@ -4,6 +4,7 @@ Translates an AWS Lambda invocation (SQS-shaped or direct) into TaskOrchestrator.run_subtask(...) calls. """ +import inspect import json import logging import os @@ -37,6 +38,8 @@ def subtask_handler( factory = orchestrator_cm or default_orchestrator def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + _wants_orchestrator = len(inspect.signature(func).parameters) >= 3 + @wraps(func) def wrapper(event: dict[str, Any], context: Any) -> None: cloud_logs_url = _cloudwatch_url() @@ -45,12 +48,22 @@ def subtask_handler( body = _parse_body(record) trigger = SubtaskTriggerBody.model_validate(body) logger.info("Running subtask %s", trigger.sub_task_id) + def _work_with( + _body: dict[str, Any] = body, + _o: TaskOrchestrator = orchestrator, + ) -> Any: + return func(_body, context, _o) + + def _work_without(_body: dict[str, Any] = body) -> Any: + return func(_body, context) + + work: Callable[[], Any] = ( + _work_with if _wants_orchestrator else _work_without + ) try: orchestrator.run_subtask( trigger.sub_task_id, - work=lambda body=body, o=orchestrator: func( - body, context, o - ), + work=work, cloud_logs_url=cloud_logs_url, ) except Exception: