diff --git a/modules/ml-pipeline/src/pipeline/src/build_model.py b/modules/ml-pipeline/src/pipeline/src/build_model.py index 5dfc71a..a07e9cf 100644 --- a/modules/ml-pipeline/src/pipeline/src/build_model.py +++ b/modules/ml-pipeline/src/pipeline/src/build_model.py @@ -53,12 +53,12 @@ def build_model( if train_data is None: if train_filepath is None: raise ValueError(f"Need {train_filepath} if no data supplied") - train_data = dataclient.load_data(location=train_filepath) + train_data = dataclient.load_data(location=train_filepath, load_config=None) if test_data is None: if test_filepath is None: raise ValueError(f"Need {test_filepath} if no data supplied") - test_data = dataclient.load_data(location=test_filepath) + test_data = dataclient.load_data(location=test_filepath, load_config=None) logger.info("----------------------") logger.info("--- Training model ---") @@ -95,7 +95,9 @@ def build_model( logger.info("--- Saving fit metrics ---") logger.info("--------------------------") - dataclient.save_data(obj=metrics_output, location=fit_metrics_filepath) + dataclient.save_data( + obj=metrics_output, location=fit_metrics_filepath, save_config=None + ) if __name__ == "__main__": diff --git a/modules/ml-pipeline/src/pipeline/src/core/DataClient.py b/modules/ml-pipeline/src/pipeline/src/core/DataClient.py index c8c9f2c..28ffff7 100644 --- a/modules/ml-pipeline/src/pipeline/src/core/DataClient.py +++ b/modules/ml-pipeline/src/pipeline/src/core/DataClient.py @@ -8,7 +8,7 @@ import boto3 import pandas as pd from pathlib import Path from io import BytesIO -from typing import List, Union +from typing import List, Union, Any from core.interface.InterfaceDataClient import DataClient from core.Logger import logger @@ -105,7 +105,7 @@ class AWSS3Client: def save_data( self, - obj: object, + obj: Any, location: str, save_config: Union[dict, None] = None, ) -> None: @@ -134,7 +134,7 @@ class AWSS3Client: obj=obj, location=location, save_config=save_config ) - def _save_parquet(self, obj: object, location: str, save_config: dict): + def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict): """ Save object as parquet """ diff --git a/modules/ml-pipeline/src/pipeline/src/core/interface/InterfaceDataClient.py b/modules/ml-pipeline/src/pipeline/src/core/interface/InterfaceDataClient.py index d572c2b..5e51a99 100644 --- a/modules/ml-pipeline/src/pipeline/src/core/interface/InterfaceDataClient.py +++ b/modules/ml-pipeline/src/pipeline/src/core/interface/InterfaceDataClient.py @@ -3,8 +3,7 @@ Interface for all DataClient i.e. s3, database, local etc """ import pandas as pd -from io import BytesIO -from typing import Protocol, Union +from typing import Protocol, Union, Any class DataClient(Protocol): @@ -22,9 +21,10 @@ class DataClient(Protocol): """ Generic to load data """ + ... def save_data( - self, obj: object, location: str, save_config: Union[dict, None] + self, obj: Any, location: str, save_config: Union[dict, None] ) -> None: """ Generic to save data diff --git a/modules/ml-pipeline/src/pipeline/src/generate_metrics.py b/modules/ml-pipeline/src/pipeline/src/generate_metrics.py index 7efeda9..58244bc 100644 --- a/modules/ml-pipeline/src/pipeline/src/generate_metrics.py +++ b/modules/ml-pipeline/src/pipeline/src/generate_metrics.py @@ -59,14 +59,16 @@ def generate_metrics( logger.info("-------------------------") test_data = input_dataclient.load_data( - location=test_data_filepath, + location=test_data_filepath, load_config=None ) logger.info("---------------------------") logger.info("--- Loading predictions ---") logger.info("---------------------------") - predictions = input_dataclient.load_data(location=predictions_output_filepath) + predictions = input_dataclient.load_data( + location=predictions_output_filepath, load_config=None + ) logger.info("--------------------------") logger.info("--- Generating metrics ---") @@ -81,7 +83,9 @@ def generate_metrics( logger.info("--- Saving metrics ---") logger.info("----------------------") - output_dataclient.save_data(obj=metrics_output, location=metrics_output_filepath) + output_dataclient.save_data( + obj=metrics_output, location=metrics_output_filepath, save_config=None + ) if __name__ == "__main__": diff --git a/modules/ml-pipeline/src/pipeline/src/generate_predictions.py b/modules/ml-pipeline/src/pipeline/src/generate_predictions.py index f80ec18..490d7e9 100644 --- a/modules/ml-pipeline/src/pipeline/src/generate_predictions.py +++ b/modules/ml-pipeline/src/pipeline/src/generate_predictions.py @@ -52,7 +52,9 @@ def generate_predictions( logger.info("--- Loading test data ---") logger.info("-------------------------") - test_data = input_dataclient.load_data(location=test_data_filepath) + test_data = input_dataclient.load_data( + location=test_data_filepath, load_config=None + ) logger.info("---------------------") logger.info("--- Loading model ---") @@ -78,7 +80,7 @@ def generate_predictions( predictions_df.columns = [predictions_column_name] output_dataclient.save_data( - obj=predictions_df, location=predictions_output_filepath + obj=predictions_df, location=predictions_output_filepath, save_config=None ) diff --git a/modules/ml-pipeline/src/pipeline/src/prepare_data.py b/modules/ml-pipeline/src/pipeline/src/prepare_data.py index 851be48..8caa101 100644 --- a/modules/ml-pipeline/src/pipeline/src/prepare_data.py +++ b/modules/ml-pipeline/src/pipeline/src/prepare_data.py @@ -79,10 +79,14 @@ def prepare_data( logger.info("--- Outputting data ---") logger.info("-----------------------") - output_dataclient.save_data(obj=train, location=output_train_filepath) + output_dataclient.save_data( + obj=train, location=output_train_filepath, save_config=None + ) if test is not None: - output_dataclient.save_data(obj=test, location=output_test_filepath) + output_dataclient.save_data( + obj=test, location=output_test_filepath, save_config=None + ) return train, test