mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
move seaborn settings to setting file
This commit is contained in:
parent
469938bb25
commit
05afffc05d
2 changed files with 45 additions and 12 deletions
|
|
@ -2,6 +2,19 @@
|
|||
# TODO: migrate to dynaconf
|
||||
from pathlib import Path
|
||||
|
||||
# TODO: remove these setting elsewhere for CML
|
||||
RESIDUAL_TRUE_LABEL = "true"
|
||||
RESIDUAL_PREDICTION_LABEL = "pred"
|
||||
RESIDUAL_FILE = "residual.png"
|
||||
SEABORN_RESIDUAL_AXIS_FONTSIZE = 12
|
||||
SEABORN_RESIDUAL_TITLE_FONTSIZE = 22
|
||||
SEABORN_RESIDUAL_STYLE = "whitegrid"
|
||||
SEABORN_RESIDUAL_ASPECT_RATIO = "equal"
|
||||
SEABORN_RESIDUAL_PLOT_DPI = 120
|
||||
SEABORN_RESIDUAL_RANGE = [-100, 100]
|
||||
SEABORN_RESIDUAL_LINE_COLOUR = "black"
|
||||
SEABORN_RESIDUAL_LINE_WIDTH = 1
|
||||
|
||||
# Can move to a hyperparmeters file
|
||||
# If anything we might want to have a file that can be loaded and sent to this script
|
||||
MODEL_HYPERPARAMETERS = {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,17 @@ from core.Settings import (
|
|||
SUBSAMPLE_FACTOR,
|
||||
MODEL_HYPERPARAMETERS,
|
||||
TIMESTAMP_FORMAT,
|
||||
RESIDUAL_TRUE_LABEL,
|
||||
RESIDUAL_PREDICTION_LABEL,
|
||||
RESIDUAL_FILE,
|
||||
SEABORN_RESIDUAL_AXIS_FONTSIZE,
|
||||
SEABORN_RESIDUAL_TITLE_FONTSIZE,
|
||||
SEABORN_RESIDUAL_STYLE,
|
||||
SEABORN_RESIDUAL_ASPECT_RATIO,
|
||||
SEABORN_RESIDUAL_PLOT_DPI,
|
||||
SEABORN_RESIDUAL_RANGE,
|
||||
SEABORN_RESIDUAL_LINE_COLOUR,
|
||||
SEABORN_RESIDUAL_LINE_WIDTH,
|
||||
)
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -147,26 +158,35 @@ def training(
|
|||
# TODO: can have a model.metric_outputs method
|
||||
# FOr not just do it here
|
||||
residual_df = pd.DataFrame(
|
||||
list(zip(test_df[target_column], model.predictions)), columns=["true", "pred"]
|
||||
list(zip(test_df[target_column], model.predictions)),
|
||||
columns=[RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL],
|
||||
)
|
||||
|
||||
# image formatting
|
||||
# TODO: move to settings file , AXIS_FONT, TITLE_FONT
|
||||
axis_fs = 18 # fontsize
|
||||
title_fs = 22 # fontsize
|
||||
sns.set(style="whitegrid")
|
||||
ax = sns.scatterplot(x="true", y="pred", data=residual_df)
|
||||
ax.set_aspect("equal")
|
||||
ax.set_xlabel(f"True {target_column}", fontsize=axis_fs)
|
||||
ax.set_ylabel(f"Predicted {target_column}", fontsize=axis_fs) # ylabel
|
||||
ax.set_title("Residuals", fontsize=title_fs)
|
||||
sns.set(style=SEABORN_RESIDUAL_STYLE)
|
||||
ax = sns.scatterplot(
|
||||
x=RESIDUAL_TRUE_LABEL, y=RESIDUAL_PREDICTION_LABEL, data=residual_df
|
||||
)
|
||||
ax.set_aspect(SEABORN_RESIDUAL_ASPECT_RATIO)
|
||||
ax.set_xlabel(f"True {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE)
|
||||
ax.set_ylabel(
|
||||
f"Predicted {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE
|
||||
) # ylabel
|
||||
ax.set_title("Residuals", fontsize=SEABORN_RESIDUAL_TITLE_FONTSIZE)
|
||||
|
||||
# Square aspect ratio
|
||||
ax.plot([-100, 100], [-100, 100], "black", linewidth=1)
|
||||
ax.plot(
|
||||
SEABORN_RESIDUAL_RANGE,
|
||||
SEABORN_RESIDUAL_RANGE,
|
||||
SEABORN_RESIDUAL_LINE_COLOUR,
|
||||
linewidth=SEABORN_RESIDUAL_LINE_WIDTH,
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
RESIDUAL_FILE = "residuals.png"
|
||||
plt.savefig(output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=120)
|
||||
plt.savefig(
|
||||
output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=SEABORN_RESIDUAL_PLOT_DPI
|
||||
)
|
||||
|
||||
# TODO: for cml, we might want to have class that outputs all data and plots to add to the report
|
||||
# If we want residual plot/ any plots, we will need to self host
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue