diff --git a/deployment/handlers/prediction_app.py b/deployment/handlers/prediction_app.py index 5b9a807..e8b02ac 100644 --- a/deployment/handlers/prediction_app.py +++ b/deployment/handlers/prediction_app.py @@ -47,21 +47,28 @@ def upload_dataframe_to_s3(df, bucket, s3_file_name): return False -def warming_up_invocation(model_filepath: str): +def warming_up_invocation( + model, + model_filepath: str, +): """ Function to handle warm up invocations """ import pandas as pd + import numpy as np - model = model_factory(settings.build_model["model_type"]) - model_filepath = settings.build_model["model_save_filepath"] model.load_model(model_filepath) - warmup_df = pd.DataFrame(columns=model.model.original_features) - warmup_df = pd.concat([warmup_df.T, pd.DataFrame([0] * len(warmup_df.T))], axis=1).T - warmup_df.fillna(0, inplace=True) + warmup_df = pd.DataFrame( + np.zeros((1, len(model.model.original_features))), + columns=model.model.original_features, + ) - model.predict(data=warmup_df) + model_names = model.model.model_names() + if "NeuralNetFastAI" in model_names: + model.model.predict(warmup_df, model="NeuralNetFastAI") + else: + model.predict(data=warmup_df) def handler(event, context): @@ -97,7 +104,7 @@ def handler(event, context): if "warm" in body: logger.info("Warm up invocation - synthetic prediction") - warming_up_invocation(model_filepath=model_filepath) + warming_up_invocation(model=model, model_filepath=model_filepath) return { "statusCode": 200,