diff --git a/model_data/simulation_system/training.py b/model_data/simulation_system/training.py index 358abb41..04ef3493 100644 --- a/model_data/simulation_system/training.py +++ b/model_data/simulation_system/training.py @@ -20,6 +20,8 @@ from core.Settings import ( SUBSAMPLE_FACTOR, MODEL_HYPERPARAMETERS ) +import seaborn as sns +import matplotlib.pyplot as plt TIMESTAMP = datetime.now().strftime(format="%Y-%m-%d_%H-%M-%S") @@ -121,6 +123,29 @@ def training( metrics_location = output_base / METRICS_FOLDER ) + logger.info("--- Generate metric outputs using predictions ---") + # TODO: can have a model.metric_outputs method + # FOr not just do it here + residual_df = pd.DataFrame(list(zip(test_df[target_column], model.predictions)), columns=['true', 'pred']) + + # image formatting + # TODO: move to settings file , AXIS_FONT, TITLE_FONT + axis_fs = 18 #fontsize + title_fs = 22 #fontsize + sns.set(style="whitegrid") + ax = sns.scatterplot(x="true", y="pred",data=residual_df) + ax.set_aspect('equal') + ax.set_xlabel(f'True {target_column}',fontsize = axis_fs) + ax.set_ylabel(f'Predicted {target_column}', fontsize = axis_fs)#ylabel + ax.set_title('Residuals', fontsize = title_fs) + + # Square aspect ratio + ax.plot([-100, 100], [-100, 100], 'black', linewidth=1) + + plt.tight_layout() + RESIDUAL_FILE = "residuals.png" + plt.savefig(output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=120) + # TODO: introduce a seperate script for model optimisation, and from there, optimise for deployment # Imagining for now that the model trained here is the best model amongst all models built @@ -139,6 +164,7 @@ def training( logger.info("Registry file found - Loading into Dataframe") registry_df = pd.read_csv(registry_path, index_col=None) else: + # TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns registry_df = pd.DataFrame(columns=['model_type', 'model_name', 'model_location', 'mean_absolute_error', 'root_mean_squared_error', 'mean_squared_error', 'r2', 'pearsonr', 'median_absolute_error', 'mape', 'best_model']) model_details_df = pd.DataFrame(