Merge pull request #86 from Hestia-Homes/heat-dev-model

Heat dev model
This commit is contained in:
quandanrepo 2023-11-27 22:16:44 +00:00 committed by GitHub
commit 676539e6a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 102 additions and 161 deletions

View file

@ -2,7 +2,7 @@ name: Sap Change Model Deploy
on:
push:
branches: [ sap-dev, sap-prod ]
branches: [ sap-dev, sap-prod, heat-dev, heat-prod, carbon-dev, carbon-prod]
jobs:
deploy:

View file

@ -42,7 +42,14 @@ jobs:
if [ -z "${latest_version}" ]; then
increment_version="1.0.0"
else
increment_version=$(echo ${latest_version} | awk -F'.' '{OFS="."; $1+=1; print}')
increment_version=$(echo ${latest_version} | awk 'BEGIN {
FS="\\." # Set the field separator to a period
OFS="." # Set the output field separator to a period
}
{
major = $1 + 1 # Increment the major version
print major, "0", "0" # Print the new version
}')
fi
new_tag=${REGISTER_MODEL_NAME}@v${increment_version}
@ -80,7 +87,14 @@ jobs:
if [ -z "${latest_version}" ]; then
increment_version="0.1.0"
else
increment_version=$(echo ${latest_version} | awk 'BEGIN{FS=OFS="."} {$2++; print}')
increment_version=$(echo ${latest_version} | awk 'BEGIN {
FS="\\." # Set the field separator to a period
OFS="." # Set the output field separator to a period
}
{
minor = $2 + 1 # Increment the minor version
print $1, minor, "0" # Print the new version
}')
fi
new_tag=${REGISTER_MODEL_NAME}@v${increment_version}
@ -118,7 +132,14 @@ jobs:
if [ -z "${latest_version}" ]; then
increment_version="0.0.1"
else
increment_version=$(echo ${latest_version} | awk 'BEGIN{FS=OFS="."} {$3++; print}')
increment_version=$(echo ${latest_version} | awk 'BEGIN {
FS="\\." # Set the field separator to a period
OFS="." # Set the output field separator to a period
}
{
patch = $3 + 1 # Increment the patch version
print $1, $2, patch # Print the new version
}')
fi
new_tag=${REGISTER_MODEL_NAME}@v${increment_version}
@ -188,7 +209,7 @@ jobs:
git config user.name "Github-Bot"
git config user.email "Github-Bot@no-reply.com"
latest_dev_version=$(gto history ${REGISTER_MODEL_NAME} --asc --plain | awk '{print $NF}' | awk '/dev/')
latest_dev_version=$(gto history ${REGISTER_MODEL_NAME} --asc --plain | awk '{print $NF}' | awk '/dev/' | awk 'END {print}')
if [ -z "${latest_dev_version}" ]; then
increment_version="1"
else
@ -196,7 +217,7 @@ jobs:
fi
new_tag=${REGISTER_MODEL_NAME}#dev#${increment_version}
latest_version=$(gto show model@latest --ref | awk -F"@" '{print $2}')
latest_version=$(gto show ${REGISTER_MODEL_NAME}@latest --ref | awk -F"@" '{print $2}')
echo ${new_tag}

View file

@ -8,9 +8,9 @@
"active": true
},
"sap": {
"version": "v0.0.4",
"version": "v0.1.0",
"stage": {
"dev": "v0.0.4"
"dev": "v0.1.0"
},
"registered": true,
"active": true

View file

@ -10,9 +10,9 @@ tracking and a model registry
- A bolt-on service that can implement model monitoring
There are multiple protected branches which adapt the generic pipeline to produce different models:
- sap_change-**
- heat_change-**
- carbon_change-**
- sap-{dev/staging/prod}-**
- heat-{dev/staging/prod}-**
- carbon-{dev/staging/prod}-**
These branches will differ by the configuration files that define the data used and the outputs of the ML-pipeline
- There can be different additional logic for each branch but the pipeline will be the same.

View file

@ -69,9 +69,7 @@ def handler(event, context):
storage_filepath = f"s3://{PREDICTIONS_BUCKET}/{portfolio_id}/{property_id}/{created_at}.parquet"
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
build_model_params = settings.build_model
client_params = settings.client
@ -80,17 +78,13 @@ def handler(event, context):
model = model_factory(build_model_params["model_type"])
logger.info("----------------------------")
logger.info(f"--- Initiate Input DataClient ---")
logger.info("----------------------------")
input_dataclient = dataclient_factory(
dataclient_type="aws-s3",
dataclient_config=client_params["aws-s3"],
)
logger.info("----------------------------")
logger.info(f"--- Initiate Output DataClient ---")
logger.info("----------------------------")
output_dataclient = dataclient_factory(
dataclient_type="aws-s3",
dataclient_config=client_params["aws-s3"],
@ -107,6 +101,7 @@ def handler(event, context):
predictions_column_name=generate_predictions_params[
"predictions_column_name"
],
identifier_column=generate_predictions_params["identifier_column"],
)
return {

View file

@ -9,16 +9,16 @@ init: dev-conda
.PHONY: dev-conda
dev-conda:
# conda deactivate || echo "Not in conda environment"
# conda remove --name $CONDA_ENV --all -y || echo "No environment created previously"
conda create --name $CONDA_ENV python=$(PYTHON_VERSION) -y
# conda remove --name ${CONDA_ENV} --all -y || echo "No environment created previously"
conda create --name ${CONDA_ENV} python=$(PYTHON_VERSION) -y
conda init bash
conda run -vvvv -n $CONDA_ENV pip install --upgrade pip
conda run -vvvv -n $CONDA_ENV pip install -r src/pipeline/requirements/training/requirements-dev.txt
conda run -vvvv -n $CONDA_ENV pip install -r src/pipeline/requirements/version_control/requirements.txt
conda run -vvvv -n $CONDA_ENV pre-commit install
conda run -vvvv -n $CONDA_ENV pip install ipykernel
conda run -v -n ${CONDA_ENV} pip install --upgrade pip
conda run -v -n ${CONDA_ENV} pip install -r src/pipeline/requirements/training/requirements-dev.txt
conda run -v -n ${CONDA_ENV} pip install -r src/pipeline/requirements/version_control/requirements.txt
conda run -v -n ${CONDA_ENV} pre-commit install
conda run -v -n ${CONDA_ENV} pip install ipykernel
echo "TO ACTIVATE ENVIRONMENT, USE THE FOLLOWING COMMAND"
echo "conda activate $CONDA_ENV"
echo "conda activate ${CONDA_ENV}"
.PHONY: dev-pyenv

View file

@ -16,13 +16,9 @@ def run_cleanup(artefacts_directory: str, metrics_directory: str) -> None:
Remove the directory where artefacts are stored
"""
logger.info("---------------------")
logger.info(f"--- Run Clean up ---")
logger.info("---------------------")
logger.info("-------------------------")
logger.info(f"--- Delete artefacts ---")
logger.info("-------------------------")
artefact_directory_path = Path(artefacts_directory)
@ -31,9 +27,7 @@ def run_cleanup(artefacts_directory: str, metrics_directory: str) -> None:
logger.info(f"Removing the directory: {artefacts_directory}")
shutil.rmtree(artefact_directory_path)
logger.info("-----------------------")
logger.info(f"--- Delete metrics ---")
logger.info("-----------------------")
metrics_directory_path = Path(metrics_directory)
@ -45,15 +39,11 @@ def run_cleanup(artefacts_directory: str, metrics_directory: str) -> None:
if __name__ == "__main__":
logger.info("----------------------------")
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
run_cleanup(
artefacts_directory=startup_cleanup_params["artefacts"],
metrics_directory=startup_cleanup_params["metrics"],
)
logger.info("-------------------------------")
logger.info(f"--- {__file__} - Complete! ---")
logger.info("-------------------------------")

View file

@ -17,9 +17,7 @@ from core.DataClient import dataclient_factory
from core.FeatureProcessor import feature_processor_factory
from config import settings
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -33,9 +31,7 @@ output_train_filepath = prepare_data_params["output_train_filepath"]
output_test_filepath = prepare_data_params["output_test_filepath"]
feature_processor_config = feature_process_params["feature_processor_config"]
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
input_dataclient_type = prepare_data_params["input_dataclient_type"]
output_dataclient_type = prepare_data_params["output_dataclient_type"]
@ -49,9 +45,7 @@ output_dataclient = dataclient_factory(
dataclient_config=client_params[output_dataclient_type],
)
logger.info("----------------------------------")
logger.info(f"--- Initiate FeatureProcessor ---")
logger.info("----------------------------------")
feature_processor = feature_processor_factory(
feature_process_params["feature_processor_type"]
@ -76,15 +70,11 @@ def prepare_data(
:param pipeline_mode: bool, Default False, this caches out the file for experimentation, objects returned in pipeline mode
"""
logger.info("--------------------")
logger.info("--- Loading data ---")
logger.info("--------------------")
data = input_dataclient.load_data(location=data_filepath, load_config={})
logger.info("--------------------------")
logger.info("--- Feature Processing ---")
logger.info("--------------------------")
data = feature_processor.feature_process(
data,
@ -93,9 +83,7 @@ def prepare_data(
new_feature_funcs=new_feature_funcs,
)
logger.info("----------------------")
logger.info("--- Splitting data ---")
logger.info("----------------------")
if train_proportion == 1:
train = data
@ -108,9 +96,7 @@ def prepare_data(
train = train.reset_index(drop=True)
logger.info("-----------------------")
logger.info("--- Outputting data ---")
logger.info("-----------------------")
output_dataclient.save_data(
obj=train, location=output_train_filepath, save_config=None
@ -126,13 +112,9 @@ def prepare_data(
if __name__ == "__main__":
logger.info("----------------------------")
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("---------------------------")
logger.info(f"--- Prepare Data Stage ---")
logger.info("---------------------------")
prepare_data(
input_dataclient=input_dataclient,
@ -147,6 +129,4 @@ if __name__ == "__main__":
new_feature_funcs=new_feature_funcs,
)
logger.info("-------------------------------")
logger.info(f"--- {__file__} - Complete! ---")
logger.info("-------------------------------")

View file

@ -18,9 +18,7 @@ from core.MLMetrics import metrics_factory
from configs.post_prediction_logic import post_prediction_logic
from config import settings
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -40,22 +38,16 @@ train_filepath = prepare_data_params["output_train_filepath"]
test_filepath = prepare_data_params["output_test_filepath"]
fit_metrics_filepath = build_model_params["fit_metrics_filepath"]
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# Output of previous prepare data step, will be where the data is
dataclient = dataclient_factory(prepare_data_params["output_dataclient_type"])
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(model_type)
logger.info("-------------------------")
logger.info(f"--- Initiate Metrics ---")
logger.info("-------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
@ -75,9 +67,7 @@ def build_model(
test_data: Union[pd.DataFrame, None] = None,
pipeline_mode: bool = False,
):
logger.info("--------------------------------------")
logger.info("--- Loading Data for build process ---")
logger.info("--------------------------------------")
if train_data is None:
if train_filepath is None:
@ -89,9 +79,7 @@ def build_model(
raise ValueError(f"Need {test_filepath} if no data supplied")
test_data = dataclient.load_data(location=test_filepath, load_config=None)
logger.info("----------------------")
logger.info("--- Training model ---")
logger.info("----------------------")
model.train_model(
data=train_data.drop(columns=identifier_columns),
@ -99,32 +87,24 @@ def build_model(
model_hyperparameters=model_hyperparameters,
)
logger.info("----------------------------------")
logger.info("--- Generating fit predictions ---")
logger.info("----------------------------------")
fit_predictions = model.predict(
data=train_data, post_prediction_logic=post_prediction_logic
)
logger.info("------------------------------")
logger.info("--- Generating fit metrics ---")
logger.info("------------------------------")
metrics_output = metrics.generate_metrics(
target=train_data[target],
predictions=pd.Series(fit_predictions),
)
logger.info("--------------------")
logger.info("--- Saving model ---")
logger.info("--------------------")
model.save_model(path=Path(model_save_location))
logger.info("--------------------------")
logger.info("--- Saving fit metrics ---")
logger.info("--------------------------")
dataclient.save_data(
obj=metrics_output, location=fit_metrics_filepath, save_config=None
@ -133,13 +113,9 @@ def build_model(
if __name__ == "__main__":
logger.info("----------------------------")
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("--------------------------")
logger.info(f"--- Build Model Stage ---")
logger.info("--------------------------")
build_model(
dataclient=dataclient,
@ -154,6 +130,4 @@ if __name__ == "__main__":
fit_metrics_filepath=fit_metrics_filepath,
)
logger.info("-------------------------------")
logger.info(f"--- {__file__} - Complete! ---")
logger.info("-------------------------------")

View file

@ -10,9 +10,7 @@ from core.Logger import logger
from config import settings
from generate_predictions import generate_predictions
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -33,15 +31,11 @@ model_filepath = build_model_params["model_save_filepath"]
predictions_output_filepath = generate_predictions_params["predictions_output_filepath"]
predictions_column_name = generate_predictions_params["predictions_column_name"]
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(build_model_params["model_type"])
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# We may have different locations of loading hence why we use one specified in generate_predictions.yaml
# I.e. for metric runs, this will be a local data client
@ -59,13 +53,9 @@ output_dataclient = dataclient_factory(
if __name__ == "__main__":
logger.info("----------------------------")
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("----------------------------------")
logger.info(f"--- Generate Predictions Stage---")
logger.info("----------------------------------")
generate_predictions(
input_dataclient=input_dataclient,
@ -78,6 +68,4 @@ if __name__ == "__main__":
predictions_column_name=predictions_column_name,
)
logger.info("-------------------------------")
logger.info(f"--- {__file__} - Complete! ---")
logger.info("-------------------------------")

View file

@ -14,9 +14,7 @@ from core.MLMetrics import metrics_factory
from core.Logger import logger
from config import settings
logger.info("----------------------------")
logger.info(f"--- Initiate Parameters ---")
logger.info("----------------------------")
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -34,15 +32,11 @@ predictions_column_name = generate_predictions_params["predictions_column_name"]
metrics_output_filepath = generate_metrics_params["metrics_output_filepath"]
logger.info("-------------------------")
logger.info(f"--- Initiate MLModel ---")
logger.info("-------------------------")
model = model_factory(build_model_params["model_type"])
logger.info("----------------------------")
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
# Use data client for input and output, as we use dvc to cache later to the cloud
dataclient_type = generate_metrics_params["dataclient_type"]
@ -51,9 +45,7 @@ dataclient = dataclient_factory(
dataclient_config=client_params[dataclient_type],
)
logger.info("---------------------------")
logger.info(f"--- Initiate MLMetrics ---")
logger.info("---------------------------")
metrics = metrics_factory(generate_metrics_params["metrics_type"])
@ -73,34 +65,26 @@ def generate_metrics(
For a given model, we generate prediction and evaluate this against the true target
"""
logger.info("-------------------------")
logger.info("--- Loading test data ---")
logger.info("-------------------------")
test_data = input_dataclient.load_data(
location=test_data_filepath, load_config=None
)
logger.info("---------------------------")
logger.info("--- Loading predictions ---")
logger.info("---------------------------")
predictions = input_dataclient.load_data(
location=predictions_output_filepath, load_config=None
)
logger.info("--------------------------")
logger.info("--- Generating metrics ---")
logger.info("--------------------------")
metrics_output = metrics.generate_metrics(
target=test_data[target],
predictions=pd.Series(predictions[predictions_column_name]),
)
logger.info("----------------------")
logger.info("--- Saving metrics ---")
logger.info("----------------------")
output_dataclient.save_data(
obj=metrics_output, location=metrics_output_filepath, save_config=None
@ -109,13 +93,9 @@ def generate_metrics(
if __name__ == "__main__":
logger.info("----------------------------")
logger.info(f"--- {__file__} - Start! ---")
logger.info("----------------------------")
logger.info("------------------------------")
logger.info(f"--- Generate Metrics Stage---")
logger.info("------------------------------")
generate_metrics(
input_dataclient=dataclient,
@ -129,6 +109,4 @@ if __name__ == "__main__":
metrics_output_filepath=metrics_output_filepath,
)
logger.info("-------------------------------")
logger.info(f"--- {__file__} - Complete! ---")
logger.info("-------------------------------")

View file

@ -13,6 +13,8 @@ default:
output_filepath: ./data/model/allmodels/
problem_type: regression
eval_metric: mean_squared_error #mean_absolute_error
time_limit: 4000
time_limit: 400
presets: medium_quality
excluded_model_types: ['KNN', 'RF']
infer_limit: 0.05
infer_limit_batch_size: 10000

View file

@ -21,7 +21,7 @@ default:
# data_filepath: s3://retrofit-data-dev/sap_change_model/dataset_with_differencing.parquet
# data_filepath: s3://retrofit-data-dev/sap_change_model/floor_area_clean_test.parquet
# data_filepath: s3://retrofit-data-dev/sap_change_model/dataset_without_differencing.parquet
data_filepath: s3://retrofit-data-dev/sap_change_model/dataset.parquet
data_filepath: s3://retrofit-data-dev/sap_change_model/dataset_test.parquet
train_proportion: 0.9
output_train_filepath: ./data/prepared_data/train.parquet
output_test_filepath: ./data/prepared_data/test.parquet
@ -43,6 +43,7 @@ default:
test_data_filepath: ./data/prepared_data/test.parquet
predictions_output_filepath: ./data/predictions/predictions.parquet
predictions_column_name: predictions
identifier_column: id
generate_metrics:
dataclient_type: local

View file

@ -142,9 +142,15 @@ class AWSS3Client:
buffer = BytesIO()
obj.to_parquet(buffer, index=False)
# Reset the buffer position to the beginning
buffer.seek(0)
bucket, key = location.strip("s3://").split("/", 1)
self.client.upload_fileobj(buffer, bucket, key)
# Close the buffer
buffer.close()
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
"""
Load a parquet file

View file

@ -21,6 +21,7 @@ def setup_logger():
# Add the stream handler to the logger
logger.addHandler(stream_handler)
logger.propagate = False
return logger

View file

@ -149,6 +149,8 @@ class AutogluonAutoML:
"time_limit",
"presets",
"excluded_model_types",
"infer_limit",
"infer_limit_batch_size",
]
def load_model(self, path: Union[Path, str]) -> None:
@ -203,6 +205,8 @@ class AutogluonAutoML:
time_limit=model_hyperparameters["time_limit"],
presets=model_hyperparameters["presets"],
excluded_model_types=model_hyperparameters["excluded_model_types"],
infer_limit=model_hyperparameters["infer_limit"],
infer_limit_batch_size=model_hyperparameters["infer_limit_batch_size"],
)
def predict(

View file

@ -5,8 +5,8 @@ stages:
deps:
- path: 1_prepare_data.py
hash: md5
md5: c9f030df733e318b80d1fa91b7732f79
size: 5132
md5: 896d3d88a4a9f68d174efe71dc089517
size: 4222
params:
configs/settings.yaml:
default.feature_processor.feature_processor_config.drop_columns:
@ -20,7 +20,7 @@ stages:
default.feature_processor.feature_processor_config.subsample_seed: 0
default.feature_processor.feature_processor_config.target: HEAT_DEMAND_ENDING
default.feature_processor.feature_processor_type: dataframe
default.prepare_data.data_filepath: s3://retrofit-data-dev/sap_change_model/dataset.parquet
default.prepare_data.data_filepath: s3://retrofit-data-dev/sap_change_model/dataset_test.parquet
default.prepare_data.input_dataclient_type: aws-s3
default.prepare_data.output_dataclient_type: local
default.prepare_data.output_test_filepath: ./data/prepared_data/test.parquet
@ -29,20 +29,20 @@ stages:
outs:
- path: data/prepared_data/
hash: md5
md5: e0be70d5025e40dd0d655d9949f72130.dir
size: 31800776
md5: 6f9c63363ad52a836524dbb6fae7a2ac.dir
size: 34480114
nfiles: 2
build_model:
cmd: python 2_build_model.py
deps:
- path: 2_build_model.py
hash: md5
md5: 84699d208874c52accaff61c6af9bb0a
size: 5359
md5: b824822475c222521516493e68eef9c5
size: 4149
- path: data/prepared_data
hash: md5
md5: e0be70d5025e40dd0d655d9949f72130.dir
size: 31800776
md5: 6f9c63363ad52a836524dbb6fae7a2ac.dir
size: 34480114
nfiles: 2
params:
configs/build_model.yaml:
@ -58,37 +58,39 @@ stages:
output_filepath: ./data/model/allmodels/
problem_type: regression
eval_metric: mean_squared_error
time_limit: 4000
time_limit: 400
presets: medium_quality
excluded_model_types:
- KNN
- RF
infer_limit: 0.05
infer_limit_batch_size: 10000
outs:
- path: data/model/
hash: md5
md5: 14ca33cde5e86770135f768abaf84978.dir
size: 422447808
nfiles: 27
md5: 452eba2d92233e81d321814aacefe5c2.dir
size: 323127043
nfiles: 24
- path: metrics/fit_metrics.json
hash: md5
md5: 41bfb8d2da8f06d1864d73ce125cc6aa
size: 221
md5: 888124b56e0c5008a6423e290fc5cc71
size: 222
generate_predictions:
cmd: python 3_generate_predictions.py
deps:
- path: 3_generate_predictions.py
hash: md5
md5: 5ef2856a5a977304f1ec01f9b4205262
size: 3028
md5: 0a70ad4dfe99414a75d1261c75a177b9
size: 2464
- path: data/model
hash: md5
md5: 14ca33cde5e86770135f768abaf84978.dir
size: 422447808
nfiles: 27
md5: 452eba2d92233e81d321814aacefe5c2.dir
size: 323127043
nfiles: 24
- path: data/prepared_data
hash: md5
md5: e0be70d5025e40dd0d655d9949f72130.dir
size: 31800776
md5: 6f9c63363ad52a836524dbb6fae7a2ac.dir
size: 34480114
nfiles: 2
params:
configs/settings.yaml:
@ -100,25 +102,25 @@ stages:
outs:
- path: data/predictions/
hash: md5
md5: 40d0c7a7fd4a15add0615e322cf341a0.dir
size: 352151
md5: f852550a0a51f0c2b120b0680c1a9b54.dir
size: 325890
nfiles: 1
generate_metrics:
cmd: python 4_generate_metrics.py
deps:
- path: 4_generate_metrics.py
hash: md5
md5: 2c9fb78955a8c19cff0a098976f81d1b
size: 4487
md5: 567b1acb819e2ff432b989cdbdd4a2bf
size: 3448
- path: data/predictions
hash: md5
md5: 40d0c7a7fd4a15add0615e322cf341a0.dir
size: 352151
md5: f852550a0a51f0c2b120b0680c1a9b54.dir
size: 325890
nfiles: 1
- path: data/prepared_data
hash: md5
md5: e0be70d5025e40dd0d655d9949f72130.dir
size: 31800776
md5: 6f9c63363ad52a836524dbb6fae7a2ac.dir
size: 34480114
nfiles: 2
params:
configs/settings.yaml:
@ -128,15 +130,15 @@ stages:
outs:
- path: metrics/metrics.json
hash: md5
md5: 4e023650240e78d6ad761f1db7aac922
size: 220
md5: ed3012943593fac4ac7ad9a5499ac18f
size: 219
startup_cleanup:
cmd: python 0_startup_cleanup.py
deps:
- path: 0_startup_cleanup.py
hash: md5
md5: fbb7e3b1b98b517c870f3e1df3e7f695
size: 1676
md5: b1b12f6b6393fbf8b83d23684df0a3d4
size: 1220
params:
configs/settings.yaml:
default.startup_cleanup.artefacts: ./data

View file

@ -14,28 +14,23 @@ def generate_predictions(
test_data_filepath: str,
predictions_output_filepath: str,
predictions_column_name: str,
identifier_column: str = "id",
):
"""
For a given model, we generate prediction and evaluate this against the true target
"""
logger.info("-------------------------")
logger.info("--- Loading test data ---")
logger.info("-------------------------")
test_data = input_dataclient.load_data(
location=test_data_filepath, load_config=None
)
logger.info("---------------------")
logger.info("--- Loading model ---")
logger.info("---------------------")
model.load_model(model_filepath)
logger.info("------------------------------")
logger.info("--- Generating predictions ---")
logger.info("------------------------------")
prediction_data = (
test_data.drop(columns=target) if target in test_data.columns else test_data
@ -45,13 +40,17 @@ def generate_predictions(
data=prediction_data, post_prediction_logic=post_prediction_logic
)
logger.info("--------------------------")
logger.info("--- Saving predictions ---")
logger.info("--------------------------")
predictions_df = pd.DataFrame(predictions)
predictions_df.columns = [predictions_column_name]
output_dataclient.save_data(
obj=predictions_df, location=predictions_output_filepath, save_config=None
output_df = (
pd.concat([test_data[identifier_column], predictions_df], axis=1)
if identifier_column in test_data.columns
else predictions_df
)
output_dataclient.save_data(
obj=output_df, location=predictions_output_filepath, save_config=None
)