diff --git a/repositories/user_address/user_address_csv_s3_repository.py b/repositories/user_address/user_address_csv_s3_repository.py index 9b93b638..058fd5a5 100644 --- a/repositories/user_address/user_address_csv_s3_repository.py +++ b/repositories/user_address/user_address_csv_s3_repository.py @@ -53,6 +53,9 @@ class UserAddressCsvS3Repository(UserAddressRepository): {**addr.source_row, _POSTCODE_CLEAN_COLUMN: str(addr.postcode)} for addr in addresses ] + + # TODO: [New Starter Task] file_name generation can be standardised + # and also easier to read, test for future implementation. Buiild that! filename = ( f"{datetime.now(timezone.utc).isoformat()}_{uuid.uuid4().hex[:8]}.csv" ) diff --git a/repositories/user_address/user_address_repository.py b/repositories/user_address/user_address_repository.py index 170f34dd..b2c0f866 100644 --- a/repositories/user_address/user_address_repository.py +++ b/repositories/user_address/user_address_repository.py @@ -7,9 +7,7 @@ from domain.addresses.user_address import UserAddress class UserAddressRepository(ABC): @abstractmethod - def load_batch(self, s3_uri: str) -> list[UserAddress]: - ... + def load_batch(self, s3_uri: str) -> list[UserAddress]: ... @abstractmethod - def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: - ... + def save_batch(self, addresses: list[UserAddress], path_prefix: str) -> str: ... diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0a246372 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +"""Shared pytest fixtures for the ``tests/`` tree. + +Provides an ephemeral PostgreSQL engine for tests that exercise SQLModel +repositories. PostgreSQL has no true in-memory mode; ``pytest-postgresql`` +starts a real, throwaway server in a temp directory (the process is started +once per session and a fresh database is created/dropped per test). That is +the closest equivalent to "in-memory" and matches production behaviour far +better than SQLite (enums, JSONB, constraint semantics, etc.). +""" + +from __future__ import annotations + +import glob +from collections.abc import Iterator +from typing import Any + +import pytest +from psycopg import Connection +from pytest_postgresql import factories +from sqlalchemy import Engine +from sqlmodel import SQLModel, create_engine + +# Importing the SQLModel row modules registers their tables on +# SQLModel.metadata so ``create_all`` builds the full schema. Imports look +# unused; they aren't. + + +# pg_ctl ships under a versioned path and is not on PATH in the dev container. +_PG_CTL = next(iter(sorted(glob.glob("/usr/lib/postgresql/*/bin/pg_ctl"))), "pg_ctl") + +postgresql_proc = factories.postgresql_proc( + executable=_PG_CTL +) # pyright: ignore[reportUnknownMemberType] +postgresql = factories.postgresql("postgresql_proc") + + +@pytest.fixture +def db_engine(postgresql: Connection[Any]) -> Iterator[Engine]: + """A SQLModel engine bound to a fresh, ephemeral PostgreSQL database.""" + info = postgresql.info + url = f"postgresql+psycopg://{info.user}:@{info.host}:{info.port}/{info.dbname}" + engine = create_engine(url) + SQLModel.metadata.create_all(engine) + try: + yield engine + finally: + SQLModel.metadata.drop_all(engine) + engine.dispose() diff --git a/tests/domain/addresses/test_postcode_batching.py b/tests/domain/addresses/test_postcode_batching.py index c69722ba..8ffcf1b5 100644 --- a/tests/domain/addresses/test_postcode_batching.py +++ b/tests/domain/addresses/test_postcode_batching.py @@ -15,12 +15,16 @@ def _addrs(postcode: str, n: int) -> list[UserAddress]: def test_empty_input_yields_no_batches() -> None: + # act / assert assert list(iter_postcode_grouped_batches([])) == [] def test_single_batch_under_cap() -> None: + # arrange addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2) + # act batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500)) + # assert assert len(batches) == 1 assert batches[0] == addrs @@ -28,8 +32,11 @@ def test_single_batch_under_cap() -> None: def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None: # Two groups whose total exactly equals the cap pack into a single # batch -- no premature flush. + # arrange addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 2) + # act batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert assert len(batches) == 1 assert len(batches[0]) == 5 @@ -37,8 +44,11 @@ def test_multiple_postcodes_packed_into_one_batch_up_to_cap() -> None: def test_flush_on_overflow_before_adding_next_postcode() -> None: # Cap is 5. First group fills 3 slots; second group of 3 would overflow, # so the buffer is flushed first and the next group starts a fresh batch. + # arrange addrs = _addrs("AA1 1AA", 3) + _addrs("BB2 2BB", 3) + # act batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert assert len(batches) == 2 assert [str(a.postcode) for a in batches[0]] == ["AA11AA"] * 3 assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 3 @@ -47,8 +57,11 @@ def test_flush_on_overflow_before_adding_next_postcode() -> None: def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None: # An oversize single-postcode group goes out as one batch larger than # the cap -- the cap never splits a postcode. + # arrange addrs = _addrs("AA1 1AA", 7) + # act batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=5)) + # assert assert len(batches) == 1 assert len(batches[0]) == 7 @@ -56,12 +69,15 @@ def test_single_postcode_group_exceeding_cap_is_dispatched_whole() -> None: def test_oversize_group_flushes_existing_buffer_first() -> None: # Mirrors the legacy ``if buffer: flush`` branch when an oversize group # is encountered: buffered work must not be lost or interleaved. + # arrange small = _addrs("AA1 1AA", 2) big = _addrs("BB2 2BB", 7) tail = _addrs("CC3 3CC", 1) + # act batches = list( iter_postcode_grouped_batches(small + big + tail, max_batch_size=5) ) + # assert assert len(batches) == 3 assert [str(a.postcode) for a in batches[0]] == ["AA11AA", "AA11AA"] assert [str(a.postcode) for a in batches[1]] == ["BB22BB"] * 7 @@ -70,17 +86,23 @@ def test_oversize_group_flushes_existing_buffer_first() -> None: def test_final_flush_yields_remaining_buffer() -> None: # No overflow ever happens, but the trailing buffer must still come out. + # arrange addrs = _addrs("AA1 1AA", 2) + _addrs("BB2 2BB", 2) + # act batches = list(iter_postcode_grouped_batches(addrs, max_batch_size=500)) + # assert assert batches == [addrs] def test_postcode_grouping_preserves_first_seen_order() -> None: # Interleaved input must still group by postcode and emit in first-seen # order -- never alphabetical. + # arrange a1, a2 = _addrs("ZZ9 9ZZ", 2) b1, b2 = _addrs("AA1 1AA", 2) + # act batches = list(iter_postcode_grouped_batches([a1, b1, a2, b2])) + # assert assert len(batches) == 1 assert [str(a.postcode) for a in batches[0]] == [ "ZZ99ZZ", @@ -91,5 +113,6 @@ def test_postcode_grouping_preserves_first_seen_order() -> None: def test_invalid_max_batch_size_raises() -> None: + # act / assert with pytest.raises(ValueError, match="max_batch_size"): list(iter_postcode_grouped_batches([], max_batch_size=0)) diff --git a/tests/domain/addresses/test_user_address.py b/tests/domain/addresses/test_user_address.py index fa44ad61..8d092df3 100644 --- a/tests/domain/addresses/test_user_address.py +++ b/tests/domain/addresses/test_user_address.py @@ -7,35 +7,45 @@ from domain.postcode import Postcode def test_user_address_holds_postcode_value_object() -> None: + # act addr = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa")) + # assert assert addr.postcode == Postcode("SW1A1AA") def test_user_address_preserves_user_address_verbatim() -> None: # The free-text user_address string is intentionally NOT normalised -- # only the postcode is canonicalised, and that happens inside Postcode. + # act addr = UserAddress( user_address=" 1 The Street ", postcode=Postcode("SW1A1AA") ) + # assert assert addr.user_address == " 1 The Street " def test_user_address_internal_reference_defaults_to_none() -> None: + # act addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # assert assert addr.internal_reference is None def test_user_address_internal_reference_accepted() -> None: + # act addr = UserAddress( user_address="1 The Street", postcode=Postcode("SW1A1AA"), internal_reference="cust-42", ) + # assert assert addr.internal_reference == "cust-42" def test_user_address_is_frozen() -> None: + # arrange addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # act / assert with pytest.raises(dataclasses.FrozenInstanceError): addr.postcode = Postcode("OTHER") # type: ignore[misc] @@ -43,29 +53,37 @@ def test_user_address_is_frozen() -> None: def test_user_address_equality_uses_canonical_postcode() -> None: # Postcode sanitises eagerly, so addresses built from different surface # forms of the same postcode compare equal. + # arrange a = UserAddress(user_address="1 The Street", postcode=Postcode("sw1a 1aa")) b = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # act / assert assert a == b def test_user_address_source_row_defaults_to_empty_dict() -> None: + # act addr = UserAddress(user_address="1 The Street", postcode=Postcode("SW1A1AA")) + # assert assert addr.source_row == {} def test_user_address_carries_source_row() -> None: + # arrange row = {"Address 1": "1 The Street", "postcode": "SW1A 1AA", "SAP Score": "72"} + # act addr = UserAddress( user_address="1 The Street", postcode=Postcode("SW1A 1AA"), source_row=row, ) + # assert assert addr.source_row == row def test_user_address_equality_ignores_source_row() -> None: # source_row is excluded from equality (and hashing): identity stays # defined by the parsed fields. + # arrange a = UserAddress( user_address="1 The Street", postcode=Postcode("SW1A1AA"), @@ -76,4 +94,5 @@ def test_user_address_equality_ignores_source_row() -> None: postcode=Postcode("SW1A1AA"), source_row={"y": "2"}, ) + # act / assert assert a == b diff --git a/tests/domain/tasks/test_subtasks.py b/tests/domain/tasks/test_subtasks.py index 2721d38f..8cee4496 100644 --- a/tests/domain/tasks/test_subtasks.py +++ b/tests/domain/tasks/test_subtasks.py @@ -6,10 +6,13 @@ from domain.tasks.subtasks import SubTask, SubTaskStatus def test_create_subtask_starts_waiting() -> None: + # arrange task_id = uuid4() + # act st = SubTask.create(task_id=task_id, inputs={"foo": "bar"}) + # assert assert st.task_id == task_id assert st.status is SubTaskStatus.WAITING assert st.inputs == {"foo": "bar"} @@ -19,57 +22,74 @@ def test_create_subtask_starts_waiting() -> None: def test_start_transitions_to_in_progress_and_sets_cloud_logs_url() -> None: + # arrange st = SubTask.create(task_id=uuid4()) + # act st.start(cloud_logs_url="https://example/log") + # assert assert st.status is SubTaskStatus.IN_PROGRESS assert st.cloud_logs_url == "https://example/log" assert st.job_started is not None def test_start_is_idempotent_from_in_progress() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.start() first_start = st.job_started + # act st.start(cloud_logs_url="https://other") + # assert assert st.status is SubTaskStatus.IN_PROGRESS assert st.job_started == first_start # not overwritten assert st.cloud_logs_url == "https://other" def test_start_rejects_from_terminal_status() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.complete() + # act / assert with pytest.raises(ValueError): st.start() def test_complete_marks_outputs_and_job_completed() -> None: + # arrange st = SubTask.create(task_id=uuid4()) st.start() + # act st.complete({"uprn": "123"}) + # assert assert st.status is SubTaskStatus.COMPLETE assert st.outputs == {"result": {"uprn": "123"}} assert st.job_completed is not None def test_complete_without_result_leaves_outputs_unset() -> None: + # arrange st = SubTask.create(task_id=uuid4()) + # act st.complete() + # assert assert st.outputs is None def test_fail_records_error_in_outputs() -> None: + # arrange st = SubTask.create(task_id=uuid4()) err = RuntimeError("boom") + # act st.fail(err) + # assert assert st.status is SubTaskStatus.FAILED assert st.outputs == {"error": "boom"} assert st.job_completed is not None diff --git a/tests/domain/tasks/test_tasks.py b/tests/domain/tasks/test_tasks.py index f30c0aa1..ba82412b 100644 --- a/tests/domain/tasks/test_tasks.py +++ b/tests/domain/tasks/test_tasks.py @@ -5,12 +5,12 @@ from domain.tasks.tasks import Source, Task, TaskStatus def test_create_task_starts_waiting() -> None: - # Arrange / Act + # arrange / act t = Task.create( task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123" ) - # Assert + # assert assert t.status is TaskStatus.WAITING assert t.source is Source.PORTFOLIO assert t.source_id == "abc-123" @@ -19,86 +19,113 @@ def test_create_task_starts_waiting() -> None: def test_create_task_rejects_blank_task_source() -> None: + # act / assert with pytest.raises(ValueError, match="task_source"): Task.create(task_source=" ") def test_start_transitions_to_in_progress() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.start() + # assert assert t.status is TaskStatus.IN_PROGRESS def test_complete_marks_job_completed() -> None: + # arrange t = Task.create(task_source="manual:test") t.start() + # act t.complete() + # assert assert t.status is TaskStatus.COMPLETE assert t.job_completed is not None def test_fail_marks_job_completed() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.fail() + # assert assert t.status is TaskStatus.FAILED assert t.job_completed is not None def test_start_rejects_from_terminal_status() -> None: + # arrange t = Task.create(task_source="manual:test") t.complete() + # act / assert with pytest.raises(ValueError): t.start() def test_recalculate_with_empty_statuses_is_noop() -> None: + # arrange t = Task.create(task_source="manual:test") original_status = t.status original_completed = t.job_completed + # act t.recalculate_from_subtasks([]) + # assert assert t.status is original_status assert t.job_completed is original_completed def test_recalculate_all_waiting_keeps_waiting() -> None: + # arrange t = Task.create(task_source="manual:test") t.start() # task moved to IN_PROGRESS earlier t.complete() # then COMPLETE, with job_completed set + # act t.recalculate_from_subtasks([SubTaskStatus.WAITING, SubTaskStatus.WAITING]) + # assert assert t.status is TaskStatus.WAITING assert t.job_completed is None def test_recalculate_any_in_progress_marks_in_progress() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks( [SubTaskStatus.WAITING, SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE] ) + # assert assert t.status is TaskStatus.IN_PROGRESS assert t.job_completed is None def test_recalculate_all_complete_marks_complete() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks([SubTaskStatus.COMPLETE, SubTaskStatus.COMPLETE]) + # assert assert t.status is TaskStatus.COMPLETE assert t.job_completed is not None def test_recalculate_any_failed_marks_failed_even_with_others() -> None: + # arrange t = Task.create(task_source="manual:test") + # act t.recalculate_from_subtasks( [SubTaskStatus.IN_PROGRESS, SubTaskStatus.COMPLETE, SubTaskStatus.FAILED] ) + # assert assert t.status is TaskStatus.FAILED assert t.job_completed is not None diff --git a/tests/domain/test_postcode.py b/tests/domain/test_postcode.py index 89d5cdc8..f7ce9015 100644 --- a/tests/domain/test_postcode.py +++ b/tests/domain/test_postcode.py @@ -6,43 +6,54 @@ from domain.postcode import Postcode def test_postcode_uppercases() -> None: + # act / assert assert Postcode("sw1a1aa").value == "SW1A1AA" def test_postcode_strips_internal_spaces() -> None: + # act / assert assert Postcode("sw1a 1aa").value == "SW1A1AA" def test_postcode_strips_leading_and_trailing_whitespace() -> None: + # act / assert assert Postcode(" sw1a 1aa ").value == "SW1A1AA" def test_postcode_strips_tabs_and_newlines() -> None: # CSV ingestion occasionally introduces stray whitespace characters; the # canonical form must absorb them just like literal spaces. + # act / assert assert Postcode("sw1a\t1aa\n").value == "SW1A1AA" def test_postcode_construction_is_idempotent() -> None: + # arrange once = Postcode("sw1a 1aa") + # act / assert assert Postcode(once.value).value == "SW1A1AA" def test_postcode_empty_string() -> None: + # act / assert assert Postcode("").value == "" def test_postcode_str_returns_canonical_value() -> None: + # act / assert assert str(Postcode("sw1a 1aa")) == "SW1A1AA" def test_postcode_equality_ignores_surface_form() -> None: # Differing case / whitespace sanitise to the same canonical value, so # the value objects compare equal. + # act / assert assert Postcode("sw1a 1aa") == Postcode("SW1A1AA") def test_postcode_is_frozen() -> None: + # arrange postcode = Postcode("SW1A1AA") + # act / assert with pytest.raises(dataclasses.FrozenInstanceError): postcode.value = "OTHER" # type: ignore[misc] diff --git a/tests/infrastructure/test_address2uprn_queue_client.py b/tests/infrastructure/test_address2uprn_queue_client.py index b4114742..c8e89ece 100644 --- a/tests/infrastructure/test_address2uprn_queue_client.py +++ b/tests/infrastructure/test_address2uprn_queue_client.py @@ -28,12 +28,15 @@ def queue_setup() -> Iterator[tuple[Address2UprnQueueClient, Any, str]]: def test_publish_returns_message_id( queue_setup: tuple[Address2UprnQueueClient, Any, str], ) -> None: + # arrange client, _boto, _url = queue_setup + # act message_id = client.publish( parent_task_id=uuid4(), child_subtask_id=uuid4(), s3_uri="s3://my-bucket/path/to/chunk.csv", ) + # assert assert isinstance(message_id, str) assert message_id @@ -41,17 +44,20 @@ def test_publish_returns_message_id( def test_publish_body_uses_typed_shape( queue_setup: tuple[Address2UprnQueueClient, Any, str], ) -> None: + # arrange client, boto_client, queue_url = queue_setup parent_id = uuid4() child_id = uuid4() s3_uri = "s3://my-bucket/path/to/chunk.csv" + # act client.publish( parent_task_id=parent_id, child_subtask_id=child_id, s3_uri=s3_uri, ) + # assert received: dict[str, Any] = boto_client.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=1 ) diff --git a/tests/infrastructure/test_csv_s3_client.py b/tests/infrastructure/test_csv_s3_client.py index 4b9fc199..30e27164 100644 --- a/tests/infrastructure/test_csv_s3_client.py +++ b/tests/infrastructure/test_csv_s3_client.py @@ -18,26 +18,34 @@ def csv_client() -> Iterator[CsvS3Client]: def test_save_rows_returns_s3_uri(csv_client: CsvS3Client) -> None: + # arrange rows = [{"address": "1 High St", "postcode": "AB1 2CD"}] + # act uri = csv_client.save_rows(rows, "uploads/addresses.csv") + # assert assert uri == f"s3://{BUCKET}/uploads/addresses.csv" def test_round_trip_preserves_rows(csv_client: CsvS3Client) -> None: + # arrange rows = [ {"address": "1 High St", "postcode": "AB1 2CD"}, {"address": "2 Low St", "postcode": "XY9 8ZW"}, ] + # act uri = csv_client.save_rows(rows, "uploads/addresses.csv") fetched = csv_client.read_rows(uri) + # assert assert fetched == rows def test_save_rows_rejects_empty_list(csv_client: CsvS3Client) -> None: + # act / assert with pytest.raises(ValueError, match="empty"): csv_client.save_rows([], "uploads/empty.csv") def test_read_rows_rejects_wrong_bucket(csv_client: CsvS3Client) -> None: + # act / assert with pytest.raises(ValueError, match="does not match client bucket"): csv_client.read_rows("s3://other-bucket/uploads/addresses.csv") diff --git a/tests/infrastructure/test_s3_client.py b/tests/infrastructure/test_s3_client.py index 7ed4c30b..67db4f58 100644 --- a/tests/infrastructure/test_s3_client.py +++ b/tests/infrastructure/test_s3_client.py @@ -18,14 +18,19 @@ def s3_client() -> Iterator[S3Client]: def test_put_object_returns_s3_uri(s3_client: S3Client) -> None: + # act uri = s3_client.put_object("folder/data.bin", b"payload") + # assert assert uri == f"s3://{BUCKET}/folder/data.bin" def test_get_object_returns_bytes_written_by_put_object(s3_client: S3Client) -> None: + # arrange s3_client.put_object("round/trip.bin", b"hello world") + # act / assert assert s3_client.get_object("round/trip.bin") == b"hello world" def test_bucket_property_exposes_configured_bucket(s3_client: S3Client) -> None: + # act / assert assert s3_client.bucket == BUCKET diff --git a/tests/infrastructure/test_s3_uri.py b/tests/infrastructure/test_s3_uri.py index 896c5959..32fd710f 100644 --- a/tests/infrastructure/test_s3_uri.py +++ b/tests/infrastructure/test_s3_uri.py @@ -4,29 +4,37 @@ from infrastructure.s3_uri import parse_s3_uri def test_parses_simple_s3_uri() -> None: + # act / assert assert parse_s3_uri("s3://my-bucket/file.csv") == ("my-bucket", "file.csv") def test_parses_s3_uri_with_nested_key() -> None: + # act bucket, key = parse_s3_uri("s3://my-bucket/nested/path/to/file.csv") + # assert assert (bucket, key) == ("my-bucket", "nested/path/to/file.csv") def test_rejects_s3_uri_without_key() -> None: + # act / assert with pytest.raises(ValueError, match="bucket and a key"): parse_s3_uri("s3://my-bucket") def test_rejects_s3_uri_with_empty_key() -> None: + # act / assert with pytest.raises(ValueError, match="bucket and a key"): parse_s3_uri("s3://my-bucket/") def test_parses_console_url_prefix() -> None: + # arrange url = "https://eu-west-2.console.aws.amazon.com/s3/object/my-bucket?prefix=nested%2Ffile.csv" + # act / assert assert parse_s3_uri(url) == ("my-bucket", "nested/file.csv") def test_rejects_unparseable_string() -> None: + # act / assert with pytest.raises(ValueError): parse_s3_uri("not-a-uri-at-all") diff --git a/tests/infrastructure/test_sqs_client.py b/tests/infrastructure/test_sqs_client.py index 7f1e8f78..44186bbb 100644 --- a/tests/infrastructure/test_sqs_client.py +++ b/tests/infrastructure/test_sqs_client.py @@ -19,17 +19,23 @@ def sqs_setup() -> Iterator[tuple[SqsClient, Any, str]]: def test_send_returns_message_id(sqs_setup: tuple[SqsClient, Any, str]) -> None: + # arrange client, _boto, _url = sqs_setup + # act message_id = client.send({"hello": "world"}) + # assert assert isinstance(message_id, str) assert message_id def test_send_json_serialises_body(sqs_setup: tuple[SqsClient, Any, str]) -> None: + # arrange client, boto_client, queue_url = sqs_setup body = {"hello": "world", "count": 3} + # act client.send(body) + # assert received: dict[str, Any] = boto_client.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=1 ) diff --git a/tests/orchestration/test_postcode_splitter_orchestrator.py b/tests/orchestration/test_postcode_splitter_orchestrator.py index 4ee2315e..a718ffbc 100644 --- a/tests/orchestration/test_postcode_splitter_orchestrator.py +++ b/tests/orchestration/test_postcode_splitter_orchestrator.py @@ -9,7 +9,8 @@ from typing import Any, cast import boto3 import pytest from moto import mock_aws -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from infrastructure.address2uprn_queue_client import Address2UprnQueueClient from infrastructure.csv_s3_client import CsvS3Client @@ -65,7 +66,7 @@ class Harness: @pytest.fixture -def harness() -> Iterator[Harness]: +def harness(db_engine: Engine) -> Iterator[Harness]: with mock_aws(): # Infra: S3 + SQS boto_s3 = _make_boto_client("s3") @@ -78,10 +79,8 @@ def harness() -> Iterator[Harness]: repo = UserAddressCsvS3Repository(csv_client, BUCKET) queue_client = Address2UprnQueueClient(boto_sqs, queue_url) - # DB: in-memory SQLite TaskOrchestrator - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as session: + # DB: ephemeral PostgreSQL TaskOrchestrator + with Session(db_engine) as session: task_repo = TaskPostgresRepository(session=session) subtask_repo = SubTaskPostgresRepository(session=session) task_orchestrator = TaskOrchestrator( @@ -169,6 +168,7 @@ def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]: def test_split_and_dispatch_creates_three_children_for_fixture( harness: Harness, ) -> None: + # arrange parent_task, parent_subtask = ( harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" @@ -176,12 +176,14 @@ def test_split_and_dispatch_creates_three_children_for_fixture( ) input_uri = _upload_fixture_csv(harness.csv_client) + # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) + # assert assert len(child_ids) == 3 # All child ids are unique and persisted as WAITING children of the # parent task. @@ -194,6 +196,7 @@ def test_split_and_dispatch_creates_three_children_for_fixture( def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri( harness: Harness, ) -> None: + # arrange parent_task, parent_subtask = ( harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" @@ -201,12 +204,14 @@ def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri( ) input_uri = _upload_fixture_csv(harness.csv_client) + # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) + # assert for cid in child_ids: child = harness.subtasks.get(cid) assert child.inputs is not None @@ -224,6 +229,7 @@ def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri( def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids( harness: Harness, ) -> None: + # arrange parent_task, parent_subtask = ( harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" @@ -231,12 +237,14 @@ def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids( ) input_uri = _upload_fixture_csv(harness.csv_client) + # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) + # assert bodies = _drain_queue(harness.boto_sqs, harness.queue_url) assert len(bodies) == len(child_ids) @@ -258,6 +266,7 @@ def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids( def test_split_and_dispatch_returns_child_ids_in_dispatch_order( harness: Harness, ) -> None: + # arrange parent_task, parent_subtask = ( harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" @@ -265,12 +274,14 @@ def test_split_and_dispatch_returns_child_ids_in_dispatch_order( ) input_uri = _upload_fixture_csv(harness.csv_client) + # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) + # assert # Re-load each child's saved batch and inspect the postcode_clean column # to confirm the dispatch order matches the postcode-batching algorithm: # AA-batch first, BB oversize batch second, CC final-flush third. diff --git a/tests/orchestration/test_task_orchestrator.py b/tests/orchestration/test_task_orchestrator.py index c0816d2d..ae89991d 100644 --- a/tests/orchestration/test_task_orchestrator.py +++ b/tests/orchestration/test_task_orchestrator.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from dataclasses import dataclass import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from domain.tasks.subtasks import SubTask, SubTaskStatus from domain.tasks.tasks import Source, TaskStatus @@ -19,10 +20,8 @@ class Harness: @pytest.fixture -def harness() -> Iterator[Harness]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as session: +def harness(db_engine: Engine) -> Iterator[Harness]: + with Session(db_engine) as session: tasks = TaskPostgresRepository(session=session) subtasks = SubTaskPostgresRepository(session=session) yield Harness( @@ -35,6 +34,7 @@ def harness() -> Iterator[Harness]: def test_create_task_with_subtask_creates_both_in_waiting( harness: Harness, ) -> None: + # act task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test", inputs={"foo": "bar"}, @@ -42,6 +42,7 @@ def test_create_task_with_subtask_creates_both_in_waiting( source_id="abc", ) + # assert assert task.status is TaskStatus.WAITING assert subtask.status is SubTaskStatus.WAITING assert subtask.task_id == task.id @@ -49,27 +50,33 @@ def test_create_task_with_subtask_creates_both_in_waiting( def test_start_subtask_cascades_to_in_progress(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act started = harness.orchestrator.start_subtask( subtask.id, cloud_logs_url="https://example/log" ) + # assert assert started.status is SubTaskStatus.IN_PROGRESS assert started.cloud_logs_url == "https://example/log" assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS def test_complete_subtask_cascades_to_complete(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) harness.orchestrator.start_subtask(subtask.id) + # act harness.orchestrator.complete_subtask(subtask.id, {"value": 42}) + # assert done_subtask = harness.subtasks.get(subtask.id) done_task = harness.tasks.get(task.id) assert done_subtask.outputs == {"result": {"value": 42}} @@ -78,12 +85,15 @@ def test_complete_subtask_cascades_to_complete(harness: Harness) -> None: def test_fail_subtask_cascades_to_failed(harness: Harness) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act harness.orchestrator.fail_subtask(subtask.id, RuntimeError("boom")) + # assert failed_subtask = harness.subtasks.get(subtask.id) failed_task = harness.tasks.get(task.id) assert failed_subtask.outputs == {"error": "boom"} @@ -93,42 +103,51 @@ def test_fail_subtask_cascades_to_failed(harness: Harness) -> None: def test_failed_subtask_locks_task_failed_even_with_others_complete( harness: Harness, ) -> None: + # arrange task, first = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) second = SubTask.create(task_id=task.id) harness.subtasks.create(second) + # act harness.orchestrator.complete_subtask(first.id) harness.orchestrator.fail_subtask(second.id, RuntimeError("nope")) + # assert assert harness.tasks.get(task.id).status is TaskStatus.FAILED def test_mixed_complete_and_in_progress_keeps_task_in_progress( harness: Harness, ) -> None: + # arrange task, first = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) second = SubTask.create(task_id=task.id) harness.subtasks.create(second) + # act harness.orchestrator.complete_subtask(first.id) harness.orchestrator.start_subtask(second.id) + # assert assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS def test_run_subtask_happy_path_returns_result_and_cascades_complete( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) + # act result = harness.orchestrator.run_subtask(subtask.id, work=lambda: {"answer": 42}) + # assert assert result == {"answer": 42} assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE @@ -137,16 +156,19 @@ def test_run_subtask_happy_path_returns_result_and_cascades_complete( def test_create_child_subtask_adds_waiting_child_without_changing_parent_status( harness: Harness, ) -> None: + # arrange task, first = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) harness.orchestrator.start_subtask(first.id) assert harness.tasks.get(task.id).status is TaskStatus.IN_PROGRESS + # act child = harness.orchestrator.create_child_subtask( task.id, inputs={"split": "a"} ) + # assert persisted_child = harness.subtasks.get(child.id) assert persisted_child.task_id == task.id assert persisted_child.status is SubTaskStatus.WAITING @@ -159,6 +181,7 @@ def test_create_child_subtask_adds_waiting_child_without_changing_parent_status( def test_run_subtask_failing_work_marks_failed_and_reraises( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -166,6 +189,7 @@ def test_run_subtask_failing_work_marks_failed_and_reraises( def boom() -> None: raise RuntimeError("boom") + # act / assert with pytest.raises(RuntimeError, match="boom"): harness.orchestrator.run_subtask(subtask.id, work=boom) diff --git a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py index ac39e089..9cec52ea 100644 --- a/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py +++ b/tests/repositories/tasks/postgres/test_subtask_postgres_repository.py @@ -1,33 +1,40 @@ from collections.abc import Iterator -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session -# Importing the SQLModel row modules registers their tables in -# SQLModel.metadata so create_all builds both. Imports look unused; they aren't. -import infrastructure.postgres.subtask_table # noqa: F401 # pyright: ignore[reportUnusedImport] -import infrastructure.postgres.task_table # noqa: F401 # pyright: ignore[reportUnusedImport] from domain.tasks.subtasks import SubTask, SubTaskStatus +from domain.tasks.tasks import Task from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository +from repositories.tasks.task_postgres_repository import TaskPostgresRepository @pytest.fixture -def session() -> Iterator[Session]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as s: +def session(db_engine: Engine) -> Iterator[Session]: + with Session(db_engine) as s: yield s +def _persisted_task_id(session: Session) -> UUID: + """Create a parent Task row so SubTask FK constraints are satisfied.""" + task = Task.create(task_source="manual:test") + TaskPostgresRepository(session=session).create(task) + return task.id + + def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - task_id = uuid4() + task_id = _persisted_task_id(session) st = SubTask.create(task_id=task_id, inputs={"address": "68 Glendon Way"}) + # act repo.create(st) fetched = repo.get(st.id) + # assert assert fetched.id == st.id assert fetched.task_id == task_id assert fetched.status is SubTaskStatus.WAITING @@ -36,16 +43,21 @@ def test_create_and_get_round_trip_preserves_inputs(session: Session) -> None: def test_save_persists_status_and_outputs(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - st = SubTask.create(task_id=uuid4()) + st = SubTask.create(task_id=_persisted_task_id(session)) repo.create(st) + # act st.start(cloud_logs_url="https://example/log") repo.save(st) + # assert assert repo.get(st.id).status is SubTaskStatus.IN_PROGRESS + # act st.complete({"uprn": "123"}) repo.save(st) + # assert done = repo.get(st.id) assert done.status is SubTaskStatus.COMPLETE assert done.outputs == {"result": {"uprn": "123"}} @@ -54,16 +66,19 @@ def test_save_persists_status_and_outputs(session: Session) -> None: def test_list_by_task_filters_by_task_id(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) - task_a = uuid4() - task_b = uuid4() + task_a = _persisted_task_id(session) + task_b = _persisted_task_id(session) repo.create(SubTask.create(task_id=task_a)) repo.create(SubTask.create(task_id=task_a)) repo.create(SubTask.create(task_id=task_b)) + # act a_results = repo.list_by_task(task_a) b_results = repo.list_by_task(task_b) + # assert assert len(a_results) == 2 assert len(b_results) == 1 assert all(s.task_id == task_a for s in a_results) @@ -71,11 +86,15 @@ def test_list_by_task_filters_by_task_id(session: Session) -> None: def test_list_by_task_returns_empty_for_unknown_task(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) + # act / assert assert repo.list_by_task(uuid4()) == [] def test_get_missing_raises(session: Session) -> None: + # arrange repo = SubTaskPostgresRepository(session=session) + # act / assert with pytest.raises(ValueError, match="not found"): repo.get(uuid4()) diff --git a/tests/repositories/tasks/postgres/test_task_postgres_repository.py b/tests/repositories/tasks/postgres/test_task_postgres_repository.py index 3e1aa226..8a49a861 100644 --- a/tests/repositories/tasks/postgres/test_task_postgres_repository.py +++ b/tests/repositories/tasks/postgres/test_task_postgres_repository.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from uuid import uuid4 import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from domain.tasks.tasks import Source, Task, TaskStatus from infrastructure.postgres.task_table import TaskRow @@ -10,25 +11,23 @@ from repositories.tasks.task_postgres_repository import TaskPostgresRepository @pytest.fixture -def session() -> Iterator[Session]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as s: +def session(db_engine: Engine) -> Iterator[Session]: + with Session(db_engine) as s: yield s def test_create_and_get_round_trip(session: Session) -> None: - # Arrange + # arrange repo = TaskPostgresRepository(session=session) t = Task.create( task_source="manual:test", source=Source.PORTFOLIO, source_id="abc-123" ) - # Act + # act repo.create(t) fetched = repo.get(t.id) - # Assert + # assert assert fetched.id == t.id assert fetched.status is TaskStatus.WAITING assert fetched.source is Source.PORTFOLIO @@ -36,33 +35,43 @@ def test_create_and_get_round_trip(session: Session) -> None: def test_save_persists_status_transition(session: Session) -> None: + # arrange repo = TaskPostgresRepository(session=session) t = Task.create(task_source="manual:test") repo.create(t) + # act t.start() repo.save(t) + # assert assert repo.get(t.id).status is TaskStatus.IN_PROGRESS + # act t.complete() repo.save(t) + # assert done = repo.get(t.id) assert done.status is TaskStatus.COMPLETE assert done.job_completed is not None def test_get_missing_raises(session: Session) -> None: + # arrange repo = TaskPostgresRepository(session=session) + # act / assert with pytest.raises(ValueError, match="not found"): repo.get(uuid4()) def test_get_normalises_legacy_capitalised_status(session: Session) -> None: # Existing rows written by backend code use "In Progress" (capitalised). + # arrange repo = TaskPostgresRepository(session=session) row = TaskRow(task_source="manual:test", status="In Progress") session.add(row) session.commit() + # act fetched = repo.get(row.id) + # assert assert fetched.status is TaskStatus.IN_PROGRESS diff --git a/tests/repositories/user_address/test_user_address_csv_s3_repository.py b/tests/repositories/user_address/test_user_address_csv_s3_repository.py index c1acee32..9ffb250a 100644 --- a/tests/repositories/user_address/test_user_address_csv_s3_repository.py +++ b/tests/repositories/user_address/test_user_address_csv_s3_repository.py @@ -32,6 +32,7 @@ def _upload_csv( def test_load_batch_parses_address_postcode_and_reference( repo: UserAddressCsvS3Repository, ) -> None: + # arrange rows = [ { "Address 1": "1 High Street", @@ -43,8 +44,10 @@ def test_load_batch_parses_address_postcode_and_reference( ] uri = _upload_csv(repo, rows, "uploads/full.csv") + # act addresses = repo.load_batch(uri) + # assert assert len(addresses) == 1 address = addresses[0] assert address.user_address == "1 High Street, Flat 2, Townville" @@ -55,6 +58,7 @@ def test_load_batch_parses_address_postcode_and_reference( def test_load_batch_uses_only_address_1_when_others_missing( repo: UserAddressCsvS3Repository, ) -> None: + # arrange rows = [ { "Address 1": "10 Cardiff Road", @@ -66,8 +70,10 @@ def test_load_batch_uses_only_address_1_when_others_missing( ] uri = _upload_csv(repo, rows, "uploads/address1-only.csv") + # act addresses = repo.load_batch(uri) + # assert assert len(addresses) == 1 assert addresses[0].user_address == "10 Cardiff Road" assert addresses[0].postcode == Postcode("CF101AA") @@ -77,6 +83,7 @@ def test_load_batch_uses_only_address_1_when_others_missing( def test_load_batch_handles_missing_internal_reference( repo: UserAddressCsvS3Repository, ) -> None: + # arrange rows = [ { "Address 1": "5 Park Lane", @@ -88,8 +95,10 @@ def test_load_batch_handles_missing_internal_reference( ] uri = _upload_csv(repo, rows, "uploads/no-ref.csv") + # act addresses = repo.load_batch(uri) + # assert assert len(addresses) == 1 assert addresses[0].user_address == "5 Park Lane" assert addresses[0].postcode == Postcode("M11AA") @@ -101,6 +110,7 @@ def test_load_batch_captures_full_source_row( ) -> None: # A raw EPC-export-shaped row: the splitter must preserve every column, # not just the ones it parses into UserAddress fields. + # arrange row = { "Asset Reference": "511", "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", @@ -110,17 +120,21 @@ def test_load_batch_captures_full_source_row( } uri = _upload_csv(repo, [row], "uploads/epc.csv") + # act addresses = repo.load_batch(uri) + # assert assert addresses[0].source_row == row def test_load_batch_raises_when_postcode_column_absent( repo: UserAddressCsvS3Repository, ) -> None: + # arrange rows = [{"Address 1": "1 High Street", "Property Type": "Flat"}] uri = _upload_csv(repo, rows, "uploads/no-postcode.csv") + # act / assert with pytest.raises(ValueError, match="no 'postcode' column"): repo.load_batch(uri) @@ -128,6 +142,7 @@ def test_load_batch_raises_when_postcode_column_absent( def test_save_batch_passes_through_all_columns_and_appends_postcode_clean( repo: UserAddressCsvS3Repository, ) -> None: + # arrange row = { "Asset Reference": "511", "Address 1": "9 Abingdon Road Padiham Lancashire BB12 7BX", @@ -137,9 +152,11 @@ def test_save_batch_passes_through_all_columns_and_appends_postcode_clean( uri = _upload_csv(repo, [row], "uploads/epc.csv") addresses = repo.load_batch(uri) + # act saved_uri = repo.save_batch(addresses, "tasks/passthrough") saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] + # assert assert len(saved_rows) == 1 saved = saved_rows[0] # Every original column survives, byte-for-byte. @@ -152,6 +169,7 @@ def test_save_batch_passes_through_all_columns_and_appends_postcode_clean( def test_save_batch_returns_uri_under_path_prefix( repo: UserAddressCsvS3Repository, ) -> None: + # arrange addresses = [ UserAddress( user_address="1 High Street", @@ -160,8 +178,10 @@ def test_save_batch_returns_uri_under_path_prefix( ), ] + # act uri = repo.save_batch(addresses, "tasks/abc/batches") + # assert assert uri.startswith(f"s3://{BUCKET}/tasks/abc/batches/") assert uri.endswith(".csv") @@ -169,6 +189,7 @@ def test_save_batch_returns_uri_under_path_prefix( def test_save_then_reload_round_trip_preserves_columns( repo: UserAddressCsvS3Repository, ) -> None: + # arrange rows = [ { "Address 1": "1 High Street", @@ -184,9 +205,11 @@ def test_save_then_reload_round_trip_preserves_columns( uri = _upload_csv(repo, rows, "uploads/round-trip.csv") addresses = repo.load_batch(uri) + # act saved_uri = repo.save_batch(addresses, "tasks/round-trip") saved_rows = repo._csv_client.read_rows(saved_uri) # pyright: ignore[reportPrivateUsage] + # assert # Original columns come back verbatim; postcode_clean is the only addition. assert [ {k: v for k, v in r.items() if k != "postcode_clean"} for r in saved_rows @@ -197,6 +220,7 @@ def test_save_then_reload_round_trip_preserves_columns( def test_save_batch_uses_unique_filename_per_call( repo: UserAddressCsvS3Repository, ) -> None: + # arrange addresses = [ UserAddress( user_address="1 High Street", @@ -205,7 +229,9 @@ def test_save_batch_uses_unique_filename_per_call( ), ] + # act uri_1 = repo.save_batch(addresses, "tasks/uniqueness") uri_2 = repo.save_batch(addresses, "tasks/uniqueness") + # assert assert uri_1 != uri_2 diff --git a/tests/utilities/aws_lambda/test_subtask_handler.py b/tests/utilities/aws_lambda/test_subtask_handler.py index 9cf68f28..d671adc4 100644 --- a/tests/utilities/aws_lambda/test_subtask_handler.py +++ b/tests/utilities/aws_lambda/test_subtask_handler.py @@ -6,7 +6,8 @@ from typing import Any from uuid import UUID import pytest -from sqlmodel import Session, SQLModel, create_engine +from sqlalchemy import Engine +from sqlmodel import Session from domain.tasks.subtasks import SubTaskStatus from domain.tasks.tasks import TaskStatus @@ -30,10 +31,8 @@ class Harness: @pytest.fixture -def harness() -> Iterator[Harness]: - engine = create_engine("sqlite://") - SQLModel.metadata.create_all(engine) - with Session(engine) as session: +def harness(db_engine: Engine) -> Iterator[Harness]: + with Session(db_engine) as session: tasks = TaskPostgresRepository(session=session) subtasks = SubTaskPostgresRepository(session=session) yield Harness( @@ -50,6 +49,7 @@ def _direct_event(task_id: UUID, subtask_id: UUID) -> dict[str, Any]: def test_subtask_handler_injects_orchestrator_as_third_positional_argument( harness: Harness, ) -> None: + # arrange _, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -64,8 +64,10 @@ def test_subtask_handler_injects_orchestrator_as_third_positional_argument( received["context"] = context received["orchestrator"] = orchestrator + # act handler(_direct_event(subtask.task_id, subtask.id), context="ctx-sentinel") + # assert assert received["orchestrator"] is harness.orchestrator assert received["context"] == "ctx-sentinel" assert received["body"]["sub_task_id"] == str(subtask.id) @@ -74,6 +76,7 @@ def test_subtask_handler_injects_orchestrator_as_third_positional_argument( def test_subtask_handler_completes_parent_subtask_on_success( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -84,8 +87,10 @@ def test_subtask_handler_completes_parent_subtask_on_success( ) -> None: return None + # act handler(_direct_event(task.id, subtask.id), context=None) + # assert assert harness.subtasks.get(subtask.id).status is SubTaskStatus.COMPLETE assert harness.tasks.get(task.id).status is TaskStatus.COMPLETE @@ -93,6 +98,7 @@ def test_subtask_handler_completes_parent_subtask_on_success( def test_subtask_handler_marks_parent_failed_and_reraises_on_error( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -103,6 +109,7 @@ def test_subtask_handler_marks_parent_failed_and_reraises_on_error( ) -> None: raise RuntimeError("boom") + # act / assert with pytest.raises(RuntimeError, match="boom"): handler(_direct_event(task.id, subtask.id), context=None) @@ -113,6 +120,7 @@ def test_subtask_handler_marks_parent_failed_and_reraises_on_error( def test_subtask_handler_injected_orchestrator_can_create_child_subtask( harness: Harness, ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -126,8 +134,10 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask( child = orchestrator.create_child_subtask(task.id, inputs={"split": 1}) child_ids.append(child.id) + # act handler(_direct_event(task.id, subtask.id), context=None) + # assert assert len(child_ids) == 1 persisted_child = harness.subtasks.get(child_ids[0]) assert persisted_child.task_id == task.id @@ -137,6 +147,7 @@ def test_subtask_handler_injected_orchestrator_can_create_child_subtask( def test_subtask_handler_logs_subtask_lifecycle_on_success( harness: Harness, caplog: pytest.LogCaptureFixture ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -147,9 +158,11 @@ def test_subtask_handler_logs_subtask_lifecycle_on_success( ) -> None: return None + # act with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): handler(_direct_event(task.id, subtask.id), context=None) + # assert assert f"Running subtask {subtask.id}" in caplog.text assert f"Subtask {subtask.id} completed" in caplog.text @@ -157,6 +170,7 @@ def test_subtask_handler_logs_subtask_lifecycle_on_success( def test_subtask_handler_logs_exception_on_failure( harness: Harness, caplog: pytest.LogCaptureFixture ) -> None: + # arrange task, subtask = harness.orchestrator.create_task_with_subtask( task_source="manual:test" ) @@ -167,6 +181,7 @@ def test_subtask_handler_logs_exception_on_failure( ) -> None: raise RuntimeError("boom") + # act / assert with caplog.at_level(logging.INFO, logger=_LOGGER_NAME): with pytest.raises(RuntimeError, match="boom"): handler(_direct_event(task.id, subtask.id), context=None) @@ -181,6 +196,7 @@ def test_subtask_handler_logs_exception_on_failure( def test_subtask_handler_records_cloudwatch_url_on_subtask( harness: Harness, monkeypatch: pytest.MonkeyPatch ) -> None: + # arrange monkeypatch.setenv("AWS_REGION", "eu-west-2") monkeypatch.setenv( "AWS_LAMBDA_LOG_GROUP_NAME", "/aws/lambda/postcode-splitter" @@ -198,8 +214,10 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask( ) -> None: return None + # act handler(_direct_event(task.id, subtask.id), context=None) + # assert saved_url = harness.subtasks.get(subtask.id).cloud_logs_url assert saved_url is not None assert saved_url.startswith( @@ -213,6 +231,7 @@ def test_subtask_handler_records_cloudwatch_url_on_subtask( def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( harness: Harness, monkeypatch: pytest.MonkeyPatch ) -> None: + # arrange for var in ( "AWS_REGION", "AWS_LAMBDA_LOG_GROUP_NAME", @@ -229,6 +248,8 @@ def test_subtask_handler_leaves_cloudwatch_url_unset_outside_lambda( ) -> None: return None + # act handler(_direct_event(task.id, subtask.id), context=None) + # assert assert harness.subtasks.get(subtask.id).cloud_logs_url is None