diff --git a/.github/workflows/MLPipelinePullRequest.yml b/.github/workflows/MLPipelinePullRequest.yml index 0d0f661..cf6a231 100644 --- a/.github/workflows/MLPipelinePullRequest.yml +++ b/.github/workflows/MLPipelinePullRequest.yml @@ -88,7 +88,7 @@ jobs: sleep 2 curl -X POST "http://localhost:9000/2015-03-31/functions/function/invocations" \ -H "Content-Type: application/json" \ - -d "{\"body\": \"{\\\"file_location\\\": \\\"s3://retrofit-data-dev/sap_change_model/sample_data_for_cicd/${timestamp}/sample_test.parquet\\\", \\\"property_id\\\": 1, \\\"portfolio_id\\\": 4, \\\"created_at\\\": \\\"now\\\", \\\"testing\\\": true}\"}" + -d "{\"body\": \"{\\\"file_location\\\": \\\"s3://retrofit-data-dev/sap_change_model/sample_data_for_cicd/${timestamp}/sample_test.parquet\\\", \\\"property_id\\\": 1, \\\"portfolio_id\\\": 4, \\\"created_at\\\": \\\"now\\\", \\\"testing\\\": true, \\\"warm\\\": true}\"}" - name: Get Lambda logs run: | diff --git a/deployment/handlers/prediction_app.py b/deployment/handlers/prediction_app.py index 25fa120..7deae3a 100644 --- a/deployment/handlers/prediction_app.py +++ b/deployment/handlers/prediction_app.py @@ -66,14 +66,6 @@ def handler(event, context): created_at = body["created_at"] # TODO: Implement the loading of the model and prediction - - if "testing" in body: - storage_filepath = body["file_location"].replace( - ".parquet", "_output.parquet" - ) - else: - storage_filepath = f"s3://{PREDICTIONS_BUCKET}/{portfolio_id}/{property_id}/{created_at}.parquet" - logger.info(f"--- Initiate MLModel ---") build_model_params = settings.build_model @@ -83,6 +75,27 @@ def handler(event, context): model = model_factory(build_model_params["model_type"]) + model_filepath = build_model_params["model_save_filepath"] + + if "testing" in body: + storage_filepath = body["file_location"].replace( + ".parquet", "_output.parquet" + ) + elif "warm" in body: + logger.info("Warm up invocation - skipping prediction") + + import pandas as pd + + model.load_model(model_filepath) + return { + "statusCode": 200, + "body": json.dumps( + {"message": f"{model.predict(data=pd.DataFrame({'a': [1]}))}"} + ), + } + else: + storage_filepath = f"s3://{PREDICTIONS_BUCKET}/{portfolio_id}/{property_id}/{created_at}.parquet" + logger.info(f"--- Initiate Input DataClient ---") input_dataclient = dataclient_factory( dataclient_type="aws-s3", @@ -100,7 +113,7 @@ def handler(event, context): output_dataclient=output_dataclient, model=model, target=feature_process_params["feature_processor_config"]["target"], - model_filepath=build_model_params["model_save_filepath"], + model_filepath=model_filepath, test_data_filepath=body["file_location"], predictions_output_filepath=storage_filepath, predictions_column_name=generate_predictions_params[