from __future__ import annotations import json import os from collections.abc import Iterator from dataclasses import dataclass from typing import Any, cast import boto3 import pytest from moto import mock_aws from sqlalchemy import Engine from sqlmodel import Session from infrastructure.address2uprn_queue_client import Address2UprnQueueClient from infrastructure.s3.csv_s3_client import CsvS3Client from orchestration.postcode_splitter_orchestrator import PostcodeSplitterOrchestrator from orchestration.task_orchestrator import TaskOrchestrator from repositories.tasks.subtask_postgres_repository import SubTaskPostgresRepository from repositories.tasks.task_postgres_repository import TaskPostgresRepository from repositories.unstandardised_address.unstandardised_address_list_csv_s3_repository import ( UnstandardisedAddressListCsvS3Repository, ) BUCKET = "splitter-bucket" REGION = "us-east-1" def _make_boto_client(service_name: str) -> Any: factory: Any = ( boto3.client ) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] return factory(service_name, region_name=REGION) @pytest.fixture(autouse=True) def _aws_creds() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] keys = ( "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", "AWS_DEFAULT_REGION", ) prev: dict[str, Any] = {k: os.environ.get(k) for k in keys} os.environ["AWS_ACCESS_KEY_ID"] = "testing" os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" os.environ["AWS_SESSION_TOKEN"] = "testing" os.environ["AWS_DEFAULT_REGION"] = REGION try: yield finally: for k, v in prev.items(): if v is None: os.environ.pop(k, None) else: os.environ[k] = v @dataclass class Harness: splitter: PostcodeSplitterOrchestrator task_orchestrator: TaskOrchestrator subtasks: SubTaskPostgresRepository csv_client: CsvS3Client boto_sqs: Any queue_url: str repo: UnstandardisedAddressListCsvS3Repository @pytest.fixture def harness(db_engine: Engine) -> Iterator[Harness]: with mock_aws(): # Infra: S3 + SQS boto_s3 = _make_boto_client("s3") boto_s3.create_bucket(Bucket=BUCKET) boto_sqs = _make_boto_client("sqs") queue: dict[str, Any] = boto_sqs.create_queue(QueueName="address2uprn-queue") queue_url = cast(str, queue["QueueUrl"]) csv_client = CsvS3Client(boto_s3, BUCKET) repo = UnstandardisedAddressListCsvS3Repository(csv_client, BUCKET) queue_client = Address2UprnQueueClient(boto_sqs, queue_url) # DB: ephemeral PostgreSQL TaskOrchestrator with Session(db_engine) as session: task_repo = TaskPostgresRepository(session=session) subtask_repo = SubTaskPostgresRepository(session=session) task_orchestrator = TaskOrchestrator( task_repo=task_repo, subtask_repo=subtask_repo ) splitter = PostcodeSplitterOrchestrator( task_orchestrator=task_orchestrator, unstandardised_address_repo=repo, queue_client=queue_client, max_batch_size=3, ) yield Harness( splitter=splitter, task_orchestrator=task_orchestrator, subtasks=subtask_repo, csv_client=csv_client, boto_sqs=boto_sqs, queue_url=queue_url, repo=repo, ) def _upload_fixture_csv(csv_client: CsvS3Client) -> str: # Three postcode groups: # AA1 1AA × 2 (within cap) # BB2 2BB × 4 (oversize: > max_batch_size=3) # CC3 3CC × 1 (final flush) # Expected batching with cap=3 and the algorithm in # ``iter_postcode_grouped_batches``: # batch 1: [AA1 1AA × 2] (flushed because oversize follows) # batch 2: [BB2 2BB × 4] (oversize own batch) # batch 3: [CC3 3CC × 1] (final flush) rows: list[dict[str, str]] = [] rows.extend( { "Address 1": f"{i} High St", "Address 2": "", "Address 3": "", "postcode": "AA1 1AA", "Internal Reference": f"AA-{i}", } for i in range(1, 3) ) rows.extend( { "Address 1": f"{i} Long Road", "Address 2": "", "Address 3": "", "postcode": "BB2 2BB", "Internal Reference": f"BB-{i}", } for i in range(1, 5) ) rows.append( { "Address 1": "1 Final Way", "Address 2": "", "Address 3": "", "postcode": "CC3 3CC", "Internal Reference": "CC-1", } ) return csv_client.save_rows(rows, "uploads/input.csv") def _drain_queue(boto_sqs: Any, queue_url: str) -> list[dict[str, Any]]: bodies: list[dict[str, Any]] = [] while True: received: dict[str, Any] = boto_sqs.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=0 ) messages = cast(list[dict[str, Any]], received.get("Messages", [])) if not messages: break for message in messages: bodies.append(cast(dict[str, Any], json.loads(message["Body"]))) boto_sqs.delete_message( QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"] ) return bodies def test_split_and_dispatch_creates_three_children_for_fixture( harness: Harness, ) -> None: # arrange parent_task, parent_subtask = harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" ) input_uri = _upload_fixture_csv(harness.csv_client) # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) # assert assert len(child_ids) == 3 # All child ids are unique and persisted as WAITING children of the # parent task. assert len(set(child_ids)) == 3 for cid in child_ids: child = harness.subtasks.get(cid) assert child.task_id == parent_task.id def test_split_and_dispatch_persists_child_inputs_with_task_id_and_s3_uri( harness: Harness, ) -> None: # arrange parent_task, parent_subtask = harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" ) input_uri = _upload_fixture_csv(harness.csv_client) # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) # assert for cid in child_ids: child = harness.subtasks.get(cid) assert child.inputs is not None assert child.inputs["task_id"] == str(parent_task.id) batch_uri = child.inputs["s3_uri"] assert isinstance(batch_uri, str) prefix = ( f"s3://{BUCKET}/ara_postcode_splitter_batches/" f"{parent_task.id}/{parent_subtask.id}/" ) assert batch_uri.startswith(prefix) assert batch_uri.endswith(".csv") def test_split_and_dispatch_publishes_one_message_per_child_with_matching_ids( harness: Harness, ) -> None: # arrange parent_task, parent_subtask = harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" ) input_uri = _upload_fixture_csv(harness.csv_client) # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) # assert bodies = _drain_queue(harness.boto_sqs, harness.queue_url) assert len(bodies) == len(child_ids) # Match queue messages against persisted child inputs by child_subtask_id; # the message body's task_id/s3_uri must agree with the SubTask inputs. bodies_by_child = {body["sub_task_id"]: body for body in bodies} assert set(bodies_by_child.keys()) == {str(cid) for cid in child_ids} for cid in child_ids: child = harness.subtasks.get(cid) body = bodies_by_child[str(cid)] assert child.inputs is not None assert body == { "task_id": str(parent_task.id), "sub_task_id": str(cid), "s3_uri": child.inputs["s3_uri"], } def test_split_and_dispatch_returns_child_ids_in_dispatch_order( harness: Harness, ) -> None: # arrange parent_task, parent_subtask = harness.task_orchestrator.create_task_with_subtask( task_source="manual:postcode-splitter-int" ) input_uri = _upload_fixture_csv(harness.csv_client) # act child_ids = harness.splitter.split_and_dispatch( parent_task_id=parent_task.id, parent_subtask_id=parent_subtask.id, input_s3_uri=input_uri, ) # assert # Re-load each child's saved batch and inspect the postcode_clean column # to confirm the dispatch order matches the postcode-batching algorithm: # AA-batch first, BB oversize batch second, CC final-flush third. postcodes_per_batch: list[set[str]] = [] for cid in child_ids: child = harness.subtasks.get(cid) assert child.inputs is not None rows = harness.csv_client.read_rows(child.inputs["s3_uri"]) postcodes_per_batch.append({row["postcode_clean"] for row in rows}) assert postcodes_per_batch == [ {"AA11AA"}, {"BB22BB"}, {"CC33CC"}, ]