diff --git a/services/ml_training_data/src/ml_training_data/build_features.py b/services/ml_training_data/src/ml_training_data/build_features.py index a5ce40f2..17a69cbc 100644 --- a/services/ml_training_data/src/ml_training_data/build_features.py +++ b/services/ml_training_data/src/ml_training_data/build_features.py @@ -5,7 +5,9 @@ JSON-encoded `document` payload. This function: - Filters out non-RdSAP assessments (our mapper only handles RdSAP-Schema-21.x). - Parses `document` and feeds it to EpcPropertyDataMapper.from_api_response. - - Calls EpcMlTransform.to_rows over the parsed properties. + - Calls EpcMlTransform.to_row per cert (streaming) and discards the heavy + EpcPropertyData immediately so memory stays O(row-dict) per cert rather than + O(EpcPropertyData * n) — critical for the 500k+ cert full-year runs. - Prepends a `certificate_number` column so every row is traceable to its source. """ @@ -15,7 +17,7 @@ from typing import Any, cast import pandas as pd from datatypes.epc.domain.mapper import EpcPropertyDataMapper -from datatypes.epc.domain.epc_property_data import EpcPropertyData +from domain.ml.schema import TransformSchema from domain.ml.transform import EpcMlTransform from ml_training_data.bulk_zip_reader import BulkZipReader @@ -29,7 +31,7 @@ def build_features( skip_unsupported_schemas: bool = True, ) -> pd.DataFrame: transform = EpcMlTransform() - properties: list[EpcPropertyData] = [] + rows: list[dict[str, Any]] = [] cert_nums: list[str] = [] for record in bulk_reader.iter_certificates_filtered(certificate_numbers): if record.get("assessment_type") != _RDSAP_ASSESSMENT_TYPE: @@ -47,8 +49,26 @@ def build_features( if skip_unsupported_schemas: continue raise - properties.append(prop) + rows.append(transform.to_row(prop)) cert_nums.append(str(record["certificate_number"])) - df = transform.to_rows(properties) + # prop and document drop out of scope; GC reclaims before the next iter. + df = _frame_from_rows(rows, transform.schema()) df["certificate_number"] = cert_nums return df[["certificate_number", *[c for c in df.columns if c != "certificate_number"]]] + + +def _frame_from_rows(rows: list[dict[str, Any]], schema: TransformSchema) -> pd.DataFrame: + """Build the typed DataFrame from streamed row dicts. + + Mirrors EpcMlTransform.to_rows post-processing: full column set even when empty, + and pd.Categorical casts for any column flagged categorical in the schema. + """ + all_columns = list(schema.feature_columns.keys()) + list(schema.target_columns.keys()) + df = pd.DataFrame(rows, columns=all_columns) + for name, spec in schema.feature_columns.items(): + if spec.categorical: + df[name] = df[name].astype("category") + for name, spec in schema.target_columns.items(): + if spec.categorical: + df[name] = df[name].astype("category") + return df