fix type hints

This commit is contained in:
Michael Duong 2023-09-18 08:13:37 +01:00
parent 146fc3057e
commit 0bf5fdd6d8
6 changed files with 28 additions and 16 deletions

View file

@ -53,12 +53,12 @@ def build_model(
if train_data is None:
if train_filepath is None:
raise ValueError(f"Need {train_filepath} if no data supplied")
train_data = dataclient.load_data(location=train_filepath)
train_data = dataclient.load_data(location=train_filepath, load_config=None)
if test_data is None:
if test_filepath is None:
raise ValueError(f"Need {test_filepath} if no data supplied")
test_data = dataclient.load_data(location=test_filepath)
test_data = dataclient.load_data(location=test_filepath, load_config=None)
logger.info("----------------------")
logger.info("--- Training model ---")
@ -95,7 +95,9 @@ def build_model(
logger.info("--- Saving fit metrics ---")
logger.info("--------------------------")
dataclient.save_data(obj=metrics_output, location=fit_metrics_filepath)
dataclient.save_data(
obj=metrics_output, location=fit_metrics_filepath, save_config=None
)
if __name__ == "__main__":

View file

@ -8,7 +8,7 @@ import boto3
import pandas as pd
from pathlib import Path
from io import BytesIO
from typing import List, Union
from typing import List, Union, Any
from core.interface.InterfaceDataClient import DataClient
from core.Logger import logger
@ -105,7 +105,7 @@ class AWSS3Client:
def save_data(
self,
obj: object,
obj: Any,
location: str,
save_config: Union[dict, None] = None,
) -> None:
@ -134,7 +134,7 @@ class AWSS3Client:
obj=obj, location=location, save_config=save_config
)
def _save_parquet(self, obj: object, location: str, save_config: dict):
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
"""
Save object as parquet
"""

View file

@ -3,8 +3,7 @@ Interface for all DataClient i.e. s3, database, local etc
"""
import pandas as pd
from io import BytesIO
from typing import Protocol, Union
from typing import Protocol, Union, Any
class DataClient(Protocol):
@ -22,9 +21,10 @@ class DataClient(Protocol):
"""
Generic to load data
"""
...
def save_data(
self, obj: object, location: str, save_config: Union[dict, None]
self, obj: Any, location: str, save_config: Union[dict, None]
) -> None:
"""
Generic to save data

View file

@ -59,14 +59,16 @@ def generate_metrics(
logger.info("-------------------------")
test_data = input_dataclient.load_data(
location=test_data_filepath,
location=test_data_filepath, load_config=None
)
logger.info("---------------------------")
logger.info("--- Loading predictions ---")
logger.info("---------------------------")
predictions = input_dataclient.load_data(location=predictions_output_filepath)
predictions = input_dataclient.load_data(
location=predictions_output_filepath, load_config=None
)
logger.info("--------------------------")
logger.info("--- Generating metrics ---")
@ -81,7 +83,9 @@ def generate_metrics(
logger.info("--- Saving metrics ---")
logger.info("----------------------")
output_dataclient.save_data(obj=metrics_output, location=metrics_output_filepath)
output_dataclient.save_data(
obj=metrics_output, location=metrics_output_filepath, save_config=None
)
if __name__ == "__main__":

View file

@ -52,7 +52,9 @@ def generate_predictions(
logger.info("--- Loading test data ---")
logger.info("-------------------------")
test_data = input_dataclient.load_data(location=test_data_filepath)
test_data = input_dataclient.load_data(
location=test_data_filepath, load_config=None
)
logger.info("---------------------")
logger.info("--- Loading model ---")
@ -78,7 +80,7 @@ def generate_predictions(
predictions_df.columns = [predictions_column_name]
output_dataclient.save_data(
obj=predictions_df, location=predictions_output_filepath
obj=predictions_df, location=predictions_output_filepath, save_config=None
)

View file

@ -79,10 +79,14 @@ def prepare_data(
logger.info("--- Outputting data ---")
logger.info("-----------------------")
output_dataclient.save_data(obj=train, location=output_train_filepath)
output_dataclient.save_data(
obj=train, location=output_train_filepath, save_config=None
)
if test is not None:
output_dataclient.save_data(obj=test, location=output_test_filepath)
output_dataclient.save_data(
obj=test, location=output_test_filepath, save_config=None
)
return train, test