mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
commit
26c011069a
5 changed files with 72 additions and 24 deletions
|
|
@ -9,10 +9,13 @@ import pandas as pd
|
|||
from typing import Union
|
||||
from pathlib import Path
|
||||
from core.Logger import logger
|
||||
from core.interface.InterfaceMetrics import MLMetrics
|
||||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.MLModels import model_factory
|
||||
from core.MLMetrics import metrics_factory
|
||||
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
|
||||
|
|
@ -25,13 +28,18 @@ build_model_params = yaml.safe_load(open(build_model_path))
|
|||
feature_process_path = Path(__file__).parent / "configs" / "feature_processor.yaml"
|
||||
feature_process_params = yaml.safe_load(open(feature_process_path))
|
||||
|
||||
generate_metrics_path = Path(__file__).parent / "configs" / "generate_metrics.yaml"
|
||||
generate_metrics_params = yaml.safe_load(open(generate_metrics_path))
|
||||
|
||||
|
||||
def build_model(
|
||||
dataclient: DataClient,
|
||||
model: MLModel,
|
||||
metrics: MLMetrics,
|
||||
target: str,
|
||||
model_save_location: str,
|
||||
model_hyperparameters: dict,
|
||||
fit_metrics_filepath: str,
|
||||
train_filepath: Union[str, None] = None,
|
||||
test_filepath: Union[str, None] = None,
|
||||
train_data: Union[pd.DataFrame, None] = None,
|
||||
|
|
@ -60,12 +68,35 @@ def build_model(
|
|||
data=train_data, target=target, model_hyperparameters=model_hyperparameters
|
||||
)
|
||||
|
||||
logger.info("------------------------------")
|
||||
logger.info("--- Generating predictions ---")
|
||||
logger.info("------------------------------")
|
||||
|
||||
prediction_data = train_data.drop(columns=target)
|
||||
|
||||
predictions = model.predict(data=prediction_data)
|
||||
|
||||
logger.info("------------------------------")
|
||||
logger.info("--- Generating fit metrics ---")
|
||||
logger.info("------------------------------")
|
||||
|
||||
metrics_output = metrics.generate_metrics(
|
||||
target=train_data[target],
|
||||
predictions=pd.Series(predictions),
|
||||
)
|
||||
|
||||
logger.info("--------------------")
|
||||
logger.info("--- Saving model ---")
|
||||
logger.info("--------------------")
|
||||
|
||||
model.save_model(path=Path(model_save_location))
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Saving fit metrics ---")
|
||||
logger.info("--------------------------")
|
||||
|
||||
dataclient.save_data(obj=metrics_output, location=fit_metrics_filepath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
@ -87,6 +118,12 @@ if __name__ == "__main__":
|
|||
model_type = build_model_params["model_type"]
|
||||
model = model_factory(model_type)
|
||||
|
||||
logger.info("-------------------------")
|
||||
logger.info(f"--- Initiate Metrics ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
metrics = metrics_factory(generate_metrics_params["metrics_type"])
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info(f"--- Build Model Stage ---")
|
||||
logger.info("--------------------------")
|
||||
|
|
@ -94,11 +131,13 @@ if __name__ == "__main__":
|
|||
build_model(
|
||||
dataclient=dataclient,
|
||||
model=model,
|
||||
metrics=metrics,
|
||||
target=feature_process_params["feature_processor_config"]["target"],
|
||||
model_save_location=build_model_params["model_save_filepath"],
|
||||
model_hyperparameters=build_model_params[model_type],
|
||||
train_filepath=prepare_data_params["output_train_filepath"],
|
||||
test_filepath=prepare_data_params["output_test_filepath"],
|
||||
fit_metrics_filepath=build_model_params["fit_metrics_filepath"],
|
||||
)
|
||||
|
||||
logger.info("-------------------------------")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
model_type: SKLearnLinearRegression
|
||||
model_save_filepath: ./data/model/model.joblib
|
||||
fit_metrics_filepath: ./metrics/fit_metrics.json
|
||||
|
||||
SKLearnLinearRegression: null
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ stages:
|
|||
deps:
|
||||
- path: prepare_data.py
|
||||
hash: md5
|
||||
md5: 2cfe9e3012280e0cecdb84da12c974d9
|
||||
size: 5009
|
||||
md5: 7531a931a405650dc4e8b5d8c1fd3c66
|
||||
size: 4959
|
||||
params:
|
||||
configs/prepare_data.yaml:
|
||||
output_test_filepath: ./data/prepared_data/test.parquet
|
||||
|
|
@ -15,20 +15,20 @@ stages:
|
|||
outs:
|
||||
- path: data/prepared_data/
|
||||
hash: md5
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
|
||||
size: 13238511
|
||||
nfiles: 2
|
||||
build_model:
|
||||
cmd: python build_model.py
|
||||
deps:
|
||||
- path: build_model.py
|
||||
hash: md5
|
||||
md5: 46bcc34f20c6851cd987640889eefde6
|
||||
size: 3671
|
||||
md5: c07ce0b8fdaf337ddfb7115684932157
|
||||
size: 5048
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
|
||||
size: 13238511
|
||||
nfiles: 2
|
||||
params:
|
||||
configs/build_model.yaml:
|
||||
|
|
@ -43,31 +43,36 @@ stages:
|
|||
SKLearnLinearRegression:
|
||||
SKLearnSVMRegression:
|
||||
kernel: linear
|
||||
fit_metrics_filepath: ./metrics/fit_metrics.json
|
||||
model_save_filepath: ./data/model/model.joblib
|
||||
model_type: SKLearnLinearRegression
|
||||
outs:
|
||||
- path: data/model/
|
||||
hash: md5
|
||||
md5: eb2b910dec66481e75bb6058622f6e55.dir
|
||||
md5: 2ace0835c28543512982b69d383b3c49.dir
|
||||
size: 1832
|
||||
nfiles: 1
|
||||
- path: metrics/fit_metrics.json
|
||||
hash: md5
|
||||
md5: c8c5a40863e2ced7f5f5a844ba203d80
|
||||
size: 180
|
||||
generate_predictions:
|
||||
cmd: python generate_predictions.py
|
||||
deps:
|
||||
- path: data/model
|
||||
hash: md5
|
||||
md5: eb2b910dec66481e75bb6058622f6e55.dir
|
||||
md5: 2ace0835c28543512982b69d383b3c49.dir
|
||||
size: 1832
|
||||
nfiles: 1
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
|
||||
size: 13238511
|
||||
nfiles: 2
|
||||
- path: generate_predictions.py
|
||||
hash: md5
|
||||
md5: d412c8c9b48b59a29f569633280a6e7f
|
||||
size: 4237
|
||||
md5: ab603e9a526a73f2fe17603e6fe6c0a4
|
||||
size: 4261
|
||||
params:
|
||||
configs/generate_predictions.yaml:
|
||||
input_dataclient_type: local
|
||||
|
|
@ -78,26 +83,26 @@ stages:
|
|||
outs:
|
||||
- path: data/predictions/
|
||||
hash: md5
|
||||
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
|
||||
size: 643406
|
||||
md5: e87d96ed77d01ab2f24aeab5aaafe344.dir
|
||||
size: 643838
|
||||
nfiles: 1
|
||||
generate_metrics:
|
||||
cmd: python generate_metrics.py
|
||||
deps:
|
||||
- path: data/predictions
|
||||
hash: md5
|
||||
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
|
||||
size: 643406
|
||||
md5: e87d96ed77d01ab2f24aeab5aaafe344.dir
|
||||
size: 643838
|
||||
nfiles: 1
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
md5: e36ed6e937196ab64dcfe9b5b97b6e9f.dir
|
||||
size: 13238511
|
||||
nfiles: 2
|
||||
- path: generate_metrics.py
|
||||
hash: md5
|
||||
md5: 5577a28107458dc1e6bcaaa098388095
|
||||
size: 4144
|
||||
md5: 78a9b9b25d0a7deaf44277f9afad5f98
|
||||
size: 4139
|
||||
params:
|
||||
configs/generate_metrics.yaml:
|
||||
dataclient_type: local
|
||||
|
|
@ -108,8 +113,8 @@ stages:
|
|||
outs:
|
||||
- path: metrics/metrics.json
|
||||
hash: md5
|
||||
md5: d79f798a272e6b50597be4d08ae48fa8
|
||||
size: 180
|
||||
md5: f494881710a057f90f82c0bd3a40a41d
|
||||
size: 183
|
||||
startup_cleanup:
|
||||
cmd: python startup_cleanup.py
|
||||
deps:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ stages:
|
|||
- configs/build_model.yaml:
|
||||
outs:
|
||||
- data/model/
|
||||
- metrics/fit_metrics.json
|
||||
always_changed: true
|
||||
generate_predictions:
|
||||
cmd: python generate_predictions.py
|
||||
|
|
@ -54,3 +55,4 @@ stages:
|
|||
always_changed: true
|
||||
metrics:
|
||||
- metrics/metrics.json
|
||||
- metrics/fit_metrics.json
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
/fit_metrics.json
|
||||
/metrics.json
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue