""" Generate metrics and enable regeneration of metrics if new metrics are generated Key tasks: - Specify metric functions that take in prediction vs actual to generate a metric value - Given a model and test data, produce a suite of all metrics """ import os import pandas as pd from pathlib import Path import seaborn as sns import matplotlib.pyplot as plt from core.CloudClient import BotoClient from core.Logger import logger from core.Settings import ( RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL, SEABORN_RESIDUAL_AXIS_FONTSIZE, SEABORN_RESIDUAL_TITLE_FONTSIZE, SEABORN_RESIDUAL_STYLE, SEABORN_RESIDUAL_ASPECT_RATIO, SEABORN_RESIDUAL_PLOT_DPI, SEABORN_RESIDUAL_RANGE, SEABORN_RESIDUAL_LINE_COLOUR, SEABORN_RESIDUAL_LINE_WIDTH, ) from sklearn.metrics import ( mean_absolute_error, median_absolute_error, mean_squared_error, mean_absolute_percentage_error, ) # Dummy example of new metric that can be added - must be true and prediction as arguments def max_error(y_true: pd.Series, y_pred: pd.Series): return max(y_true - y_pred) METRIC_TO_APPLY = [ mean_absolute_error, median_absolute_error, mean_squared_error, mean_absolute_percentage_error, # max_error ] def sort_by_metric( data: pd.DataFrame, optimse_metric: str, best_model_column_name: str ) -> pd.DataFrame: """ Helper function to sort data frame by metric and append a best model flag """ # Ascending as we want lowest error values data = data.sort_values(optimse_metric, ascending=True).reset_index(drop=True) data[best_model_column_name] = [False] * len(data) data.loc[0, best_model_column_name] = True return data class Metrics: """ All metric functions used to generate a dictionary of metrics """ def upload_metrics(self, output_filepath: Path, client: BotoClient) -> None: """ Providing a path, this function will save the metrics folders/files. """ if client.client is None: logger.info("In local development mode - no need to upload") else: logger.info(f"Saving metrics into s3") s3_location = client.model_bucket + "/" + str(output_filepath) self.directory_upload( client=client, local_directory=str(output_filepath), bucket_name=client.model_bucket, ) logger.info("Save complete") def directory_upload(self, client, local_directory, bucket_name): # Iterate through the local directory and upload each file for root, dirs, files in os.walk(local_directory): for file in files: # Determine the local file path and S3 object key local_file_path = os.path.join(root, file) s3_object_key = os.path.relpath(local_file_path, local_directory) # Upload the file to S3 client.client.upload_file(local_file_path, bucket_name, local_file_path) logger.info( f"Uploaded {local_file_path} to {bucket_name}/{local_file_path}" ) @staticmethod def list_metric_functions() -> list: """ Gather all metric functions to run """ return [metric_to_apply.__name__ for metric_to_apply in METRIC_TO_APPLY] @staticmethod def generate_metric_suite(actuals: pd.Series, predictions: pd.Series) -> pd.Series: """ For the model, test data and target, generate predictions and then iterative over all metrics to generate a Series of metric values """ metric_dict = {} for metric_function in METRIC_TO_APPLY: metric_dict[metric_function.__name__] = metric_function( actuals, predictions ) metrics = pd.Series(metric_dict) return metrics @staticmethod def generate_plot_suite(): """ Can do all metric ploting """ @staticmethod def generate_residual_plot( actuals: pd.Series, predictions: pd.Series, target_column: str, output_filepath: Path | str, ): # TODO: can have a model.metric_outputs method # FOr not just do it here residual_df = pd.DataFrame( list(zip(actuals, predictions)), columns=[RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL], ) # image formatting sns.set(style=SEABORN_RESIDUAL_STYLE) ax = sns.scatterplot( x=RESIDUAL_TRUE_LABEL, y=RESIDUAL_PREDICTION_LABEL, data=residual_df ) ax.set_aspect(SEABORN_RESIDUAL_ASPECT_RATIO) ax.set_xlabel(f"True {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE) ax.set_ylabel( f"Predicted {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE ) # ylabel ax.set_title("Residuals", fontsize=SEABORN_RESIDUAL_TITLE_FONTSIZE) # Square aspect ratio ax.plot( SEABORN_RESIDUAL_RANGE, SEABORN_RESIDUAL_RANGE, SEABORN_RESIDUAL_LINE_COLOUR, linewidth=SEABORN_RESIDUAL_LINE_WIDTH, ) plt.tight_layout() plt.savefig(output_filepath, dpi=SEABORN_RESIDUAL_PLOT_DPI)