mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
118 lines
3.3 KiB
Python
118 lines
3.3 KiB
Python
"""
|
|
Script to regenerate metrics for all the models in the model registry
|
|
Key task:
|
|
- Load model registry
|
|
- For each model in the registry, generate the metrics (Key questions here is what if the test data changes)
|
|
- Save the new metrics out to s3 bucket
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
from s3pathlib import S3Path
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from core.Logger import logger
|
|
from core.Metrics import Metrics, sort_by_metric
|
|
from core.DataLoader import dataloader_factory
|
|
from core.Settings import (
|
|
OPTIMISE_METRIC,
|
|
MODEL_DIRECTORY,
|
|
REGISTRY_FILE,
|
|
BEST_MODEL_COLUMN_NAME,
|
|
)
|
|
from MLModel.Models import model_factory
|
|
|
|
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
|
|
|
|
|
def ingest_arguments() -> argparse.Namespace:
|
|
"""
|
|
Helper function to take in arguments from script start
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description="Inputs for training script")
|
|
|
|
parser.add_argument(
|
|
"--test-filepath",
|
|
type=str,
|
|
help="Location of Parquet dataset to load for testing",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--target-column",
|
|
type=str,
|
|
help="The response variable",
|
|
choices=["RDSAP_CHANGE", "HEAT_DEMAND_CHANGE"],
|
|
default="RDSAP_CHANGE",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def regenerate_metrics(test_filepath: str, target_column: str) -> None:
|
|
"""
|
|
Recreate all metrics for all models
|
|
"""
|
|
|
|
logger.info("--- Loading test data ---")
|
|
dataloader = dataloader_factory(runtime_environment=RUNTIME_ENVIRONMENT)
|
|
test_df = dataloader.load(filepath=test_filepath)
|
|
|
|
logger.info("--- Loading model registry ---")
|
|
logger.info(f"Loading registry for {target_column} models")
|
|
registry_df = dataloader.load(
|
|
filepath=S3Path(MODEL_DIRECTORY, target_column, REGISTRY_FILE).uri
|
|
)
|
|
|
|
logger.info("Extract non-metric columns")
|
|
registry_df = registry_df[["model_type", "model_name", "model_location"]]
|
|
|
|
logger.info("--- Regenerating metrics ---")
|
|
|
|
metric_suite = Metrics()
|
|
|
|
metrics_df = pd.DataFrame(columns=metric_suite.list_metric_functions())
|
|
|
|
for _, row in tqdm(registry_df.iterrows()):
|
|
|
|
logger.info(f"--- Loading Model ({row['model_name']}) ---")
|
|
|
|
model = model_factory(model_type=row["model_type"])()
|
|
|
|
model.load_model(filepath=row["model_location"])
|
|
|
|
metrics = metric_suite.generate_metric_suite(
|
|
model=model, data=test_df, target_column=target_column
|
|
)
|
|
|
|
# Add metrics row by row
|
|
metrics_df = pd.concat([metrics_df, metrics], axis=0).reset_index(drop=True)
|
|
|
|
# Add metrics df to registry df side by side
|
|
registry_df = pd.concat([registry_df, metrics_df], axis=1)
|
|
|
|
logger.info(f"--- Sorting by Optimise Metric ({OPTIMISE_METRIC}) ---")
|
|
|
|
registry_df = sort_by_metric(
|
|
data=registry_df,
|
|
optimse_metric=OPTIMISE_METRIC,
|
|
best_model_column_name=BEST_MODEL_COLUMN_NAME,
|
|
)
|
|
|
|
logger.info("--- Saving model metrics ---")
|
|
|
|
registry_df.to_csv(S3Path(MODEL_DIRECTORY, target_column, REGISTRY_FILE).uri)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
logger.info("---Begin Pipeline to regenerate metrics---")
|
|
|
|
logger.info("---Ingest Arguments---")
|
|
args = ingest_arguments()
|
|
|
|
regenerate_metrics(
|
|
test_filepath=args.test_filepath, target_column=args.target_column
|
|
)
|