mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
add residual png
This commit is contained in:
parent
a31127c2bc
commit
0903a60af8
1 changed files with 26 additions and 0 deletions
|
|
@ -20,6 +20,8 @@ from core.Settings import (
|
|||
SUBSAMPLE_FACTOR,
|
||||
MODEL_HYPERPARAMETERS
|
||||
)
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
TIMESTAMP = datetime.now().strftime(format="%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
|
@ -121,6 +123,29 @@ def training(
|
|||
metrics_location = output_base / METRICS_FOLDER
|
||||
)
|
||||
|
||||
logger.info("--- Generate metric outputs using predictions ---")
|
||||
# TODO: can have a model.metric_outputs method
|
||||
# FOr not just do it here
|
||||
residual_df = pd.DataFrame(list(zip(test_df[target_column], model.predictions)), columns=['true', 'pred'])
|
||||
|
||||
# image formatting
|
||||
# TODO: move to settings file , AXIS_FONT, TITLE_FONT
|
||||
axis_fs = 18 #fontsize
|
||||
title_fs = 22 #fontsize
|
||||
sns.set(style="whitegrid")
|
||||
ax = sns.scatterplot(x="true", y="pred",data=residual_df)
|
||||
ax.set_aspect('equal')
|
||||
ax.set_xlabel(f'True {target_column}',fontsize = axis_fs)
|
||||
ax.set_ylabel(f'Predicted {target_column}', fontsize = axis_fs)#ylabel
|
||||
ax.set_title('Residuals', fontsize = title_fs)
|
||||
|
||||
# Square aspect ratio
|
||||
ax.plot([-100, 100], [-100, 100], 'black', linewidth=1)
|
||||
|
||||
plt.tight_layout()
|
||||
RESIDUAL_FILE = "residuals.png"
|
||||
plt.savefig(output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=120)
|
||||
|
||||
# TODO: introduce a seperate script for model optimisation, and from there, optimise for deployment
|
||||
# Imagining for now that the model trained here is the best model amongst all models built
|
||||
|
||||
|
|
@ -139,6 +164,7 @@ def training(
|
|||
logger.info("Registry file found - Loading into Dataframe")
|
||||
registry_df = pd.read_csv(registry_path, index_col=None)
|
||||
else:
|
||||
# TODO: Moved columns into settings: MODEL_DETAILS and Metrics class columns
|
||||
registry_df = pd.DataFrame(columns=['model_type', 'model_name', 'model_location', 'mean_absolute_error', 'root_mean_squared_error', 'mean_squared_error', 'r2', 'pearsonr', 'median_absolute_error', 'mape', 'best_model'])
|
||||
|
||||
model_details_df = pd.DataFrame(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue