diff --git a/services/ml_training_data/pyproject.toml b/services/ml_training_data/pyproject.toml index 232d2276..4519fd93 100644 --- a/services/ml_training_data/pyproject.toml +++ b/services/ml_training_data/pyproject.toml @@ -6,6 +6,7 @@ requires-python = ">=3.11" dependencies = [ "pandas>=2.0", "pandas-stubs", + "ijson>=3.2", ] [build-system] diff --git a/services/ml_training_data/src/ml_training_data/bulk_zip_reader.py b/services/ml_training_data/src/ml_training_data/bulk_zip_reader.py new file mode 100644 index 00000000..0790d172 --- /dev/null +++ b/services/ml_training_data/src/ml_training_data/bulk_zip_reader.py @@ -0,0 +1,46 @@ +"""Stream EPC certificates from the gov bulk JSON ZIP. + +The bulk ZIP from /api/files/domestic/json is ZIP64, ~15 GB, with one JSON entry +per inspection year (certificates-YYYY.json). Each entry is a JSON array of +certificate dicts. The reader streams entries via ijson so 5GB+ yearly files +never have to fit in memory. +""" + +import zipfile +from collections.abc import Iterator +from typing import Any, cast + +import ijson # type: ignore[import-untyped] + +from ml_training_data.storage import Storage + + +class BulkZipReader: + def __init__(self, storage: Storage, zip_key: str) -> None: + self._storage = storage + self._zip_key = zip_key + + def list_entries(self) -> list[str]: + with zipfile.ZipFile(self._storage.open_read(self._zip_key)) as zf: + return zf.namelist() + + def iter_certificates(self, entry: str) -> Iterator[dict[str, Any]]: + with zipfile.ZipFile(self._storage.open_read(self._zip_key)) as zf: + with zf.open(entry) as f: + for item in ijson.items(f, "item"): + yield cast(dict[str, Any], item) + + def iter_certificates_filtered( + self, certificate_numbers: set[str] + ) -> Iterator[dict[str, Any]]: + remaining = set(certificate_numbers) + for entry in self.list_entries(): + if not remaining: + return + for cert in self.iter_certificates(entry): + cert_num = cert.get("certificate_number") + if cert_num in remaining: + remaining.discard(cert_num) + yield cert + if not remaining: + return diff --git a/services/ml_training_data/src/ml_training_data/storage.py b/services/ml_training_data/src/ml_training_data/storage.py index c5c76490..8e7e3994 100644 --- a/services/ml_training_data/src/ml_training_data/storage.py +++ b/services/ml_training_data/src/ml_training_data/storage.py @@ -7,7 +7,7 @@ the swap is a constructor change, not a callsite rewrite. from collections.abc import Iterator from pathlib import Path -from typing import Protocol +from typing import IO, Protocol class Storage(Protocol): @@ -15,6 +15,7 @@ class Storage(Protocol): def read_bytes(self, key: str) -> bytes: ... def exists(self, key: str) -> bool: ... def iter_keys(self, prefix: str = "") -> Iterator[str]: ... + def open_read(self, key: str) -> IO[bytes]: ... class LocalStorage: @@ -35,6 +36,9 @@ class LocalStorage: def exists(self, key: str) -> bool: return self._path(key).exists() + def open_read(self, key: str) -> IO[bytes]: + return self._path(key).open("rb") + def iter_keys(self, prefix: str = "") -> Iterator[str]: if not self._root.exists(): return diff --git a/services/ml_training_data/tests/unit/test_bulk_zip_reader.py b/services/ml_training_data/tests/unit/test_bulk_zip_reader.py new file mode 100644 index 00000000..a7ac0409 --- /dev/null +++ b/services/ml_training_data/tests/unit/test_bulk_zip_reader.py @@ -0,0 +1,89 @@ +"""Tests for BulkZipReader — stream EPC certificates from the gov bulk JSON ZIP. + +The real bulk ZIP is ZIP64, ~15 GB, with one JSON entry per inspection year +(certificates-YYYY.json). Each entry is a JSON array of certificate dicts. +We test against small synthetic ZIPs of the same shape. +""" + +import io +import json +import zipfile +from pathlib import Path +from typing import Any + +import pytest + +from ml_training_data.bulk_zip_reader import BulkZipReader +from ml_training_data.storage import LocalStorage + + +def _write_zip(storage: LocalStorage, key: str, entries: dict[str, Any]) -> None: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for entry_name, payload in entries.items(): + zf.writestr(entry_name, json.dumps(payload)) + storage.write_bytes(key, buf.getvalue()) + + +def test_list_entries_returns_zip_member_names_in_archive_order(tmp_path: Path) -> None: + # Arrange + storage = LocalStorage(root=tmp_path) + _write_zip( + storage, + "bulk.zip", + {"certificates-2024.json": [], "certificates-2025.json": []}, + ) + reader = BulkZipReader(storage=storage, zip_key="bulk.zip") + + # Act + entries = reader.list_entries() + + # Assert + assert entries == ["certificates-2024.json", "certificates-2025.json"] + + +def test_iter_certificates_yields_every_cert_in_entry(tmp_path: Path) -> None: + # Arrange + storage = LocalStorage(root=tmp_path) + certs = [ + {"certificate_number": "CN-0001", "postcode": "AA1 1AA"}, + {"certificate_number": "CN-0002", "postcode": "BB2 2BB"}, + {"certificate_number": "CN-0003", "postcode": "CC3 3CC"}, + ] + _write_zip(storage, "bulk.zip", {"certificates-2025.json": certs}) + reader = BulkZipReader(storage=storage, zip_key="bulk.zip") + + # Act + out = list(reader.iter_certificates("certificates-2025.json")) + + # Assert + assert out == certs + + +def test_iter_certificates_filtered_yields_only_requested_across_entries(tmp_path: Path) -> None: + # Arrange + storage = LocalStorage(root=tmp_path) + _write_zip( + storage, + "bulk.zip", + { + "certificates-2024.json": [ + {"certificate_number": "A-1", "postcode": "AA1"}, + {"certificate_number": "A-2", "postcode": "AA2"}, + ], + "certificates-2025.json": [ + {"certificate_number": "B-1", "postcode": "BB1"}, + {"certificate_number": "B-2", "postcode": "BB2"}, + ], + }, + ) + reader = BulkZipReader(storage=storage, zip_key="bulk.zip") + + # Act + out = sorted( + reader.iter_certificates_filtered({"A-1", "B-2", "MISSING-9"}), + key=lambda c: c["certificate_number"], + ) + + # Assert + assert [c["certificate_number"] for c in out] == ["A-1", "B-2"] diff --git a/services/ml_training_data/tests/unit/test_storage.py b/services/ml_training_data/tests/unit/test_storage.py index 0ab08051..c254c2c9 100644 --- a/services/ml_training_data/tests/unit/test_storage.py +++ b/services/ml_training_data/tests/unit/test_storage.py @@ -74,3 +74,17 @@ def test_read_bytes_raises_filenotfound_for_missing_key(tmp_path: Path) -> None: # Act / Assert with pytest.raises(FileNotFoundError): storage.read_bytes("nope.bin") + + +def test_open_read_returns_seekable_binary_stream(tmp_path: Path) -> None: + # Arrange + storage = LocalStorage(root=tmp_path) + storage.write_bytes("big.bin", b"abcdefghij") + + # Act + with storage.open_read("big.bin") as f: + f.seek(4) + chunk = f.read(3) + + # Assert + assert chunk == b"efg"