mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
fix type hints
This commit is contained in:
parent
146fc3057e
commit
0bf5fdd6d8
6 changed files with 28 additions and 16 deletions
|
|
@ -53,12 +53,12 @@ def build_model(
|
|||
if train_data is None:
|
||||
if train_filepath is None:
|
||||
raise ValueError(f"Need {train_filepath} if no data supplied")
|
||||
train_data = dataclient.load_data(location=train_filepath)
|
||||
train_data = dataclient.load_data(location=train_filepath, load_config=None)
|
||||
|
||||
if test_data is None:
|
||||
if test_filepath is None:
|
||||
raise ValueError(f"Need {test_filepath} if no data supplied")
|
||||
test_data = dataclient.load_data(location=test_filepath)
|
||||
test_data = dataclient.load_data(location=test_filepath, load_config=None)
|
||||
|
||||
logger.info("----------------------")
|
||||
logger.info("--- Training model ---")
|
||||
|
|
@ -95,7 +95,9 @@ def build_model(
|
|||
logger.info("--- Saving fit metrics ---")
|
||||
logger.info("--------------------------")
|
||||
|
||||
dataclient.save_data(obj=metrics_output, location=fit_metrics_filepath)
|
||||
dataclient.save_data(
|
||||
obj=metrics_output, location=fit_metrics_filepath, save_config=None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import boto3
|
|||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Any
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.Logger import logger
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class AWSS3Client:
|
|||
|
||||
def save_data(
|
||||
self,
|
||||
obj: object,
|
||||
obj: Any,
|
||||
location: str,
|
||||
save_config: Union[dict, None] = None,
|
||||
) -> None:
|
||||
|
|
@ -134,7 +134,7 @@ class AWSS3Client:
|
|||
obj=obj, location=location, save_config=save_config
|
||||
)
|
||||
|
||||
def _save_parquet(self, obj: object, location: str, save_config: dict):
|
||||
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
|
||||
"""
|
||||
Save object as parquet
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ Interface for all DataClient i.e. s3, database, local etc
|
|||
"""
|
||||
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import Protocol, Union
|
||||
from typing import Protocol, Union, Any
|
||||
|
||||
|
||||
class DataClient(Protocol):
|
||||
|
|
@ -22,9 +21,10 @@ class DataClient(Protocol):
|
|||
"""
|
||||
Generic to load data
|
||||
"""
|
||||
...
|
||||
|
||||
def save_data(
|
||||
self, obj: object, location: str, save_config: Union[dict, None]
|
||||
self, obj: Any, location: str, save_config: Union[dict, None]
|
||||
) -> None:
|
||||
"""
|
||||
Generic to save data
|
||||
|
|
|
|||
|
|
@ -59,14 +59,16 @@ def generate_metrics(
|
|||
logger.info("-------------------------")
|
||||
|
||||
test_data = input_dataclient.load_data(
|
||||
location=test_data_filepath,
|
||||
location=test_data_filepath, load_config=None
|
||||
)
|
||||
|
||||
logger.info("---------------------------")
|
||||
logger.info("--- Loading predictions ---")
|
||||
logger.info("---------------------------")
|
||||
|
||||
predictions = input_dataclient.load_data(location=predictions_output_filepath)
|
||||
predictions = input_dataclient.load_data(
|
||||
location=predictions_output_filepath, load_config=None
|
||||
)
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Generating metrics ---")
|
||||
|
|
@ -81,7 +83,9 @@ def generate_metrics(
|
|||
logger.info("--- Saving metrics ---")
|
||||
logger.info("----------------------")
|
||||
|
||||
output_dataclient.save_data(obj=metrics_output, location=metrics_output_filepath)
|
||||
output_dataclient.save_data(
|
||||
obj=metrics_output, location=metrics_output_filepath, save_config=None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -52,7 +52,9 @@ def generate_predictions(
|
|||
logger.info("--- Loading test data ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
test_data = input_dataclient.load_data(location=test_data_filepath)
|
||||
test_data = input_dataclient.load_data(
|
||||
location=test_data_filepath, load_config=None
|
||||
)
|
||||
|
||||
logger.info("---------------------")
|
||||
logger.info("--- Loading model ---")
|
||||
|
|
@ -78,7 +80,7 @@ def generate_predictions(
|
|||
predictions_df.columns = [predictions_column_name]
|
||||
|
||||
output_dataclient.save_data(
|
||||
obj=predictions_df, location=predictions_output_filepath
|
||||
obj=predictions_df, location=predictions_output_filepath, save_config=None
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,10 +79,14 @@ def prepare_data(
|
|||
logger.info("--- Outputting data ---")
|
||||
logger.info("-----------------------")
|
||||
|
||||
output_dataclient.save_data(obj=train, location=output_train_filepath)
|
||||
output_dataclient.save_data(
|
||||
obj=train, location=output_train_filepath, save_config=None
|
||||
)
|
||||
|
||||
if test is not None:
|
||||
output_dataclient.save_data(obj=test, location=output_test_filepath)
|
||||
output_dataclient.save_data(
|
||||
obj=test, location=output_test_filepath, save_config=None
|
||||
)
|
||||
|
||||
return train, test
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue