slice 14a: ml_training_data pkg + sample.py (CSV filter + random sample)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Khalim Conn-Kowlessar 2026-05-16 17:39:43 +00:00
parent 3abcee6a53
commit eb42cb88a1
9 changed files with 155 additions and 1 deletions

1
.gitignore vendored
View file

@ -241,6 +241,7 @@ fabric.properties
# Locally stored data
local_data/*
/local_data/*
/data/ml_training/
etl/epc/local_data/*
/backend/condition/sample_data/lbwf/*
/backend/condition/sample_data/peabody/*

View file

@ -10,4 +10,5 @@ members = [
"packages/fetchers",
"packages/utils",
"services/ara",
"services/ml_training_data",
]

View file

@ -1,5 +1,5 @@
[pytest]
pythonpath = . packages/domain/src
pythonpath = . packages/domain/src services/ml_training_data/src
log_cli = true
log_cli_level = INFO
addopts = --cov-report term-missing --cov=etl/epc --cov=recommendations --cov=backend --cov=etl/epc_clean --cov=etl/spatial

View file

@ -0,0 +1,16 @@
[project]
name = "ml-training-data"
version = "0.1.0"
description = "Pipeline that turns the EPC open-data CSV into ML training parquet + baseline models."
requires-python = ">=3.11"
dependencies = [
"pandas>=2.0",
"pandas-stubs",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/ml_training_data"]

View file

@ -0,0 +1,5 @@
"""EPC CSV → training-data pipeline.
Produces parquet + schema.json + manifest.json for baseline LightGBM training.
See ara_backend_design.md (repo root) for the pipeline shape.
"""

View file

@ -0,0 +1,33 @@
"""Sample certificate rows from the EPC flat-register CSV.
The flat-register CSV (2.4M rows) is the *only* exhaustive list of certificate
numbers; per-certificate detail is fetched separately downstream. sample() returns
a thin DataFrame keyed by certificate_number so later stages know which records to
fetch and how to join their per-cert JSON back to register-side metadata.
"""
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Optional
import pandas as pd
def sample(
csv_path: Path,
n: int,
seed: int,
filters: Optional[Mapping[str, Sequence[str]]] = None,
) -> pd.DataFrame:
df: pd.DataFrame = pd.read_csv(csv_path, dtype=str, keep_default_na=False)
if filters:
for column, allowed in filters.items():
if column not in df.columns:
raise KeyError(f"filter column not present in CSV: {column!r}")
df = df[df[column].isin(list(allowed))]
if len(df) <= n:
return df.reset_index(drop=True)
return df.sample(n=n, random_state=seed).reset_index(drop=True)

View file

@ -0,0 +1,98 @@
"""Tests for sample.sample() — random + filterable selection over the EPC flat-register CSV.
sample() is the entry point of the training-data pipeline: it produces a thin DataFrame
of certificate rows that downstream stages (fetch -> build_features -> write_parquet)
operate on. Filtering supports excluding obviously-wrong cohorts before paying the
per-cert fetch cost.
"""
from pathlib import Path
import pandas as pd
import pytest
from ml_training_data.sample import sample
def _write_csv(path: Path, rows: list[dict[str, str]]) -> None:
df = pd.DataFrame(rows)
df.to_csv(path, index=False)
def test_sample_returns_n_rows_when_no_filter(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows = [
{"certificate_number": f"CN-{i:04d}", "property_type": "House", "postcode": "AB1 2CD"}
for i in range(100)
]
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=10, seed=42)
# Assert
assert len(out) == 10
assert set(out.columns) == {"certificate_number", "property_type", "postcode"}
def test_sample_is_deterministic_with_same_seed(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
_write_csv(
csv,
[{"certificate_number": f"CN-{i:04d}", "property_type": "House"} for i in range(200)],
)
# Act
first = sample(csv_path=csv, n=20, seed=7)
second = sample(csv_path=csv, n=20, seed=7)
# Assert
assert first["certificate_number"].tolist() == second["certificate_number"].tolist()
def test_sample_filter_selects_only_matching_rows(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows: list[dict[str, str]] = []
for i in range(50):
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
for i in range(50):
rows.append({"certificate_number": f"F-{i:03d}", "property_type": "Flat"})
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=30, seed=1, filters={"property_type": ["House"]})
# Assert
assert len(out) == 30
assert (out["property_type"] == "House").all()
def test_sample_returns_fewer_than_n_when_filter_too_strict(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
rows: list[dict[str, str]] = []
for i in range(3):
rows.append({"certificate_number": f"BG-{i}", "property_type": "Bungalow"})
for i in range(50):
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
_write_csv(csv, rows)
# Act
out = sample(csv_path=csv, n=100, seed=1, filters={"property_type": ["Bungalow"]})
# Assert
assert len(out) == 3
assert (out["property_type"] == "Bungalow").all()
def test_sample_raises_when_filter_column_missing(tmp_path: Path) -> None:
# Arrange
csv = tmp_path / "register.csv"
_write_csv(csv, [{"certificate_number": "CN-001", "property_type": "House"}])
# Act / Assert
with pytest.raises(KeyError):
sample(csv_path=csv, n=1, seed=1, filters={"nonexistent": ["x"]})