mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
slice 14a: ml_training_data pkg + sample.py (CSV filter + random sample)
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
3abcee6a53
commit
eb42cb88a1
9 changed files with 155 additions and 1 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -241,6 +241,7 @@ fabric.properties
|
|||
# Locally stored data
|
||||
local_data/*
|
||||
/local_data/*
|
||||
/data/ml_training/
|
||||
etl/epc/local_data/*
|
||||
/backend/condition/sample_data/lbwf/*
|
||||
/backend/condition/sample_data/peabody/*
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ members = [
|
|||
"packages/fetchers",
|
||||
"packages/utils",
|
||||
"services/ara",
|
||||
"services/ml_training_data",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[pytest]
|
||||
pythonpath = . packages/domain/src
|
||||
pythonpath = . packages/domain/src services/ml_training_data/src
|
||||
log_cli = true
|
||||
log_cli_level = INFO
|
||||
addopts = --cov-report term-missing --cov=etl/epc --cov=recommendations --cov=backend --cov=etl/epc_clean --cov=etl/spatial
|
||||
|
|
|
|||
16
services/ml_training_data/pyproject.toml
Normal file
16
services/ml_training_data/pyproject.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[project]
|
||||
name = "ml-training-data"
|
||||
version = "0.1.0"
|
||||
description = "Pipeline that turns the EPC open-data CSV into ML training parquet + baseline models."
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"pandas>=2.0",
|
||||
"pandas-stubs",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/ml_training_data"]
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""EPC CSV → training-data pipeline.
|
||||
|
||||
Produces parquet + schema.json + manifest.json for baseline LightGBM training.
|
||||
See ara_backend_design.md (repo root) for the pipeline shape.
|
||||
"""
|
||||
33
services/ml_training_data/src/ml_training_data/sample.py
Normal file
33
services/ml_training_data/src/ml_training_data/sample.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Sample certificate rows from the EPC flat-register CSV.
|
||||
|
||||
The flat-register CSV (2.4M rows) is the *only* exhaustive list of certificate
|
||||
numbers; per-certificate detail is fetched separately downstream. sample() returns
|
||||
a thin DataFrame keyed by certificate_number so later stages know which records to
|
||||
fetch and how to join their per-cert JSON back to register-side metadata.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sample(
|
||||
csv_path: Path,
|
||||
n: int,
|
||||
seed: int,
|
||||
filters: Optional[Mapping[str, Sequence[str]]] = None,
|
||||
) -> pd.DataFrame:
|
||||
df: pd.DataFrame = pd.read_csv(csv_path, dtype=str, keep_default_na=False)
|
||||
|
||||
if filters:
|
||||
for column, allowed in filters.items():
|
||||
if column not in df.columns:
|
||||
raise KeyError(f"filter column not present in CSV: {column!r}")
|
||||
df = df[df[column].isin(list(allowed))]
|
||||
|
||||
if len(df) <= n:
|
||||
return df.reset_index(drop=True)
|
||||
|
||||
return df.sample(n=n, random_state=seed).reset_index(drop=True)
|
||||
0
services/ml_training_data/tests/__init__.py
Normal file
0
services/ml_training_data/tests/__init__.py
Normal file
0
services/ml_training_data/tests/unit/__init__.py
Normal file
0
services/ml_training_data/tests/unit/__init__.py
Normal file
98
services/ml_training_data/tests/unit/test_sample.py
Normal file
98
services/ml_training_data/tests/unit/test_sample.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for sample.sample() — random + filterable selection over the EPC flat-register CSV.
|
||||
|
||||
sample() is the entry point of the training-data pipeline: it produces a thin DataFrame
|
||||
of certificate rows that downstream stages (fetch -> build_features -> write_parquet)
|
||||
operate on. Filtering supports excluding obviously-wrong cohorts before paying the
|
||||
per-cert fetch cost.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from ml_training_data.sample import sample
|
||||
|
||||
|
||||
def _write_csv(path: Path, rows: list[dict[str, str]]) -> None:
|
||||
df = pd.DataFrame(rows)
|
||||
df.to_csv(path, index=False)
|
||||
|
||||
|
||||
def test_sample_returns_n_rows_when_no_filter(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv = tmp_path / "register.csv"
|
||||
rows = [
|
||||
{"certificate_number": f"CN-{i:04d}", "property_type": "House", "postcode": "AB1 2CD"}
|
||||
for i in range(100)
|
||||
]
|
||||
_write_csv(csv, rows)
|
||||
|
||||
# Act
|
||||
out = sample(csv_path=csv, n=10, seed=42)
|
||||
|
||||
# Assert
|
||||
assert len(out) == 10
|
||||
assert set(out.columns) == {"certificate_number", "property_type", "postcode"}
|
||||
|
||||
|
||||
def test_sample_is_deterministic_with_same_seed(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv = tmp_path / "register.csv"
|
||||
_write_csv(
|
||||
csv,
|
||||
[{"certificate_number": f"CN-{i:04d}", "property_type": "House"} for i in range(200)],
|
||||
)
|
||||
|
||||
# Act
|
||||
first = sample(csv_path=csv, n=20, seed=7)
|
||||
second = sample(csv_path=csv, n=20, seed=7)
|
||||
|
||||
# Assert
|
||||
assert first["certificate_number"].tolist() == second["certificate_number"].tolist()
|
||||
|
||||
|
||||
def test_sample_filter_selects_only_matching_rows(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv = tmp_path / "register.csv"
|
||||
rows: list[dict[str, str]] = []
|
||||
for i in range(50):
|
||||
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
|
||||
for i in range(50):
|
||||
rows.append({"certificate_number": f"F-{i:03d}", "property_type": "Flat"})
|
||||
_write_csv(csv, rows)
|
||||
|
||||
# Act
|
||||
out = sample(csv_path=csv, n=30, seed=1, filters={"property_type": ["House"]})
|
||||
|
||||
# Assert
|
||||
assert len(out) == 30
|
||||
assert (out["property_type"] == "House").all()
|
||||
|
||||
|
||||
def test_sample_returns_fewer_than_n_when_filter_too_strict(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv = tmp_path / "register.csv"
|
||||
rows: list[dict[str, str]] = []
|
||||
for i in range(3):
|
||||
rows.append({"certificate_number": f"BG-{i}", "property_type": "Bungalow"})
|
||||
for i in range(50):
|
||||
rows.append({"certificate_number": f"H-{i:03d}", "property_type": "House"})
|
||||
_write_csv(csv, rows)
|
||||
|
||||
# Act
|
||||
out = sample(csv_path=csv, n=100, seed=1, filters={"property_type": ["Bungalow"]})
|
||||
|
||||
# Assert
|
||||
assert len(out) == 3
|
||||
assert (out["property_type"] == "Bungalow").all()
|
||||
|
||||
|
||||
def test_sample_raises_when_filter_column_missing(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
csv = tmp_path / "register.csv"
|
||||
_write_csv(csv, [{"certificate_number": "CN-001", "property_type": "House"}])
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(KeyError):
|
||||
sample(csv_path=csv, n=1, seed=1, filters={"nonexistent": ["x"]})
|
||||
Loading…
Add table
Reference in a new issue