mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
fixed static typing issue
This commit is contained in:
parent
26c517594b
commit
f76c238ac6
4 changed files with 16 additions and 9 deletions
|
|
@ -1,2 +1,3 @@
|
|||
test_data_filepath: ./data/prepared_data/test.parquet
|
||||
predictions_output_filepath: ./data/predictions/predictions.parquet
|
||||
predictions_column_name: predictions
|
||||
|
|
|
|||
|
|
@ -59,16 +59,17 @@ stages:
|
|||
nfiles: 2
|
||||
- path: generate_predictions.py
|
||||
hash: md5
|
||||
md5: b3250f5b597fe33bf57e8bc225606be7
|
||||
size: 3268
|
||||
md5: 219c7b2aa920b14d4bb7a1ef7df0ea1b
|
||||
size: 3420
|
||||
params:
|
||||
configs/generate_predictions.yaml:
|
||||
predictions_column_name: predictions
|
||||
predictions_output_filepath: ./data/predictions/predictions.parquet
|
||||
test_data_filepath: ./data/prepared_data/test.parquet
|
||||
outs:
|
||||
- path: data/predictions/
|
||||
hash: md5
|
||||
md5: 339924cbd0435a59be599c06fd2b25e6.dir
|
||||
md5: 25bb58c06ce3bc7ef20de87298db1567.dir
|
||||
size: 2949
|
||||
nfiles: 1
|
||||
generate_metrics:
|
||||
|
|
@ -76,7 +77,7 @@ stages:
|
|||
deps:
|
||||
- path: data/predictions
|
||||
hash: md5
|
||||
md5: 339924cbd0435a59be599c06fd2b25e6.dir
|
||||
md5: 25bb58c06ce3bc7ef20de87298db1567.dir
|
||||
size: 2949
|
||||
nfiles: 1
|
||||
- path: data/prepared_data
|
||||
|
|
@ -86,8 +87,8 @@ stages:
|
|||
nfiles: 2
|
||||
- path: generate_metrics.py
|
||||
hash: md5
|
||||
md5: 8c78578a8c45edf4b93a85c42c2b2192
|
||||
size: 3561
|
||||
md5: 81f8eec20ffb542f27bde28dc028bace
|
||||
size: 3741
|
||||
params:
|
||||
configs/generate_metrics.yaml:
|
||||
metrics_output_filepath: ./metrics/metrics.json
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ def generate_metrics(
|
|||
target: str,
|
||||
test_data_filepath: str,
|
||||
predictions_output_filepath: str,
|
||||
predictions_column_name: str,
|
||||
metrics_output_filepath: str,
|
||||
):
|
||||
"""
|
||||
|
|
@ -59,14 +60,15 @@ def generate_metrics(
|
|||
logger.info("---------------------------")
|
||||
|
||||
# TODO: replace with client loader here
|
||||
predictions = pd.Series(pd.read_parquet(predictions_output_filepath))
|
||||
predictions = pd.read_parquet(predictions_output_filepath)
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Generating metrics ---")
|
||||
logger.info("--------------------------")
|
||||
|
||||
metrics_output = metrics.generate_metrics(
|
||||
target=test_data[target], predictions=predictions
|
||||
target=test_data[target],
|
||||
predictions=pd.Series(predictions[predictions_column_name]),
|
||||
)
|
||||
|
||||
logger.info("----------------------")
|
||||
|
|
@ -101,6 +103,7 @@ if __name__ == "__main__":
|
|||
predictions_output_filepath=generate_predictions_params[
|
||||
"predictions_output_filepath"
|
||||
],
|
||||
predictions_column_name=generate_predictions_params["predictions_column_name"],
|
||||
metrics_output_filepath=generate_metrics_params["metrics_output_filepath"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def generate_predictions(
|
|||
model_filepath: str,
|
||||
test_data_filepath: str,
|
||||
predictions_output_filepath: str,
|
||||
predictions_column_name: str,
|
||||
):
|
||||
"""
|
||||
For a given model, we generate prediction and evaluate this against the true target
|
||||
|
|
@ -75,7 +76,7 @@ def generate_predictions(
|
|||
if not Path(predictions_output_filepath).parent.exists():
|
||||
os.mkdir(Path(predictions_output_filepath).parent)
|
||||
|
||||
pd.DataFrame(predictions, columns=["predictions"]).to_parquet(
|
||||
pd.DataFrame(predictions, columns=[predictions_column_name]).to_parquet(
|
||||
predictions_output_filepath
|
||||
)
|
||||
|
||||
|
|
@ -98,6 +99,7 @@ if __name__ == "__main__":
|
|||
predictions_output_filepath=generate_predictions_params[
|
||||
"predictions_output_filepath"
|
||||
],
|
||||
predictions_column_name=generate_predictions_params["predictions_column_name"],
|
||||
)
|
||||
|
||||
logger.info("-------------------------------")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue