diff --git a/services/ml_training_data/pyproject.toml b/services/ml_training_data/pyproject.toml index 15e00384..12b412a5 100644 --- a/services/ml_training_data/pyproject.toml +++ b/services/ml_training_data/pyproject.toml @@ -11,6 +11,8 @@ dependencies = [ "pyarrow>=15", "lightgbm>=4.0", "scikit-learn>=1.4", + "httpx", + "remotezip>=0.12", ] [tool.uv.sources] diff --git a/services/ml_training_data/src/ml_training_data/remote_bulk_fetcher.py b/services/ml_training_data/src/ml_training_data/remote_bulk_fetcher.py new file mode 100644 index 00000000..afa16263 --- /dev/null +++ b/services/ml_training_data/src/ml_training_data/remote_bulk_fetcher.py @@ -0,0 +1,79 @@ +"""Extract specific yearly entries from the gov bulk JSON ZIP without downloading +the whole 15 GB archive. + +The gov endpoint returns a 302 to a pre-signed S3 URL. remotezip uses HTTP Range +requests against that URL to read only the central directory + the bytes for the +requested entries, so disk usage stays at "size of the entries we actually want" +instead of the full archive. + +Entries are streamed via zipfile.ZipExtFile.read(chunk) so partial-network failures +during the multi-GB read don't waste the whole transfer, and so we never hold the +full entry in memory. +""" + +from pathlib import Path +from tempfile import NamedTemporaryFile + +import httpx +from remotezip import RemoteZip # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs] + +from ml_training_data.storage import Storage + +_BULK_JSON_URL = ( + "https://api.get-energy-performance-data.communities.gov.uk/api/files/domestic/json" +) +_READ_CHUNK_BYTES = 8 * 1024 * 1024 # 8 MB + + +def extract_entries( + auth_token: str, + entry_names: list[str], + storage: Storage, + key_prefix: str, +) -> dict[str, int]: + presigned_url = _resolve_presigned_url(auth_token) + sizes: dict[str, int] = {} + with RemoteZip(presigned_url) as zf: # pyright: ignore[reportUnknownVariableType] + for entry in entry_names: + n_bytes = _stream_entry_to_storage(zf, entry, storage, f"{key_prefix}{entry}") + sizes[entry] = n_bytes + return sizes + + +def _stream_entry_to_storage( + zf: RemoteZip, # pyright: ignore[reportUnknownParameterType] + entry: str, + storage: Storage, + output_key: str, +) -> int: + with NamedTemporaryFile(delete=False) as tmp: + tmp_path = Path(tmp.name) + with zf.open(entry) as src: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + while True: + chunk: bytes = src.read(_READ_CHUNK_BYTES) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] + if not chunk: + break + tmp.write(chunk) + total = tmp_path.stat().st_size + storage.write_bytes(output_key, tmp_path.read_bytes()) + tmp_path.unlink() + return total + + +def _resolve_presigned_url(auth_token: str) -> str: + response = httpx.get( + _BULK_JSON_URL, + headers={"Authorization": f"Bearer {auth_token}"}, + follow_redirects=False, + timeout=30, + ) + if response.status_code != 302: + raise RuntimeError( + f"Bulk JSON endpoint did not redirect: {response.status_code} {response.text[:200]}" + ) + location = response.headers.get("location") + if not location: + raise RuntimeError("Bulk JSON 302 had no Location header") + return location + +