mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
fix shapley
This commit is contained in:
parent
d39600eaaa
commit
56cf9c33d4
3 changed files with 18 additions and 3 deletions
|
|
@ -0,0 +1 @@
|
|||
dataclient_type: local
|
||||
|
|
@ -58,6 +58,13 @@ def prediction_analysis(
|
|||
):
|
||||
|
||||
test_df = dataclient.load_data(output_test_filepath)
|
||||
predictions = dataclient.load_data("./data/predictions/predictions.parquet")
|
||||
|
||||
mix_df = test_df.copy()
|
||||
mix_df["predictions"] = predictions
|
||||
mix_df["residual"] = abs(mix_df["predictions"] - mix_df["SAP_ENDING"])
|
||||
mix_df = mix_df.sort_values("residual", ascending=False)
|
||||
|
||||
target = "SAP_ENDING"
|
||||
test_df_without_target = test_df.drop(columns=[target])
|
||||
|
||||
|
|
@ -78,12 +85,14 @@ def prediction_analysis(
|
|||
ag_wrapper = AutogluonWrapper(
|
||||
model.model, feature_names=test_df_without_target.columns
|
||||
)
|
||||
explainer = shap.KernelExplainer(ag_wrapper.predict, test_df_without_target)
|
||||
explainer = shap.KernelExplainer(
|
||||
ag_wrapper.predict, test_df_without_target.head(100)
|
||||
)
|
||||
|
||||
NSHAP_SAMPLES = 100 # how many samples to use to approximate each Shapely value, larger values will be slower
|
||||
N_VAL = 30 # how many datapoints from validation data should we interpret predictions for, larger values will be slower
|
||||
|
||||
ROW_INDEX = 0 # index of an example datapoint
|
||||
ROW_INDEX = 8541 # 23690 #21059 # index of an example datapoint
|
||||
single_datapoint = test_df_without_target.iloc[[ROW_INDEX]]
|
||||
single_prediction = ag_wrapper.predict(single_datapoint)
|
||||
|
||||
|
|
@ -93,7 +102,11 @@ def prediction_analysis(
|
|||
shap_values_single,
|
||||
test_df_without_target.iloc[ROW_INDEX, :],
|
||||
)
|
||||
...
|
||||
shap_single_prediciton_df = pd.DataFrame(
|
||||
shap_values_single, columns=test_df_without_target.columns
|
||||
).T
|
||||
shap_single_prediciton_df.columns = ["contribution"]
|
||||
shap_single_prediciton_df = shap_single_prediciton_df.sort_values("contribution")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@ boto3==1.28.17
|
|||
pandas==1.5.3
|
||||
autogluon==0.8.2
|
||||
alibi==0.9.4
|
||||
shap==0.42.1
|
||||
pyarrow==13.0.0
|
||||
pre-commit==3.3.3
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue