mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
188 lines
6 KiB
Python
188 lines
6 KiB
Python
"""
|
|
Script to load MLModel class and generate predictions
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import argparse
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
from typing import Optional
|
|
from datetime import datetime
|
|
from MLModel.Models import AutogluonModel
|
|
from core.Logger import logger
|
|
from core.DataLoader import dataloader_factory
|
|
from core.CloudClient import BotoClient
|
|
from core.Metrics import Metrics
|
|
from core.RegistryHandler import RegistryHandler
|
|
from core.Settings import (
|
|
BASE_REGISTRY_PATH,
|
|
REGISTRY_FILE,
|
|
PREDICTION_LOCATION,
|
|
PREDICTION_FILE,
|
|
METADATA_FILE,
|
|
TIMESTAMP_FORMAT,
|
|
MODEL_DIRECTORY,
|
|
)
|
|
|
|
TIMESTAMP = datetime.now().strftime(TIMESTAMP_FORMAT)
|
|
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "dev")
|
|
|
|
CLIENT = BotoClient(runtime_environment=RUNTIME_ENVIRONMENT)
|
|
|
|
# FOR TESTING
|
|
# For now just loading data first and then passing into function (i.e. as if we receive json data and convert to
|
|
# DataFrame)
|
|
# TEST_DATA = DataLoader.load(filepath="../simulation_system/model_build_data/change_data/rdsap_full/test_data.parquet")
|
|
# DATA = TEST_DATA.sample(1)
|
|
|
|
# For testing in dev s3
|
|
# Data path can be passed as so:
|
|
# python3 predictions.py --data-path s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data.parquet
|
|
# data_path="s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data_with_id.parquet"
|
|
|
|
|
|
def ingest_arguments() -> argparse.Namespace:
|
|
"""
|
|
Helper function to take in arguments from script start
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description="Inputs for training script")
|
|
parser.add_argument(
|
|
"--target-column",
|
|
type=str,
|
|
help="The response variable you are predicting for",
|
|
choices=["RDSAP_CHANGE", "HEAT_DEMAND_CHANGE"],
|
|
default="RDSAP_CHANGE",
|
|
)
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
help="If you wish to use a specific model, specify the model path here",
|
|
)
|
|
parser.add_argument("--data", type=str, help="Json data for predictions")
|
|
parser.add_argument(
|
|
"--data-path", type=str, help="Location of Parquet dataset to load for training"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def prediction(
|
|
target_column: str = "RDSAP_CHANGE",
|
|
model_path: str | None = None,
|
|
data: Optional[pd.DataFrame | str] = None,
|
|
data_path: Optional[str] = None,
|
|
):
|
|
"""
|
|
Main pipeline function
|
|
"""
|
|
|
|
if model_path is not None:
|
|
logger.info("User specified a model to load - ignoring registry")
|
|
model_location = model_path
|
|
model_type = model_path
|
|
model_name = model_path
|
|
else:
|
|
# TODO: Think about where registry will sit/ type
|
|
logger.info("Loading best model from registry")
|
|
|
|
metrics = Metrics()
|
|
registry_handler = RegistryHandler()
|
|
|
|
registry_path = Path(MODEL_DIRECTORY) / target_column / REGISTRY_FILE
|
|
|
|
registry_df = registry_handler.load_registry(
|
|
registry_path=registry_path, client=CLIENT, metrics=metrics
|
|
)
|
|
|
|
# registry_df = pd.read_csv(registry_path)
|
|
best_model_df = registry_df[registry_df["best_model"]]
|
|
|
|
model_location = best_model_df["model_location"].values[0]
|
|
model_type = best_model_df["model_type"].values[0]
|
|
model_name = best_model_df["model_name"].values[0]
|
|
|
|
logger.info("--- Model Info: ---")
|
|
logger.info(f"Model type: {model_type}")
|
|
logger.info(f"Model name: {model_name}")
|
|
logger.info(f"Model location: {model_location}")
|
|
|
|
logger.info("--- Loading Data ---")
|
|
if data is None and data_path is None:
|
|
logger.error("No Data/Data Path passed")
|
|
exit(1)
|
|
if data_path and data is None:
|
|
logger.info("Loading data from provided path")
|
|
dataloader = dataloader_factory(runtime_environment=RUNTIME_ENVIRONMENT)
|
|
data = dataloader.load(client=CLIENT, filepath=data_path, index_col=None)
|
|
|
|
if data is None:
|
|
raise ValueError("No data loaded")
|
|
|
|
else:
|
|
logger.info("Using data provided")
|
|
data = json.loads(str(data))
|
|
data = pd.DataFrame([data])
|
|
|
|
logger.info("--- Loading Model ---")
|
|
|
|
if model_type == "autogluon":
|
|
logger.info("Using an Autogluon model")
|
|
model = AutogluonModel()
|
|
else:
|
|
raise ValueError("No other model currently")
|
|
|
|
# In lambda, only the /tmp folder is writable
|
|
model_folder = "/tmp" if RUNTIME_ENVIRONMENT in ["dev", "prod"] else "local_model"
|
|
|
|
model.load_model(filepath=model_location, client=CLIENT, model_folder=model_folder)
|
|
|
|
logger.info("--- Generating Predictions ---")
|
|
prediction = model.generate_predictions(data=data)
|
|
|
|
# logger.info(pd.concat([data["id"], prediction], axis=1))
|
|
|
|
return pd.concat([data["id"], prediction], axis=1)
|
|
|
|
# Save prediction some where?
|
|
# prediction.to_csv("s3?")
|
|
|
|
# TODO: Check how we want to structure outputs
|
|
# For now, just categorise by uprn and timestamp
|
|
# Assume one uprn coming in for now
|
|
# uprn = data.index.values[0]
|
|
|
|
# # Saving prediction local for now
|
|
# # TODO: change uprn to TARGET_ID, put in setting
|
|
# logger.info("--- Outputting prediction and metadata --- ")
|
|
# output_base = PREDICTION_LOCATION / target_column / uprn / TIMESTAMP
|
|
# output_base.mkdir(parents=True, exist_ok=True)
|
|
|
|
# json_prediction = prediction.to_json(output_base / PREDICTION_FILE)
|
|
# prediction_metadata = {
|
|
# "model_type": model_type,
|
|
# "model_name": model_name,
|
|
# "model_location": model_location,
|
|
# "model_settings": model.model_metadata(),
|
|
# }
|
|
|
|
# pd.DataFrame([prediction_metadata]).to_json(output_base / METADATA_FILE)
|
|
|
|
return json_prediction
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = ingest_arguments()
|
|
|
|
# Data can be passed in as JSON string: python3 predictions.py --data '{"TOTAL_FLOOR_AREA": 1}'
|
|
# Data path can be passed as so: python3 predictions.py --data-path
|
|
# ./model_build_data/change_data/rdsap_full/test_data.parquet
|
|
prediction(
|
|
target_column=args.target_column,
|
|
model_path=args.model_path,
|
|
data=args.data,
|
|
data_path=args.data_path,
|
|
)
|