Model/model_data/simulation_system/core/Metrics.py
2023-09-01 19:25:35 +01:00

167 lines
5.2 KiB
Python

"""
Generate metrics and enable regeneration of metrics if new metrics are generated
Key tasks:
- Specify metric functions that take in prediction vs actual to generate a metric value
- Given a model and test data, produce a suite of all metrics
"""
import os
import pandas as pd
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from core.CloudClient import BotoClient
from core.Logger import logger
from core.Settings import (
RESIDUAL_TRUE_LABEL,
RESIDUAL_PREDICTION_LABEL,
SEABORN_RESIDUAL_AXIS_FONTSIZE,
SEABORN_RESIDUAL_TITLE_FONTSIZE,
SEABORN_RESIDUAL_STYLE,
SEABORN_RESIDUAL_ASPECT_RATIO,
SEABORN_RESIDUAL_PLOT_DPI,
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_LINE_COLOUR,
SEABORN_RESIDUAL_LINE_WIDTH,
)
from sklearn.metrics import (
mean_absolute_error,
median_absolute_error,
mean_squared_error,
mean_absolute_percentage_error,
)
# Dummy example of new metric that can be added - must be true and prediction as arguments
def max_error(y_true: pd.Series, y_pred: pd.Series):
return max(y_true - y_pred)
METRIC_TO_APPLY = [
mean_absolute_error,
median_absolute_error,
mean_squared_error,
mean_absolute_percentage_error,
# max_error
]
def sort_by_metric(
data: pd.DataFrame, optimse_metric: str, best_model_column_name: str
) -> pd.DataFrame:
"""
Helper function to sort data frame by metric and append a best model flag
"""
# Ascending as we want lowest error values
data = data.sort_values(optimse_metric, ascending=True).reset_index(drop=True)
data[best_model_column_name] = [False] * len(data)
data.loc[0, best_model_column_name] = True
return data
class Metrics:
"""
All metric functions used to generate a dictionary of metrics
"""
def upload_metrics(self, output_filepath: Path, client: BotoClient) -> None:
"""
Providing a path, this function will save the metrics folders/files.
"""
if client.client is None:
logger.info("In local development mode - no need to upload")
else:
logger.info(f"Saving metrics into s3")
s3_location = client.model_bucket + "/" + str(output_filepath)
self.directory_upload(
client=client,
local_directory=str(output_filepath),
bucket_name=client.model_bucket,
)
logger.info("Save complete")
def directory_upload(self, client, local_directory, bucket_name):
# Iterate through the local directory and upload each file
for root, dirs, files in os.walk(local_directory):
for file in files:
# Determine the local file path and S3 object key
local_file_path = os.path.join(root, file)
s3_object_key = os.path.relpath(local_file_path, local_directory)
# Upload the file to S3
client.client.upload_file(local_file_path, bucket_name, local_file_path)
logger.info(
f"Uploaded {local_file_path} to {bucket_name}/{local_file_path}"
)
@staticmethod
def list_metric_functions() -> list:
"""
Gather all metric functions to run
"""
return [metric_to_apply.__name__ for metric_to_apply in METRIC_TO_APPLY]
@staticmethod
def generate_metric_suite(actuals: pd.Series, predictions: pd.Series) -> pd.Series:
"""
For the model, test data and target, generate predictions and then iterative over all metrics to generate a Series of metric values
"""
metric_dict = {}
for metric_function in METRIC_TO_APPLY:
metric_dict[metric_function.__name__] = metric_function(
actuals, predictions
)
metrics = pd.Series(metric_dict)
return metrics
@staticmethod
def generate_plot_suite():
"""
Can do all metric ploting
"""
@staticmethod
def generate_residual_plot(
actuals: pd.Series,
predictions: pd.Series,
target_column: str,
output_filepath: Path | str,
):
# TODO: can have a model.metric_outputs method
# FOr not just do it here
residual_df = pd.DataFrame(
list(zip(actuals, predictions)),
columns=[RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL],
)
# image formatting
sns.set(style=SEABORN_RESIDUAL_STYLE)
ax = sns.scatterplot(
x=RESIDUAL_TRUE_LABEL, y=RESIDUAL_PREDICTION_LABEL, data=residual_df
)
ax.set_aspect(SEABORN_RESIDUAL_ASPECT_RATIO)
ax.set_xlabel(f"True {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE)
ax.set_ylabel(
f"Predicted {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE
) # ylabel
ax.set_title("Residuals", fontsize=SEABORN_RESIDUAL_TITLE_FONTSIZE)
# Square aspect ratio
ax.plot(
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_LINE_COLOUR,
linewidth=SEABORN_RESIDUAL_LINE_WIDTH,
)
plt.tight_layout()
plt.savefig(output_filepath, dpi=SEABORN_RESIDUAL_PLOT_DPI)