diff --git a/modules/ml-pipeline/src/pipeline/src/configs/generate_predictions.yaml b/modules/ml-pipeline/src/pipeline/src/configs/generate_predictions.yaml index 89444ad..c7f1b2d 100644 --- a/modules/ml-pipeline/src/pipeline/src/configs/generate_predictions.yaml +++ b/modules/ml-pipeline/src/pipeline/src/configs/generate_predictions.yaml @@ -1,2 +1,3 @@ test_data_filepath: ./data/prepared_data/test.parquet predictions_output_filepath: ./data/predictions/predictions.parquet +predictions_column_name: predictions diff --git a/modules/ml-pipeline/src/pipeline/src/dvc.lock b/modules/ml-pipeline/src/pipeline/src/dvc.lock index 71172d8..96a8c32 100644 --- a/modules/ml-pipeline/src/pipeline/src/dvc.lock +++ b/modules/ml-pipeline/src/pipeline/src/dvc.lock @@ -59,16 +59,17 @@ stages: nfiles: 2 - path: generate_predictions.py hash: md5 - md5: b3250f5b597fe33bf57e8bc225606be7 - size: 3268 + md5: 219c7b2aa920b14d4bb7a1ef7df0ea1b + size: 3420 params: configs/generate_predictions.yaml: + predictions_column_name: predictions predictions_output_filepath: ./data/predictions/predictions.parquet test_data_filepath: ./data/prepared_data/test.parquet outs: - path: data/predictions/ hash: md5 - md5: 339924cbd0435a59be599c06fd2b25e6.dir + md5: 25bb58c06ce3bc7ef20de87298db1567.dir size: 2949 nfiles: 1 generate_metrics: @@ -76,7 +77,7 @@ stages: deps: - path: data/predictions hash: md5 - md5: 339924cbd0435a59be599c06fd2b25e6.dir + md5: 25bb58c06ce3bc7ef20de87298db1567.dir size: 2949 nfiles: 1 - path: data/prepared_data @@ -86,8 +87,8 @@ stages: nfiles: 2 - path: generate_metrics.py hash: md5 - md5: 8c78578a8c45edf4b93a85c42c2b2192 - size: 3561 + md5: 81f8eec20ffb542f27bde28dc028bace + size: 3741 params: configs/generate_metrics.yaml: metrics_output_filepath: ./metrics/metrics.json diff --git a/modules/ml-pipeline/src/pipeline/src/generate_metrics.py b/modules/ml-pipeline/src/pipeline/src/generate_metrics.py index 2167671..11b214c 100644 --- a/modules/ml-pipeline/src/pipeline/src/generate_metrics.py +++ b/modules/ml-pipeline/src/pipeline/src/generate_metrics.py @@ -41,6 +41,7 @@ def generate_metrics( target: str, test_data_filepath: str, predictions_output_filepath: str, + predictions_column_name: str, metrics_output_filepath: str, ): """ @@ -59,14 +60,15 @@ def generate_metrics( logger.info("---------------------------") # TODO: replace with client loader here - predictions = pd.Series(pd.read_parquet(predictions_output_filepath)) + predictions = pd.read_parquet(predictions_output_filepath) logger.info("--------------------------") logger.info("--- Generating metrics ---") logger.info("--------------------------") metrics_output = metrics.generate_metrics( - target=test_data[target], predictions=predictions + target=test_data[target], + predictions=pd.Series(predictions[predictions_column_name]), ) logger.info("----------------------") @@ -101,6 +103,7 @@ if __name__ == "__main__": predictions_output_filepath=generate_predictions_params[ "predictions_output_filepath" ], + predictions_column_name=generate_predictions_params["predictions_column_name"], metrics_output_filepath=generate_metrics_params["metrics_output_filepath"], ) diff --git a/modules/ml-pipeline/src/pipeline/src/generate_predictions.py b/modules/ml-pipeline/src/pipeline/src/generate_predictions.py index e390bed..00d6fce 100644 --- a/modules/ml-pipeline/src/pipeline/src/generate_predictions.py +++ b/modules/ml-pipeline/src/pipeline/src/generate_predictions.py @@ -37,6 +37,7 @@ def generate_predictions( model_filepath: str, test_data_filepath: str, predictions_output_filepath: str, + predictions_column_name: str, ): """ For a given model, we generate prediction and evaluate this against the true target @@ -75,7 +76,7 @@ def generate_predictions( if not Path(predictions_output_filepath).parent.exists(): os.mkdir(Path(predictions_output_filepath).parent) - pd.DataFrame(predictions, columns=["predictions"]).to_parquet( + pd.DataFrame(predictions, columns=[predictions_column_name]).to_parquet( predictions_output_filepath ) @@ -98,6 +99,7 @@ if __name__ == "__main__": predictions_output_filepath=generate_predictions_params[ "predictions_output_filepath" ], + predictions_column_name=generate_predictions_params["predictions_column_name"], ) logger.info("-------------------------------")