mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
nearly there
This commit is contained in:
parent
4a6b7f3ed7
commit
1e70a3a582
4 changed files with 60 additions and 13 deletions
|
|
@ -1,5 +1,3 @@
|
|||
dataclient_type: local
|
||||
input_datahandler_type: parquet
|
||||
output_datahandler_type: json
|
||||
metrics_type: Regression
|
||||
metrics_output_filepath: ./metrics/metrics.json
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ def model_analysis(
|
|||
logger.info(f"--- Generate Feature Importance ---")
|
||||
logger.info("------------------------------------")
|
||||
|
||||
test_df = pd.read_parquet(output_test_filepath)
|
||||
test_df = dataclient.load_data(output_test_filepath)
|
||||
|
||||
test_df = test_df.head(permutation_subsample_amount)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,30 +36,76 @@ feature_process_params = yaml.safe_load(open(feature_process_path))
|
|||
build_model_path = Path(__file__).parent / "configs" / "build_model.yaml"
|
||||
build_model_params = yaml.safe_load(open(build_model_path))
|
||||
|
||||
model_analysis_path = Path(__file__).parent / "configs" / "model_analysis.yaml"
|
||||
model_analysis_params = yaml.safe_load(open(model_analysis_path))
|
||||
|
||||
generate_predictions_path = (
|
||||
Path(__file__).parent / "configs" / "generate_predictions.yaml"
|
||||
prediction_analysis_path = (
|
||||
Path(__file__).parent / "configs" / "prediction_analysis.yaml"
|
||||
)
|
||||
generate_predictions_params = yaml.safe_load(open(generate_predictions_path))
|
||||
prediction_analysis_params = yaml.safe_load(open(prediction_analysis_path))
|
||||
|
||||
model = model_factory(build_model_params["model_type"])
|
||||
model.load_model(build_model_params["model_save_filepath"])
|
||||
|
||||
dataclient_type = model_analysis_params["dataclient_type"]
|
||||
dataclient_type = prediction_analysis_params["dataclient_type"]
|
||||
dataclient = dataclient_factory(
|
||||
dataclient_type=dataclient_type,
|
||||
dataclient_config=client_params[dataclient_type],
|
||||
)
|
||||
|
||||
output_test_filepath = prepare_data_params["output_test_filepath"]
|
||||
|
||||
def prediction_analysis(model: MLModel, dataclient: DataClient):
|
||||
|
||||
shap.kmeans()
|
||||
def prediction_analysis(
|
||||
model: MLModel, dataclient: DataClient, output_test_filepath: str
|
||||
):
|
||||
|
||||
test_df = dataclient.load_data(output_test_filepath)
|
||||
target = "SAP_ENDING"
|
||||
test_df_without_target = test_df.drop(columns=[target])
|
||||
|
||||
# test_df_summary = shap.kmeans(test_df, 10)
|
||||
# print("Baseline feature-values: \n", test_df_summary)
|
||||
class AutogluonWrapper:
|
||||
def __init__(self, predictor, feature_names):
|
||||
self.ag_model = predictor
|
||||
self.feature_names = feature_names
|
||||
|
||||
def predict(self, X):
|
||||
if isinstance(X, pd.Series):
|
||||
X = X.values.reshape(1, -1)
|
||||
if not isinstance(X, pd.DataFrame):
|
||||
X = pd.DataFrame(X, columns=self.feature_names)
|
||||
return self.ag_model.predict(X)
|
||||
|
||||
ag_wrapper = AutogluonWrapper(
|
||||
model.model, feature_names=test_df_without_target.columns
|
||||
)
|
||||
explainer = shap.KernelExplainer(ag_wrapper.predict, test_df_without_target)
|
||||
|
||||
NSHAP_SAMPLES = 100 # how many samples to use to approximate each Shapely value, larger values will be slower
|
||||
N_VAL = 30 # how many datapoints from validation data should we interpret predictions for, larger values will be slower
|
||||
|
||||
ROW_INDEX = 0 # index of an example datapoint
|
||||
single_datapoint = test_df_without_target.iloc[[ROW_INDEX]]
|
||||
single_prediction = ag_wrapper.predict(single_datapoint)
|
||||
|
||||
shap_values_single = explainer.shap_values(single_datapoint, nsamples=NSHAP_SAMPLES)
|
||||
shap.force_plot(
|
||||
explainer.expected_value,
|
||||
shap_values_single,
|
||||
test_df_without_target.iloc[ROW_INDEX, :],
|
||||
)
|
||||
...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prediction_analysis()
|
||||
|
||||
logger.info("----------------------------")
|
||||
logger.info(f"--- {__file__} - Start! ---")
|
||||
logger.info("----------------------------")
|
||||
|
||||
prediction_analysis(
|
||||
model=model, dataclient=dataclient, output_test_filepath=output_test_filepath
|
||||
)
|
||||
|
||||
logger.info("-------------------------------")
|
||||
logger.info(f"--- {__file__} - Complete! ---")
|
||||
logger.info("-------------------------------")
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ def prepare_data(
|
|||
train, test = train_test_split(
|
||||
data, train_size=train_proportion, test_size=(1 - train_proportion)
|
||||
)
|
||||
test = test.reset_index(drop=True)
|
||||
|
||||
train = train.reset_index(drop=True)
|
||||
|
||||
logger.info("-----------------------")
|
||||
logger.info("--- Outputting data ---")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue