Merge pull request #35 from Hestia-Homes/model-test

add fit metrics
This commit is contained in:
quandanrepo 2023-09-17 21:09:33 +01:00 committed by GitHub
commit 26c011069a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 24 deletions

View file

@ -9,10 +9,13 @@ import pandas as pd
from typing import Union
from pathlib import Path
from core.Logger import logger
from core.interface.InterfaceMetrics import MLMetrics
from core.interface.InterfaceModels import MLModel
from core.interface.InterfaceDataClient import DataClient
from core.DataClient import dataclient_factory
from core.MLModels import model_factory
from core.MLMetrics import metrics_factory
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -25,13 +28,18 @@ build_model_params = yaml.safe_load(open(build_model_path))
feature_process_path = Path(__file__).parent / "configs" / "feature_processor.yaml"
feature_process_params = yaml.safe_load(open(feature_process_path))
generate_metrics_path = Path(__file__).parent / "configs" / "generate_metrics.yaml"
generate_metrics_params = yaml.safe_load(open(generate_metrics_path))
def build_model(
dataclient: DataClient,
model: MLModel,
metrics: MLMetrics,
target: str,
model_save_location: str,
model_hyperparameters: dict,
fit_metrics_filepath: str,
train_filepath: Union[str, None] = None,
test_filepath: Union[str, None] = None,
train_data: Union[pd.DataFrame, None] = None,
@ -60,12 +68,35 @@ def build_model(
data=train_data, target=target, model_hyperparameters=model_hyperparameters
)
logger.info("------------------------------")
logger.info("--- Generating predictions ---")
logger.info("------------------------------")
prediction_data = train_data.drop(columns=target)
predictions = model.predict(data=prediction_data)
logger.info("------------------------------")
logger.info("--- Generating fit metrics ---")
logger.info("------------------------------")
metrics_output = metrics.generate_metrics(
target=train_data[target],
predictions=pd.Series(predictions),
)
logger.info("--------------------")
logger.info("--- Saving model ---")
logger.info("--------------------")
model.save_model(path=Path(model_save_location))
logger.info("--------------------------")
logger.info("--- Saving fit metrics ---")
logger.info("--------------------------")
dataclient.save_data(obj=metrics_output, location=fit_metrics_filepath)
if __name__ == "__main__":
@ -87,6 +118,12 @@ if __name__ == "__main__":
model_type = build_model_params["model_type"]
model = model_factory(model_type)
logger.info("-------------------------")
logger.info(f"--- Initiate Metrics ---")
logger.info("-------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
logger.info("--------------------------")
logger.info(f"--- Build Model Stage ---")
logger.info("--------------------------")
@ -94,11 +131,13 @@ if __name__ == "__main__":
build_model(
dataclient=dataclient,
model=model,
metrics=metrics,
target=feature_process_params["feature_processor_config"]["target"],
model_save_location=build_model_params["model_save_filepath"],
model_hyperparameters=build_model_params[model_type],
train_filepath=prepare_data_params["output_train_filepath"],
test_filepath=prepare_data_params["output_test_filepath"],
fit_metrics_filepath=build_model_params["fit_metrics_filepath"],
)
logger.info("-------------------------------")

View file

@ -1,5 +1,6 @@
model_type: SKLearnLinearRegression
model_save_filepath: ./data/model/model.joblib
fit_metrics_filepath: ./metrics/fit_metrics.json
SKLearnLinearRegression: null

View file

@ -5,8 +5,8 @@ stages:
deps:
- path: prepare_data.py
hash: md5
md5: 2cfe9e3012280e0cecdb84da12c974d9
size: 5009
md5: 7531a931a405650dc4e8b5d8c1fd3c66
size: 4959
params:
configs/prepare_data.yaml:
output_test_filepath: ./data/prepared_data/test.parquet
@ -15,20 +15,20 @@ stages:
outs:
- path: data/prepared_data/
hash: md5
md5: ea0a2baf3931e692d6344ba609331089.dir
size: 13232732
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
size: 13238511
nfiles: 2
build_model:
cmd: python build_model.py
deps:
- path: build_model.py
hash: md5
md5: 46bcc34f20c6851cd987640889eefde6
size: 3671
md5: c07ce0b8fdaf337ddfb7115684932157
size: 5048
- path: data/prepared_data
hash: md5
md5: ea0a2baf3931e692d6344ba609331089.dir
size: 13232732
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
size: 13238511
nfiles: 2
params:
configs/build_model.yaml:
@ -43,31 +43,36 @@ stages:
SKLearnLinearRegression:
SKLearnSVMRegression:
kernel: linear
fit_metrics_filepath: ./metrics/fit_metrics.json
model_save_filepath: ./data/model/model.joblib
model_type: SKLearnLinearRegression
outs:
- path: data/model/
hash: md5
md5: eb2b910dec66481e75bb6058622f6e55.dir
md5: 2ace0835c28543512982b69d383b3c49.dir
size: 1832
nfiles: 1
- path: metrics/fit_metrics.json
hash: md5
md5: c8c5a40863e2ced7f5f5a844ba203d80
size: 180
generate_predictions:
cmd: python generate_predictions.py
deps:
- path: data/model
hash: md5
md5: eb2b910dec66481e75bb6058622f6e55.dir
md5: 2ace0835c28543512982b69d383b3c49.dir
size: 1832
nfiles: 1
- path: data/prepared_data
hash: md5
md5: ea0a2baf3931e692d6344ba609331089.dir
size: 13232732
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
size: 13238511
nfiles: 2
- path: generate_predictions.py
hash: md5
md5: d412c8c9b48b59a29f569633280a6e7f
size: 4237
md5: ab603e9a526a73f2fe17603e6fe6c0a4
size: 4261
params:
configs/generate_predictions.yaml:
input_dataclient_type: local
@ -78,26 +83,26 @@ stages:
outs:
- path: data/predictions/
hash: md5
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
size: 643406
md5: e87d96ed77d01ab2f24aeab5aaafe344.dir
size: 643838
nfiles: 1
generate_metrics:
cmd: python generate_metrics.py
deps:
- path: data/predictions
hash: md5
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
size: 643406
md5: e87d96ed77d01ab2f24aeab5aaafe344.dir
size: 643838
nfiles: 1
- path: data/prepared_data
hash: md5
md5: ea0a2baf3931e692d6344ba609331089.dir
size: 13232732
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
size: 13238511
nfiles: 2
- path: generate_metrics.py
hash: md5
md5: 5577a28107458dc1e6bcaaa098388095
size: 4144
md5: 78a9b9b25d0a7deaf44277f9afad5f98
size: 4139
params:
configs/generate_metrics.yaml:
dataclient_type: local
@ -108,8 +113,8 @@ stages:
outs:
- path: metrics/metrics.json
hash: md5
md5: d79f798a272e6b50597be4d08ae48fa8
size: 180
md5: f494881710a057f90f82c0bd3a40a41d
size: 183
startup_cleanup:
cmd: python startup_cleanup.py
deps:

View file

@ -29,6 +29,7 @@ stages:
- configs/build_model.yaml:
outs:
- data/model/
- metrics/fit_metrics.json
always_changed: true
generate_predictions:
cmd: python generate_predictions.py
@ -54,3 +55,4 @@ stages:
always_changed: true
metrics:
- metrics/metrics.json
- metrics/fit_metrics.json

View file

@ -1 +1,2 @@
/fit_metrics.json
/metrics.json