mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
Merge branch 'master' of github.com:Hestia-Homes/ML into model-test
This commit is contained in:
commit
6b7171adc0
10 changed files with 462 additions and 64 deletions
125
.github/workflows/Deploy.yml
vendored
Normal file
125
.github/workflows/Deploy.yml
vendored
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
name: Sap Change Model Deploy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ sap_change-dev, sap_change-prod ]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.10.12
|
||||
|
||||
- name: Install Serverless and plugins
|
||||
run: |
|
||||
npm install -g serverless
|
||||
npm install -g serverless-domain-manager
|
||||
|
||||
- name: Install DVC
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install -r modules/ml-pipeline/src/pipeline/requirements/version_control/requirements.txt
|
||||
|
||||
# Set up all of the secrets required for the deployment
|
||||
- name: set secret prefix which is used across multiple steps
|
||||
id: secret_prefix
|
||||
run: |
|
||||
# Convert branch name to uppercase and replace hyphens with underscores
|
||||
echo "::set-output name=secret_prefix::$(echo "${{ github.ref_name }}" | tr 'a-z-' 'A-Z_')"
|
||||
|
||||
- name: Set domain name
|
||||
id: set_domain
|
||||
run: echo "::set-output name=domain::${{ secrets[format('{0}_DOMAIN_NAME', steps.secret_prefix.outputs.secret_prefix)] }}"
|
||||
|
||||
- name: Set ECR credentials
|
||||
id: set_ecr_credentials
|
||||
run: |
|
||||
# Fetch the secret using the secret prefix
|
||||
echo "::set-output name=ecr_uri::${{ secrets[format('{0}_ECR_URI', steps.secret_prefix.outputs.secret_prefix)] }}"
|
||||
|
||||
- name: Set S3 buckets
|
||||
id: set_s3_buckets
|
||||
run: |
|
||||
# Fetch the secret using the secret prefix
|
||||
echo "::set-output name=data_bucket::${{ secrets[format('{0}_DATA_BUCKET', steps.secret_prefix.outputs.secret_prefix)] }}"
|
||||
echo "::set-output name=predictions_bucket::${{ secrets[format('{0}_PREDICTIONS_BUCKET', steps.secret_prefix.outputs.secret_prefix)] }}"
|
||||
|
||||
- name: Set stack_name
|
||||
id: set_stack_name
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" == "sap_change-dev" || "${{ github.ref_name }}" == "sap_change-prod" ]]; then
|
||||
echo "::set-output name=stack_name::sapmodel"
|
||||
else
|
||||
echo "::set-output name=stack_name::"
|
||||
fi
|
||||
|
||||
- name: Set runtime_environment
|
||||
id: set_runtime_environment
|
||||
run: |
|
||||
# Extract the suffix after the hyphen from the branch name
|
||||
runtime_environment=$(echo "${{ github.ref_name }}" | awk -F'-' '{print $NF}')
|
||||
echo "::set-output name=runtime_environment::$runtime_environment"
|
||||
|
||||
- name: AWS credentials for dev
|
||||
if: ${{ steps.set_runtime_environment.outputs.runtime_environment }} == 'dev'
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.DEV_AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.DEV_AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: eu-west-2
|
||||
|
||||
- name: AWS credentials for prod
|
||||
if: ${{ steps.set_runtime_environment.outputs.runtime_environment }} == 'prod'
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.PROD_AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.PROD_AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: eu-west-2
|
||||
|
||||
- name: DVC Pull
|
||||
run: |
|
||||
cd modules/ml-pipeline/src/pipeline
|
||||
dvc pull -r ${{ steps.set_runtime_environment.outputs.runtime_environment }}
|
||||
|
||||
- name: Setup Docker
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to ECR
|
||||
run: |
|
||||
aws ecr get-login-password --region eu-west-2 | docker login --username AWS --password-stdin ${{ steps.set_ecr_credentials.outputs.ecr_uri }}
|
||||
|
||||
# Building and pushing Docker image with caching
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
file: ./deployment/Dockerfile.prediction.lambda
|
||||
push: true
|
||||
tags: ${{ steps.set_ecr_credentials.outputs.ecr_uri }}:${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64
|
||||
provenance: false
|
||||
build-args: |
|
||||
RUNTIME_ENVIRONMENT=${{ steps.set_runtime_environment.outputs.runtime_environment }}
|
||||
|
||||
- name: Deploy to AWS Lambda via Serverless
|
||||
env:
|
||||
RUNTIME_ENVIRONMENT: ${{ steps.set_runtime_environment.outputs.runtime_environment }}
|
||||
PREDICTIONS_BUCKET: ${{ steps.set_s3_buckets.outputs.predictions_bucket }}
|
||||
DATA_BUCKET: ${{ steps.set_s3_buckets.outputs.data_bucket }}
|
||||
DOMAIN_NAME: ${{ steps.set_domain.outputs.domain }}
|
||||
ECR_URI: ${{ steps.set_ecr_credentials.outputs.ecr_uri }}
|
||||
GITHUB_SHA: ${{ github.sha }}
|
||||
STACK_NAME: ${{ steps.set_stack_name.outputs.stack_name }}
|
||||
run: |
|
||||
# Deploy to AWS Lambda via Serverless
|
||||
cd deployment
|
||||
sls deploy --config serverless.yml --stage ${{ steps.set_runtime_environment.outputs.runtime_environment }} --verbose
|
||||
|
|
@ -6,5 +6,13 @@
|
|||
},
|
||||
"registered": true,
|
||||
"active": true
|
||||
},
|
||||
"migrate": {
|
||||
"version": null,
|
||||
"stage": {
|
||||
"dev": "f320b9e0e9f3ea7735aed1abee07b1fb498c39c3"
|
||||
},
|
||||
"registered": true,
|
||||
"active": true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
67
README.md
67
README.md
|
|
@ -3,7 +3,7 @@
|
|||
Creating a ML-toolkit that can be reused:
|
||||
|
||||
- ML pipeline:
|
||||
- A generic pipeline that has data version control, experiment
|
||||
- A generic pipeline that has data version control, experiment
|
||||
tracking and a model registry
|
||||
|
||||
- ML monitoring:
|
||||
|
|
@ -17,7 +17,68 @@ There are multiple protected branches which adapt the generic pipeline to produc
|
|||
These branches will differ by the configuration files that define the data used and the outputs of the ML-pipeline
|
||||
- There can be different additional logic for each branch but the pipeline will be the same.
|
||||
|
||||
# Deployment
|
||||
# Deployment
|
||||
|
||||
TBD
|
||||
Scripts associated to deployment can be found in the deployment/ folder.
|
||||
|
||||
Deployment is automated via Github Actions, where a deployment is triggered by a push to one of the
|
||||
protected branch, with one of dev or prod as the suffix, describing the target environment.
|
||||
|
||||
The github actions file will build and push a docker image to ECR and then deploy a lambda
|
||||
which produces predictions for the relevant model.
|
||||
|
||||
In order for this to be set up, some key environment variables needs to be inserted into Github
|
||||
secrets. Each different model and protected branch has its own set of secrets which allows for flexibility
|
||||
between different pipelines.
|
||||
|
||||
For example, for the branch sap_change-dev, the prefix=SAP_CHANGE_DEV, and the following secrets are:
|
||||
|
||||
- {prefix}_ECR_URI, which is the URI of the ECR repository to push to. For example, for the
|
||||
sap change model this is the lambda-sap-prediction-dev repository.
|
||||
- {prefix}_DOMAIN_NAME, is the custom domain name. This is likely going to be the same across the different
|
||||
models, but is still included in the secrets for flexibility.
|
||||
- {prefix}_DATA_BUCKET, is the name of the s3 data bucket where data to be scored by the model is stored
|
||||
- {prefix}_MODEL_BUCKET, is the name of the s3 bucket where the model is stored
|
||||
- {prefix}_PREDICTIONS_BUCKET, is the name of the s3 bucket where the predictions are stored
|
||||
|
||||
|
||||
# Building and Testing the Prediction Lambda Function Locally
|
||||
TODO: Generalise these instructions for the various different pipelines
|
||||
|
||||
This guide outlines the steps to build and test the Lambda function locally using Docker. These instructions assume you're working with a machine that has Docker installed.
|
||||
|
||||
### Prerequisites
|
||||
Docker: Make sure Docker is installed and running on your machine.
|
||||
AWS Credentials: Ensure you have AWS credentials set up on your local machine, typically stored
|
||||
in ~/.aws/credentials.
|
||||
Root Directory: All commands should be run from the root directory of the repository.
|
||||
Step-by-Step Guide
|
||||
1. Building the Docker Image
|
||||
First, navigate to the root directory of the repository. Open a terminal and execute the following
|
||||
2. command to build the Docker image:
|
||||
|
||||
```bash
|
||||
docker build -t sap_change -f deployment/Dockerfile.prediction.lambda .
|
||||
```
|
||||
|
||||
This will build a Docker image tagged as sap_change using the Dockerfile.prediction.lambda located
|
||||
in the deployment directory.
|
||||
|
||||
2. Running the Docker Image
|
||||
Once the image is built, you can run it using the following command:
|
||||
|
||||
```bash
|
||||
docker run -p 9000:8080 -v ~/.aws/credentials:/root/.aws/credentials:ro -e RUNTIME_ENVIRONMENT=dev sap_change
|
||||
```
|
||||
This command does the following:
|
||||
|
||||
Maps port 9000 on your local machine to port 8080 on the Docker container.
|
||||
Mounts your AWS credentials into the Docker container in read-only mode.
|
||||
Sets the RUNTIME_ENVIRONMENT variable to dev.
|
||||
3. Testing the Lambda Function
|
||||
To test the Lambda function, use the following curl command:
|
||||
|
||||
```json
|
||||
curl -XPOST "http://localhost:9000/2015-03-31/functions/function/invocations" -d '{"body": "{\"file_location\": \"s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/test_data_with_id.parquet\", \"property_id\": 1, \"portfolio_id\": 4, \"created_at\": \"now\"}"'
|
||||
```
|
||||
This will send a POST request to the running Lambda function and pass in the required data as JSON.
|
||||
|
|
|
|||
25
deployment/Dockerfile.prediction.lambda
Normal file
25
deployment/Dockerfile.prediction.lambda
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
FROM public.ecr.aws/lambda/python:3.10
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR ${LAMBDA_TASK_ROOT}
|
||||
ENV PYTHONPATH "${PYTHONPATH}:${LAMBDA_TASK_ROOT}"
|
||||
|
||||
# Environment variables
|
||||
ARG RUNTIME_ENVIRONMENT
|
||||
ENV RUNTIME_ENVIRONMENT=${RUNTIME_ENVIRONMENT}
|
||||
|
||||
# Install necessary build tools - required to test locally
|
||||
RUN yum install -y gcc python3-devel
|
||||
|
||||
# Install python packages
|
||||
COPY modules/ml-pipeline/src/pipeline/requirements/predictions/requirements.txt ./requirements.txt
|
||||
RUN pip install --no-cache-dir -r ./requirements.txt
|
||||
|
||||
# Copy the project code
|
||||
COPY modules/ml-pipeline/src/pipeline ./pipeline
|
||||
# Copy the handler
|
||||
COPY deployment/handlers/prediction_app.py ./pipeline/prediction_app.py
|
||||
WORKDIR ${LAMBDA_TASK_ROOT}/pipeline
|
||||
|
||||
|
||||
CMD [ "prediction_app.handler" ]
|
||||
128
deployment/handlers/prediction_app.py
Normal file
128
deployment/handlers/prediction_app.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
This script is the handler for the lambda prediction function, responsible
|
||||
for producting predictions for a model
|
||||
"""
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
import json
|
||||
from io import StringIO
|
||||
import os
|
||||
import logging
|
||||
from generate_predictions import generate_predictions
|
||||
from core.MLModels import model_factory
|
||||
from config import settings
|
||||
from core.DataClient import dataclient_factory
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
PREDICTIONS_BUCKET = os.getenv("PREDICTIONS_BUCKET", None)
|
||||
|
||||
|
||||
def upload_dataframe_to_s3(df, bucket, s3_file_name):
|
||||
"""
|
||||
Upload a pandas DataFrame to an S3 bucket as CSV
|
||||
|
||||
:param df: DataFrame to upload
|
||||
:param bucket: Bucket to upload to
|
||||
:param s3_file_name: S3 object name
|
||||
:return: True if file was uploaded, else False
|
||||
"""
|
||||
|
||||
# Initialize the S3 client
|
||||
s3 = boto3.client("s3")
|
||||
csv_buffer = StringIO()
|
||||
|
||||
# Write the DataFrame to the buffer as CSV
|
||||
df.to_csv(csv_buffer, index=False)
|
||||
|
||||
try:
|
||||
# Upload the CSV from the buffer to S3
|
||||
s3.put_object(Bucket=bucket, Key=s3_file_name, Body=csv_buffer.getvalue())
|
||||
print(f"Successfully uploaded DataFrame to {bucket}/{s3_file_name}")
|
||||
return True
|
||||
except NoCredentialsError:
|
||||
print("Credentials not available")
|
||||
return False
|
||||
|
||||
|
||||
def handler(event, context):
|
||||
"""
|
||||
Take in event and trigger the prediction pipeline
|
||||
"""
|
||||
|
||||
logger.info("received event: " + str(event))
|
||||
|
||||
try:
|
||||
body = (
|
||||
json.loads(event["body"])
|
||||
if not isinstance(event["body"], dict)
|
||||
else event["body"]
|
||||
)
|
||||
|
||||
property_id = body["property_id"]
|
||||
portfolio_id = body["portfolio_id"]
|
||||
created_at = body["created_at"]
|
||||
|
||||
# TODO: Implement the loading of the model and prediction
|
||||
|
||||
storage_filepath = f"s3://{PREDICTIONS_BUCKET}/{portfolio_id}/{property_id}/{created_at}.parquet"
|
||||
|
||||
logger.info("-------------------------")
|
||||
logger.info(f"--- Initiate MLModel ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
build_model_params = settings.build_model
|
||||
client_params = settings.client
|
||||
feature_process_params = settings.feature_processor
|
||||
generate_predictions_params = settings.generate_predictions
|
||||
|
||||
model = model_factory(build_model_params["model_type"])
|
||||
|
||||
logger.info("----------------------------")
|
||||
logger.info(f"--- Initiate Input DataClient ---")
|
||||
logger.info("----------------------------")
|
||||
input_dataclient = dataclient_factory(
|
||||
dataclient_type="aws-s3",
|
||||
dataclient_config=client_params["aws-s3"],
|
||||
)
|
||||
|
||||
logger.info("----------------------------")
|
||||
logger.info(f"--- Initiate Output DataClient ---")
|
||||
logger.info("----------------------------")
|
||||
output_dataclient = dataclient_factory(
|
||||
dataclient_type="aws-s3",
|
||||
dataclient_config=client_params["aws-s3"],
|
||||
)
|
||||
|
||||
generate_predictions(
|
||||
input_dataclient=input_dataclient,
|
||||
output_dataclient=output_dataclient,
|
||||
model=model,
|
||||
target=feature_process_params["feature_processor_config"]["target"],
|
||||
model_filepath=build_model_params["model_save_filepath"],
|
||||
test_data_filepath=body["file_location"],
|
||||
predictions_output_filepath=storage_filepath,
|
||||
predictions_column_name=generate_predictions_params[
|
||||
"predictions_column_name"
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
"statusCode": 200,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"message": "Successfully processed input",
|
||||
"storage_filepath": storage_filepath,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
except (Exception, KeyError, ValueError) as e:
|
||||
logger.info("Prediction failed")
|
||||
logger.info(e)
|
||||
return {
|
||||
"statusCode": 500,
|
||||
"body": json.dumps({"message": "Prediction failed", "error": str(e)}),
|
||||
}
|
||||
53
deployment/serverless.yml
Normal file
53
deployment/serverless.yml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
service: ${env:STACK_NAME}
|
||||
|
||||
provider:
|
||||
name: aws
|
||||
region: eu-west-2
|
||||
architecture: x86_64
|
||||
environment:
|
||||
RUNTIME_ENVIRONMENT: ${env:RUNTIME_ENVIRONMENT}
|
||||
PREDICTIONS_BUCKET: ${env:PREDICTIONS_BUCKET}
|
||||
DATA_BUCKET: ${env:DATA_BUCKET}
|
||||
DOMAIN_NAME: ${env:DOMAIN_NAME}
|
||||
ECR_URI: ${env:ECR_URI}
|
||||
GITHUB_SHA: ${env:GITHUB_SHA}
|
||||
iam:
|
||||
role:
|
||||
name: ${env:STACK_NAME}_s3_access
|
||||
statements:
|
||||
# Allow reading from the DATA_BUCKET
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- s3:*
|
||||
Resource:
|
||||
- arn:aws:s3:::${env:DATA_BUCKET}
|
||||
- arn:aws:s3:::${env:DATA_BUCKET}/*
|
||||
# Allow reading and writing to PREDICTIONS_BUCKET
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- s3:*
|
||||
Resource:
|
||||
- arn:aws:s3:::${env:PREDICTIONS_BUCKET}
|
||||
- arn:aws:s3:::${env:PREDICTIONS_BUCKET}/*
|
||||
|
||||
|
||||
|
||||
plugins:
|
||||
- serverless-domain-manager
|
||||
|
||||
custom:
|
||||
customDomain:
|
||||
domainName: api.${self:provider.environment.DOMAIN_NAME}
|
||||
basePath: ${env:STACK_NAME}
|
||||
createRoute53Record: true
|
||||
certificateArn: ${ssm:/ssl_certificate_arn}
|
||||
|
||||
functions:
|
||||
sap_prediction_lambda:
|
||||
image:
|
||||
uri: ${env:ECR_URI}:${env:GITHUB_SHA}
|
||||
events:
|
||||
- http:
|
||||
path: /predict
|
||||
method: POST
|
||||
timeout: 120 # Set max run time to 2 minutes - we shouldn't need this much time so this can be reviewed
|
||||
1
modules/ml-pipeline/.gitignore
vendored
1
modules/ml-pipeline/.gitignore
vendored
|
|
@ -3,3 +3,4 @@
|
|||
__pycache__/
|
||||
.DS_Store
|
||||
.vscode/
|
||||
data/
|
||||
|
|
|
|||
|
|
@ -4,16 +4,11 @@ After the model is built, we can evaluate its performance
|
|||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.MLModels import model_factory
|
||||
from core.Logger import logger
|
||||
from configs.post_prediction_logic import post_prediction_logic
|
||||
from config import settings
|
||||
from generate_predictions import generate_predictions
|
||||
|
||||
logger.info("----------------------------")
|
||||
logger.info(f"--- Initiate Parameters ---")
|
||||
|
|
@ -62,58 +57,6 @@ output_dataclient = dataclient_factory(
|
|||
)
|
||||
|
||||
|
||||
def generate_predictions(
|
||||
input_dataclient: DataClient,
|
||||
output_dataclient: DataClient,
|
||||
model: MLModel,
|
||||
target: str,
|
||||
model_filepath: str,
|
||||
test_data_filepath: str,
|
||||
predictions_output_filepath: str,
|
||||
predictions_column_name: str,
|
||||
):
|
||||
"""
|
||||
For a given model, we generate prediction and evaluate this against the true target
|
||||
"""
|
||||
|
||||
logger.info("-------------------------")
|
||||
logger.info("--- Loading test data ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
test_data = input_dataclient.load_data(
|
||||
location=test_data_filepath, load_config=None
|
||||
)
|
||||
|
||||
logger.info("---------------------")
|
||||
logger.info("--- Loading model ---")
|
||||
logger.info("---------------------")
|
||||
|
||||
model.load_model(model_filepath)
|
||||
|
||||
logger.info("------------------------------")
|
||||
logger.info("--- Generating predictions ---")
|
||||
logger.info("------------------------------")
|
||||
|
||||
prediction_data = (
|
||||
test_data.drop(columns=target) if target in test_data.columns else test_data
|
||||
)
|
||||
|
||||
predictions = model.predict(
|
||||
data=prediction_data, post_prediction_logic=post_prediction_logic
|
||||
)
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Saving predictions ---")
|
||||
logger.info("--------------------------")
|
||||
|
||||
predictions_df = pd.DataFrame(predictions)
|
||||
predictions_df.columns = [predictions_column_name]
|
||||
|
||||
output_dataclient.save_data(
|
||||
obj=predictions_df, location=predictions_output_filepath, save_config=None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
logger.info("----------------------------")
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
/prepared_data
|
||||
/model
|
||||
/predictions
|
||||
57
modules/ml-pipeline/src/pipeline/generate_predictions.py
Normal file
57
modules/ml-pipeline/src/pipeline/generate_predictions.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import pandas as pd
|
||||
from configs.post_prediction_logic import post_prediction_logic
|
||||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
def generate_predictions(
|
||||
input_dataclient: DataClient,
|
||||
output_dataclient: DataClient,
|
||||
model: MLModel,
|
||||
target: str,
|
||||
model_filepath: str,
|
||||
test_data_filepath: str,
|
||||
predictions_output_filepath: str,
|
||||
predictions_column_name: str,
|
||||
):
|
||||
"""
|
||||
For a given model, we generate prediction and evaluate this against the true target
|
||||
"""
|
||||
|
||||
logger.info("-------------------------")
|
||||
logger.info("--- Loading test data ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
test_data = input_dataclient.load_data(
|
||||
location=test_data_filepath, load_config=None
|
||||
)
|
||||
|
||||
logger.info("---------------------")
|
||||
logger.info("--- Loading model ---")
|
||||
logger.info("---------------------")
|
||||
|
||||
model.load_model(model_filepath)
|
||||
|
||||
logger.info("------------------------------")
|
||||
logger.info("--- Generating predictions ---")
|
||||
logger.info("------------------------------")
|
||||
|
||||
prediction_data = (
|
||||
test_data.drop(columns=target) if target in test_data.columns else test_data
|
||||
)
|
||||
|
||||
predictions = model.predict(
|
||||
data=prediction_data, post_prediction_logic=post_prediction_logic
|
||||
)
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Saving predictions ---")
|
||||
logger.info("--------------------------")
|
||||
|
||||
predictions_df = pd.DataFrame(predictions)
|
||||
predictions_df.columns = [predictions_column_name]
|
||||
|
||||
output_dataclient.save_data(
|
||||
obj=predictions_df, location=predictions_output_filepath, save_config=None
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue