mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
slice 14e: write_training_dataset emits parquet + schema.json + manifest.json
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
20fd55d5a1
commit
23ba2ef271
2 changed files with 152 additions and 0 deletions
|
|
@ -0,0 +1,54 @@
|
|||
"""Persist a training-feature DataFrame as parquet + schema.json + manifest.json.
|
||||
|
||||
The output triple is the artefact contract this pipeline hands to downstream model
|
||||
training: parquet for the data, schema.json for dtype intent, manifest.json for
|
||||
run provenance. Writes go through Storage so the same code lands on local-fs or S3
|
||||
without a callsite change.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from domain.ml.schema import ColumnSpec, TransformSchema
|
||||
from ml_training_data.storage import Storage
|
||||
|
||||
|
||||
def write_training_dataset(
|
||||
df: pd.DataFrame,
|
||||
storage: Storage,
|
||||
run_key: str,
|
||||
*,
|
||||
schema: TransformSchema,
|
||||
source_info: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, engine="pyarrow", index=False)
|
||||
storage.write_bytes(f"{run_key}data.parquet", buf.getvalue())
|
||||
|
||||
schema_doc = {
|
||||
"transform_version": schema.transform_version,
|
||||
"features": {name: _column_to_json(spec) for name, spec in schema.feature_columns.items()},
|
||||
"targets": {name: _column_to_json(spec) for name, spec in schema.target_columns.items()},
|
||||
}
|
||||
storage.write_bytes(f"{run_key}schema.json", json.dumps(schema_doc, indent=2).encode("utf-8"))
|
||||
|
||||
manifest = {
|
||||
"transform_version": schema.transform_version,
|
||||
"row_count": len(df),
|
||||
"written_at": datetime.now(timezone.utc).isoformat(),
|
||||
"source_info": source_info or {},
|
||||
}
|
||||
storage.write_bytes(f"{run_key}manifest.json", json.dumps(manifest, indent=2).encode("utf-8"))
|
||||
|
||||
|
||||
def _column_to_json(spec: ColumnSpec) -> dict[str, object]:
|
||||
return {
|
||||
"dtype": spec.dtype.__name__,
|
||||
"nullable": spec.nullable,
|
||||
"categorical": spec.categorical,
|
||||
"description": spec.description,
|
||||
}
|
||||
98
services/ml_training_data/tests/unit/test_write_parquet.py
Normal file
98
services/ml_training_data/tests/unit/test_write_parquet.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Tests for write_training_dataset() — parquet + schema.json + manifest.json.
|
||||
|
||||
The output triple is the contract handed to the AutoGluon training repo:
|
||||
- data.parquet: feature+target rows
|
||||
- schema.json: column specs (categorical flags, target list) so the consumer
|
||||
can reconstruct dtype intent without re-parsing the transform module.
|
||||
- manifest.json: run metadata (when, transform version, source bulk-ZIP info).
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from domain.ml.schema import ColumnSpec, TransformSchema
|
||||
from ml_training_data.storage import LocalStorage
|
||||
from ml_training_data.write_parquet import write_training_dataset
|
||||
|
||||
|
||||
def _toy_schema() -> "TransformSchema":
|
||||
return TransformSchema(
|
||||
transform_version="0.1.0",
|
||||
feature_columns={
|
||||
"total_floor_area_m2": ColumnSpec(
|
||||
dtype=float, nullable=False, description="floor area"
|
||||
),
|
||||
"property_type": ColumnSpec(
|
||||
dtype=str, nullable=True, description="cat", categorical=True
|
||||
),
|
||||
},
|
||||
target_columns={
|
||||
"sap_score": ColumnSpec(dtype=int, nullable=False, description="target"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_write_training_dataset_persists_dataframe_to_parquet(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
storage = LocalStorage(root=tmp_path)
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"certificate_number": ["A-1", "A-2"],
|
||||
"total_floor_area_m2": [80.0, 120.0],
|
||||
"sap_score": [70, 85],
|
||||
}
|
||||
)
|
||||
|
||||
# Act
|
||||
write_training_dataset(
|
||||
df=df, storage=storage, run_key="runs/2026-05-16/", schema=_toy_schema()
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert storage.exists("runs/2026-05-16/data.parquet")
|
||||
roundtrip = pd.read_parquet(tmp_path / "runs/2026-05-16/data.parquet")
|
||||
assert roundtrip["certificate_number"].tolist() == ["A-1", "A-2"]
|
||||
assert roundtrip["sap_score"].tolist() == [70, 85]
|
||||
|
||||
|
||||
def test_write_training_dataset_writes_schema_json_alongside_parquet(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
storage = LocalStorage(root=tmp_path)
|
||||
df = pd.DataFrame({"total_floor_area_m2": [80.0], "property_type": ["House"], "sap_score": [70]})
|
||||
|
||||
# Act
|
||||
write_training_dataset(
|
||||
df=df, storage=storage, run_key="runs/2026-05-16/", schema=_toy_schema()
|
||||
)
|
||||
|
||||
# Assert
|
||||
schema_doc = json.loads(storage.read_bytes("runs/2026-05-16/schema.json"))
|
||||
assert schema_doc["transform_version"] == "0.1.0"
|
||||
assert "total_floor_area_m2" in schema_doc["features"]
|
||||
assert schema_doc["features"]["property_type"]["categorical"] is True
|
||||
assert schema_doc["targets"]["sap_score"]["dtype"] == "int"
|
||||
|
||||
|
||||
def test_write_training_dataset_writes_manifest_with_row_count_and_source(tmp_path: Path) -> None:
|
||||
# Arrange
|
||||
storage = LocalStorage(root=tmp_path)
|
||||
df = pd.DataFrame({"total_floor_area_m2": [80.0, 95.0, 110.0], "sap_score": [70, 75, 80]})
|
||||
source_info = {"bulk_zip_last_updated": "2026-05-14T09:59:32Z", "bulk_zip_size_bytes": 15_642_371_075}
|
||||
|
||||
# Act
|
||||
write_training_dataset(
|
||||
df=df,
|
||||
storage=storage,
|
||||
run_key="runs/2026-05-16/",
|
||||
schema=_toy_schema(),
|
||||
source_info=source_info,
|
||||
)
|
||||
|
||||
# Assert
|
||||
manifest = json.loads(storage.read_bytes("runs/2026-05-16/manifest.json"))
|
||||
assert manifest["row_count"] == 3
|
||||
assert manifest["transform_version"] == "0.1.0"
|
||||
assert manifest["source_info"] == source_info
|
||||
assert "written_at" in manifest
|
||||
Loading…
Add table
Reference in a new issue