Model/model_data/simulation_system/regenerate_metrics.py
2023-08-31 14:46:10 +01:00

118 lines
3.3 KiB
Python

"""
Script to regenerate metrics for all the models in the model registry
Key task:
- Load model registry
- For each model in the registry, generate the metrics (Key questions here is what if the test data changes)
- Save the new metrics out to s3 bucket
"""
import os
import argparse
from s3pathlib import S3Path
import pandas as pd
from tqdm import tqdm
from core.Logger import logger
from core.Metrics import Metrics, sort_by_metric
from core.DataLoader import dataloader_factory
from core.Settings import (
OPTIMISE_METRIC,
MODEL_DIRECTORY,
REGISTRY_FILE,
BEST_MODEL_COLUMN_NAME,
)
from MLModel.Models import model_factory
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
def ingest_arguments() -> argparse.Namespace:
"""
Helper function to take in arguments from script start
"""
parser = argparse.ArgumentParser(description="Inputs for training script")
parser.add_argument(
"--test-filepath",
type=str,
help="Location of Parquet dataset to load for testing",
required=True,
)
parser.add_argument(
"--target-column",
type=str,
help="The response variable",
choices=["RDSAP_CHANGE", "HEAT_DEMAND_CHANGE"],
default="RDSAP_CHANGE",
)
args = parser.parse_args()
return args
def regenerate_metrics(test_filepath: str, target_column: str) -> None:
"""
Recreate all metrics for all models
"""
logger.info("--- Loading test data ---")
dataloader = dataloader_factory(runtime_environment=RUNTIME_ENVIRONMENT)
test_df = dataloader.load(filepath=test_filepath)
logger.info("--- Loading model registry ---")
logger.info(f"Loading registry for {target_column} models")
registry_df = dataloader.load(
filepath=S3Path(MODEL_DIRECTORY, target_column, REGISTRY_FILE).uri
)
logger.info("Extract non-metric columns")
registry_df = registry_df[["model_type", "model_name", "model_location"]]
logger.info("--- Regenerating metrics ---")
metric_suite = Metrics()
metrics_df = pd.DataFrame(columns=metric_suite.list_metric_functions())
for _, row in tqdm(registry_df.iterrows()):
logger.info(f"--- Loading Model ({row['model_name']}) ---")
model = model_factory(model_type=row["model_type"])()
model.load_model(filepath=row["model_location"])
metrics = metric_suite.generate_metric_suite(
model=model, data=test_df, target_column=target_column
)
# Add metrics row by row
metrics_df = pd.concat([metrics_df, metrics], axis=0).reset_index(drop=True)
# Add metrics df to registry df side by side
registry_df = pd.concat([registry_df, metrics_df], axis=1)
logger.info(f"--- Sorting by Optimise Metric ({OPTIMISE_METRIC}) ---")
registry_df = sort_by_metric(
data=registry_df,
optimse_metric=OPTIMISE_METRIC,
best_model_column_name=BEST_MODEL_COLUMN_NAME,
)
logger.info("--- Saving model metrics ---")
registry_df.to_csv(S3Path(MODEL_DIRECTORY, target_column, REGISTRY_FILE).uri)
if __name__ == "__main__":
logger.info("---Begin Pipeline to regenerate metrics---")
logger.info("---Ingest Arguments---")
args = ingest_arguments()
regenerate_metrics(
test_filepath=args.test_filepath, target_column=args.target_column
)