move seaborn settings to setting file

This commit is contained in:
Michael Duong 2023-08-29 18:31:01 +01:00
parent 469938bb25
commit 05afffc05d
2 changed files with 45 additions and 12 deletions

View file

@ -2,6 +2,19 @@
# TODO: migrate to dynaconf
from pathlib import Path
# TODO: remove these setting elsewhere for CML
RESIDUAL_TRUE_LABEL = "true"
RESIDUAL_PREDICTION_LABEL = "pred"
RESIDUAL_FILE = "residual.png"
SEABORN_RESIDUAL_AXIS_FONTSIZE = 12
SEABORN_RESIDUAL_TITLE_FONTSIZE = 22
SEABORN_RESIDUAL_STYLE = "whitegrid"
SEABORN_RESIDUAL_ASPECT_RATIO = "equal"
SEABORN_RESIDUAL_PLOT_DPI = 120
SEABORN_RESIDUAL_RANGE = [-100, 100]
SEABORN_RESIDUAL_LINE_COLOUR = "black"
SEABORN_RESIDUAL_LINE_WIDTH = 1
# Can move to a hyperparmeters file
# If anything we might want to have a file that can be loaded and sent to this script
MODEL_HYPERPARAMETERS = {

View file

@ -18,6 +18,17 @@ from core.Settings import (
SUBSAMPLE_FACTOR,
MODEL_HYPERPARAMETERS,
TIMESTAMP_FORMAT,
RESIDUAL_TRUE_LABEL,
RESIDUAL_PREDICTION_LABEL,
RESIDUAL_FILE,
SEABORN_RESIDUAL_AXIS_FONTSIZE,
SEABORN_RESIDUAL_TITLE_FONTSIZE,
SEABORN_RESIDUAL_STYLE,
SEABORN_RESIDUAL_ASPECT_RATIO,
SEABORN_RESIDUAL_PLOT_DPI,
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_LINE_COLOUR,
SEABORN_RESIDUAL_LINE_WIDTH,
)
import seaborn as sns
import matplotlib.pyplot as plt
@ -147,26 +158,35 @@ def training(
# TODO: can have a model.metric_outputs method
# FOr not just do it here
residual_df = pd.DataFrame(
list(zip(test_df[target_column], model.predictions)), columns=["true", "pred"]
list(zip(test_df[target_column], model.predictions)),
columns=[RESIDUAL_TRUE_LABEL, RESIDUAL_PREDICTION_LABEL],
)
# image formatting
# TODO: move to settings file , AXIS_FONT, TITLE_FONT
axis_fs = 18 # fontsize
title_fs = 22 # fontsize
sns.set(style="whitegrid")
ax = sns.scatterplot(x="true", y="pred", data=residual_df)
ax.set_aspect("equal")
ax.set_xlabel(f"True {target_column}", fontsize=axis_fs)
ax.set_ylabel(f"Predicted {target_column}", fontsize=axis_fs) # ylabel
ax.set_title("Residuals", fontsize=title_fs)
sns.set(style=SEABORN_RESIDUAL_STYLE)
ax = sns.scatterplot(
x=RESIDUAL_TRUE_LABEL, y=RESIDUAL_PREDICTION_LABEL, data=residual_df
)
ax.set_aspect(SEABORN_RESIDUAL_ASPECT_RATIO)
ax.set_xlabel(f"True {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE)
ax.set_ylabel(
f"Predicted {target_column}", fontsize=SEABORN_RESIDUAL_AXIS_FONTSIZE
) # ylabel
ax.set_title("Residuals", fontsize=SEABORN_RESIDUAL_TITLE_FONTSIZE)
# Square aspect ratio
ax.plot([-100, 100], [-100, 100], "black", linewidth=1)
ax.plot(
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_RANGE,
SEABORN_RESIDUAL_LINE_COLOUR,
linewidth=SEABORN_RESIDUAL_LINE_WIDTH,
)
plt.tight_layout()
RESIDUAL_FILE = "residuals.png"
plt.savefig(output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=120)
plt.savefig(
output_base / METRICS_FOLDER / RESIDUAL_FILE, dpi=SEABORN_RESIDUAL_PLOT_DPI
)
# TODO: for cml, we might want to have class that outputs all data and plots to add to the report
# If we want residual plot/ any plots, we will need to self host