diff --git a/model_data/simulation_system/core/Settings.py b/model_data/simulation_system/core/Settings.py index c46a7dc0..3b9c8abf 100644 --- a/model_data/simulation_system/core/Settings.py +++ b/model_data/simulation_system/core/Settings.py @@ -2,6 +2,19 @@ # TODO: migrate to dynaconf from pathlib import Path +# TODO: remove these setting elsewhere for CML +RESIDUAL_TRUE_LABEL = "true" +RESIDUAL_PREDICTION_LABEL = "pred" +RESIDUAL_FILE = "residual.png" +SEABORN_RESIDUAL_AXIS_FONTSIZE = 12 +SEABORN_RESIDUAL_TITLE_FONTSIZE = 22 +SEABORN_RESIDUAL_STYLE = "whitegrid" +SEABORN_RESIDUAL_ASPECT_RATIO = "equal" +SEABORN_RESIDUAL_PLOT_DPI = 120 +SEABORN_RESIDUAL_RANGE = [-100, 100] +SEABORN_RESIDUAL_LINE_COLOUR = "black" +SEABORN_RESIDUAL_LINE_WIDTH = 1 + # Can move to a hyperparmeters file # If anything we might want to have a file that can be loaded and sent to this script MODEL_HYPERPARAMETERS = { diff --git a/model_data/simulation_system/training.py b/model_data/simulation_system/training.py index 11acdf57..6a9dae31 100644 --- a/model_data/simulation_system/training.py +++ b/model_data/simulation_system/training.py @@ -18,6 +18,17 @@ from core.Settings import ( SUBSAMPLE_FACTOR, MODEL_HYPERPARAMETERS, TIMESTAMP_FORMAT, + RESIDUAL_TRUE_LABEL, + RESIDUAL_PREDICTION_LABEL, + RESIDUAL_FILE, + SEABORN_RESIDUAL_AXIS_FONTSIZE, + SEABORN_RESIDUAL_TITLE_FONTSIZE, + SEABORN_RESIDUAL_STYLE, + SEABORN_RESIDUAL_ASPECT_RATIO, + SEABORN_RESIDUAL_PLOT_DPI, + SEABORN_RESIDUAL_RANGE, + SEABORN_RESIDUAL_LINE_COLOUR, + SEABORN_RESIDUAL_LINE_WIDTH, ) import seaborn as sns import matplotlib.pyplot as plt @@ -147,26 +158,35 @@ def training( # TODO: can have a model.metric_outputs method # FOr not just do it here residual_df = pd.DataFrame( - list(zip(test_df[target_column], model.predictions)), columns=["true", "pred"] + list(zip(test_df[target_column], model.predictions)), + columns=[RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL], ) # image formatting # TODO: move to settings file , AXIS_FONT, TITLE_FONT - axis_fs = 18 # fontsize - title_fs = 22 # fontsize - sns.set(style="whitegrid") - ax = sns.scatterplot(x="true", y="pred", data=residual_df) - ax.set_aspect("equal") - ax.set_xlabel(f"True {target_column}", fontsize=axis_fs) - ax.set_ylabel(f"Predicted {target_column}", fontsize=axis_fs) # ylabel - ax.set_title("Residuals", fontsize=title_fs) + sns.set(style=SEABORN_RESIDUAL_STYLE) + ax = sns.scatterplot( + x=RESIDUAL_TRUE_LABEL, y=RESIDUAL_PREDICTION_LABEL, data=residual_df + ) + ax.set_aspect(SEABORN_RESIDUAL_ASPECT_RATIO) + ax.set_xlabel(f"True {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE) + ax.set_ylabel( + f"Predicted {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE + ) # ylabel + ax.set_title("Residuals", fontsize=SEABORN_RESIDUAL_TITLE_FONTSIZE) # Square aspect ratio - ax.plot([-100, 100], [-100, 100], "black", linewidth=1) + ax.plot( + SEABORN_RESIDUAL_RANGE, + SEABORN_RESIDUAL_RANGE, + SEABORN_RESIDUAL_LINE_COLOUR, + linewidth=SEABORN_RESIDUAL_LINE_WIDTH, + ) plt.tight_layout() - RESIDUAL_FILE = "residuals.png" - plt.savefig(output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=120) + plt.savefig( + output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=SEABORN_RESIDUAL_PLOT_DPI + ) # TODO: for cml, we might want to have class that outputs all data and plots to add to the report # If we want residual plot/ any plots, we will need to self host