fixed static typing issue

This commit is contained in:
Michael Duong 2023-09-10 16:03:18 +01:00
parent 26c517594b
commit f76c238ac6
4 changed files with 16 additions and 9 deletions

View file

@ -1,2 +1,3 @@
test_data_filepath: ./data/prepared_data/test.parquet
predictions_output_filepath: ./data/predictions/predictions.parquet
predictions_column_name: predictions

View file

@ -59,16 +59,17 @@ stages:
nfiles: 2
- path: generate_predictions.py
hash: md5
md5: b3250f5b597fe33bf57e8bc225606be7
size: 3268
md5: 219c7b2aa920b14d4bb7a1ef7df0ea1b
size: 3420
params:
configs/generate_predictions.yaml:
predictions_column_name: predictions
predictions_output_filepath: ./data/predictions/predictions.parquet
test_data_filepath: ./data/prepared_data/test.parquet
outs:
- path: data/predictions/
hash: md5
md5: 339924cbd0435a59be599c06fd2b25e6.dir
md5: 25bb58c06ce3bc7ef20de87298db1567.dir
size: 2949
nfiles: 1
generate_metrics:
@ -76,7 +77,7 @@ stages:
deps:
- path: data/predictions
hash: md5
md5: 339924cbd0435a59be599c06fd2b25e6.dir
md5: 25bb58c06ce3bc7ef20de87298db1567.dir
size: 2949
nfiles: 1
- path: data/prepared_data
@ -86,8 +87,8 @@ stages:
nfiles: 2
- path: generate_metrics.py
hash: md5
md5: 8c78578a8c45edf4b93a85c42c2b2192
size: 3561
md5: 81f8eec20ffb542f27bde28dc028bace
size: 3741
params:
configs/generate_metrics.yaml:
metrics_output_filepath: ./metrics/metrics.json

View file

@ -41,6 +41,7 @@ def generate_metrics(
target: str,
test_data_filepath: str,
predictions_output_filepath: str,
predictions_column_name: str,
metrics_output_filepath: str,
):
"""
@ -59,14 +60,15 @@ def generate_metrics(
logger.info("---------------------------")
# TODO: replace with client loader here
predictions = pd.Series(pd.read_parquet(predictions_output_filepath))
predictions = pd.read_parquet(predictions_output_filepath)
logger.info("--------------------------")
logger.info("--- Generating metrics ---")
logger.info("--------------------------")
metrics_output = metrics.generate_metrics(
target=test_data[target], predictions=predictions
target=test_data[target],
predictions=pd.Series(predictions[predictions_column_name]),
)
logger.info("----------------------")
@ -101,6 +103,7 @@ if __name__ == "__main__":
predictions_output_filepath=generate_predictions_params[
"predictions_output_filepath"
],
predictions_column_name=generate_predictions_params["predictions_column_name"],
metrics_output_filepath=generate_metrics_params["metrics_output_filepath"],
)

View file

@ -37,6 +37,7 @@ def generate_predictions(
model_filepath: str,
test_data_filepath: str,
predictions_output_filepath: str,
predictions_column_name: str,
):
"""
For a given model, we generate prediction and evaluate this against the true target
@ -75,7 +76,7 @@ def generate_predictions(
if not Path(predictions_output_filepath).parent.exists():
os.mkdir(Path(predictions_output_filepath).parent)
pd.DataFrame(predictions, columns=["predictions"]).to_parquet(
pd.DataFrame(predictions, columns=[predictions_column_name]).to_parquet(
predictions_output_filepath
)
@ -98,6 +99,7 @@ if __name__ == "__main__":
predictions_output_filepath=generate_predictions_params[
"predictions_output_filepath"
],
predictions_column_name=generate_predictions_params["predictions_column_name"],
)
logger.info("-------------------------------")