mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
refactored dataclient
This commit is contained in:
parent
c62f32c1e5
commit
2d7af3ed69
9 changed files with 85 additions and 235 deletions
|
|
@ -12,7 +12,6 @@ from core.Logger import logger
|
|||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.DataHandler import datahandler_factory
|
||||
from core.MLModels import model_factory
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
|
|
@ -46,14 +45,12 @@ def build_model(
|
|||
if train_data is None:
|
||||
if train_filepath is None:
|
||||
raise ValueError(f"Need {train_filepath} if no data supplied")
|
||||
train_data = datahandler.load_data(
|
||||
dataclient=dataclient, location=train_filepath
|
||||
)
|
||||
train_data = dataclient.load_data(location=train_filepath)
|
||||
|
||||
if test_data is None:
|
||||
if test_filepath is None:
|
||||
raise ValueError(f"Need {test_filepath} if no data supplied")
|
||||
test_data = datahandler.load_data(dataclient=dataclient, location=test_filepath)
|
||||
test_data = dataclient.load_data(location=test_filepath)
|
||||
|
||||
logger.info("----------------------")
|
||||
logger.info("--- Training model ---")
|
||||
|
|
@ -80,14 +77,9 @@ if __name__ == "__main__":
|
|||
logger.info(f"--- Initiate DataClient ---")
|
||||
logger.info("----------------------------")
|
||||
|
||||
# Output of previous prepare data step, will be where the data is
|
||||
dataclient = dataclient_factory(prepare_data_params["output_dataclient_type"])
|
||||
|
||||
logger.info("-----------------------------")
|
||||
logger.info(f"--- Initiate DataHandler ---")
|
||||
logger.info("-----------------------------")
|
||||
|
||||
datahandler = datahandler_factory(prepare_data_params["datahandler_type"])
|
||||
|
||||
logger.info("-------------------------")
|
||||
logger.info(f"--- Initiate MLModel ---")
|
||||
logger.info("-------------------------")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Implementations of the DataClient Protocol
|
|||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import boto3
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
|
@ -13,7 +14,7 @@ from core.Logger import logger
|
|||
|
||||
|
||||
def dataclient_factory(
|
||||
dataclient_type: str, dataclient_config: Union[dict, None]
|
||||
dataclient_type: str, dataclient_config: Union[dict, None] = None
|
||||
) -> DataClient:
|
||||
"""
|
||||
Determine which dataclient to use
|
||||
|
|
@ -77,7 +78,7 @@ class AWSS3Client:
|
|||
)
|
||||
|
||||
def load_data(
|
||||
self, location: str, filetype: str, load_config: Union[dict, None] = None
|
||||
self, location: str, load_config: Union[dict, None] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generic to load data
|
||||
|
|
@ -89,6 +90,8 @@ class AWSS3Client:
|
|||
if load_config is None:
|
||||
load_config = {}
|
||||
|
||||
filetype = Path(location).suffix
|
||||
|
||||
load_methods = {
|
||||
".parquet": self._load_parquet,
|
||||
# "": _load_directory(**load_config),
|
||||
|
|
@ -104,7 +107,6 @@ class AWSS3Client:
|
|||
self,
|
||||
obj: object,
|
||||
location: str,
|
||||
filetype: str,
|
||||
save_config: Union[dict, None] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -117,6 +119,8 @@ class AWSS3Client:
|
|||
if save_config is None:
|
||||
save_config = {}
|
||||
|
||||
filetype = Path(location).suffix
|
||||
|
||||
save_methods = {
|
||||
".parquet": self._save_parquet,
|
||||
# "": _save_directory(**save_config),
|
||||
|
|
@ -196,7 +200,7 @@ class LocalClient:
|
|||
logger.info("Local - No establishing client required")
|
||||
|
||||
def load_data(
|
||||
self, location: str, filetype: str, load_config: Union[dict, None] = None
|
||||
self, location: str, load_config: Union[dict, None] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Generic to load data
|
||||
|
|
@ -205,6 +209,8 @@ class LocalClient:
|
|||
if load_config is None:
|
||||
load_config = {}
|
||||
|
||||
filetype = Path(location).suffix
|
||||
|
||||
load_methods = {
|
||||
".parquet": self._load_parquet,
|
||||
# "": _load_directory(**load_config),
|
||||
|
|
@ -220,7 +226,6 @@ class LocalClient:
|
|||
self,
|
||||
obj: object,
|
||||
location: str,
|
||||
filetype: str,
|
||||
save_config: Union[dict, None] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -234,10 +239,13 @@ class LocalClient:
|
|||
|
||||
save_methods = {
|
||||
".parquet": self._save_parquet,
|
||||
".json": self._save_json
|
||||
# "": _save_directory(**save_config),
|
||||
# ADD MORE save_methods HERE
|
||||
}
|
||||
|
||||
filetype = Path(location).suffix
|
||||
|
||||
if filetype not in save_methods:
|
||||
raise ValueError("save_methods specified is not in factory")
|
||||
|
||||
|
|
@ -254,31 +262,29 @@ class LocalClient:
|
|||
|
||||
return df
|
||||
|
||||
def _save_parquet(self, obj: object, location: str, save_config: dict):
|
||||
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
|
||||
"""
|
||||
Save object as parquet
|
||||
"""
|
||||
|
||||
obj.to_parquet(location, **save_config)
|
||||
|
||||
# def load_data_as_buffer(self, location: str) -> BytesIO:
|
||||
# """
|
||||
# When the client is established, we can load data from a buffer
|
||||
# """
|
||||
# with open(location, "rb") as file:
|
||||
# # Read the entire file into a BytesIO object
|
||||
# buffer = BytesIO(file.read())
|
||||
# buffer.seek(0)
|
||||
def _save_json(self, obj: dict, location: str, save_config: dict):
|
||||
"""
|
||||
Save object as json
|
||||
"""
|
||||
# Serialize the dictionary to a JSON-formatted string
|
||||
json_string = json.dumps(obj) # indent for pretty formatting
|
||||
|
||||
# return buffer
|
||||
# Convert the JSON string to bytes (UTF-8 encoding)
|
||||
json_bytes = json_string.encode("utf-8")
|
||||
|
||||
# def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
|
||||
# """
|
||||
# When the client is established, we can save out objects from a buffer
|
||||
# """
|
||||
# if not Path(location).parent.exists():
|
||||
# os.makedirs(Path(location).parent)
|
||||
# Create a BytesIO object and write the JSON bytes to it
|
||||
buffer = BytesIO()
|
||||
buffer.write(json_bytes)
|
||||
|
||||
# # Write the contents of the buffer to the local file
|
||||
# with open(location, "wb") as f:
|
||||
# f.write(buffer.getvalue())
|
||||
buffer.seek(0)
|
||||
|
||||
# Write the contents of the buffer to the local file
|
||||
with open(location, "wb") as f:
|
||||
f.write(buffer.getvalue())
|
||||
|
|
|
|||
|
|
@ -1,86 +0,0 @@
|
|||
"""
|
||||
Implementations of the datahandler Protocol
|
||||
"""
|
||||
|
||||
import json
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from core.interface.InterfaceDataHandler import DataHandler
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
|
||||
|
||||
def datahandler_factory(datahandler_type: str) -> DataHandler:
|
||||
"""
|
||||
Determine which dataclient to use
|
||||
"""
|
||||
datahandler = {
|
||||
"parquet": ParquetHandler(),
|
||||
"json": JSONHandler()
|
||||
# ADD MORE DATACLIENTS HERE
|
||||
}
|
||||
|
||||
if datahandler_type not in datahandler:
|
||||
raise ValueError("Dataloader type specified is not in factory")
|
||||
|
||||
return datahandler[datahandler_type]
|
||||
|
||||
|
||||
def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
|
||||
if not set(keys_1).issubset(keys_2):
|
||||
raise ValueError(f"Incorrect {config_type} keys specified")
|
||||
|
||||
|
||||
class ParquetHandler:
|
||||
"""
|
||||
Load and save Parquet datasets
|
||||
"""
|
||||
|
||||
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
df = pd.read_parquet(dataclient.load_data_as_buffer(location=location))
|
||||
return df
|
||||
|
||||
def save_data(
|
||||
self, dataclient: DataClient, obj: pd.DataFrame, location: str
|
||||
) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
# Convert the Pandas DataFrame to a Parquet buffer
|
||||
parquet_buffer = BytesIO()
|
||||
obj.to_parquet(parquet_buffer, index=False)
|
||||
|
||||
dataclient.upload_data_from_buffer(buffer=parquet_buffer, location=location)
|
||||
|
||||
|
||||
class JSONHandler:
|
||||
"""
|
||||
Load and save Parquet datasets
|
||||
"""
|
||||
|
||||
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def save_data(self, dataclient: DataClient, obj: dict, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
# Serialize the dictionary to a JSON-formatted string
|
||||
json_string = json.dumps(obj) # indent for pretty formatting
|
||||
|
||||
# Convert the JSON string to bytes (UTF-8 encoding)
|
||||
json_bytes = json_string.encode("utf-8")
|
||||
|
||||
# Create a BytesIO object and write the JSON bytes to it
|
||||
buffer = BytesIO()
|
||||
buffer.write(json_bytes)
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
dataclient.upload_data_from_buffer(buffer=buffer, location=location)
|
||||
|
|
@ -18,15 +18,13 @@ class DataClient(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def load_data(
|
||||
self, location: str, filetype: str, load_config: Union[dict, None]
|
||||
) -> pd.DataFrame:
|
||||
def load_data(self, location: str, load_config: Union[dict, None]) -> pd.DataFrame:
|
||||
"""
|
||||
Generic to load data
|
||||
"""
|
||||
|
||||
def save_data(
|
||||
self, obj: object, location: str, filetype: str, save_config: Union[dict, None]
|
||||
self, obj: object, location: str, save_config: Union[dict, None]
|
||||
) -> None:
|
||||
"""
|
||||
Generic to save data
|
||||
|
|
|
|||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
Interface for all DataHandler i.e. Parquet data, csv data
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Protocol, Union, Any
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
|
||||
|
||||
class DataHandler(Protocol):
|
||||
"""
|
||||
Declare the methods required for a DataClient
|
||||
"""
|
||||
|
||||
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def save_data(
|
||||
self, dataclient: DataClient, obj: Union[pd.DataFrame, dict, Any], location: str
|
||||
) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
|
|
@ -5,8 +5,8 @@ stages:
|
|||
deps:
|
||||
- path: prepare_data.py
|
||||
hash: md5
|
||||
md5: 9c31bfb1b75ea3c9685ec459cbb50e62
|
||||
size: 5921
|
||||
md5: 2cfe9e3012280e0cecdb84da12c974d9
|
||||
size: 5009
|
||||
params:
|
||||
configs/prepare_data.yaml:
|
||||
output_test_filepath: ./data/prepared_data/test.parquet
|
||||
|
|
@ -15,20 +15,20 @@ stages:
|
|||
outs:
|
||||
- path: data/prepared_data/
|
||||
hash: md5
|
||||
md5: 5cbabd20ff23b9d6734c5c68684dc8dc.dir
|
||||
size: 11982694
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
nfiles: 2
|
||||
build_model:
|
||||
cmd: python build_model.py
|
||||
deps:
|
||||
- path: build_model.py
|
||||
hash: md5
|
||||
md5: 662cd6b1562fbbc2c7d30dd0f2375a66
|
||||
size: 3948
|
||||
md5: 46bcc34f20c6851cd987640889eefde6
|
||||
size: 3671
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: 5cbabd20ff23b9d6734c5c68684dc8dc.dir
|
||||
size: 11982694
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
nfiles: 2
|
||||
params:
|
||||
configs/build_model.yaml:
|
||||
|
|
@ -48,7 +48,7 @@ stages:
|
|||
outs:
|
||||
- path: data/model/
|
||||
hash: md5
|
||||
md5: f53ceced818ffe9e3ae327492d5a049a.dir
|
||||
md5: eb2b910dec66481e75bb6058622f6e55.dir
|
||||
size: 1832
|
||||
nfiles: 1
|
||||
generate_predictions:
|
||||
|
|
@ -56,18 +56,18 @@ stages:
|
|||
deps:
|
||||
- path: data/model
|
||||
hash: md5
|
||||
md5: f53ceced818ffe9e3ae327492d5a049a.dir
|
||||
md5: eb2b910dec66481e75bb6058622f6e55.dir
|
||||
size: 1832
|
||||
nfiles: 1
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: 5cbabd20ff23b9d6734c5c68684dc8dc.dir
|
||||
size: 11982694
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
nfiles: 2
|
||||
- path: generate_predictions.py
|
||||
hash: md5
|
||||
md5: 32c0ecd082e1f8fc4426338d6629979c
|
||||
size: 4686
|
||||
md5: d412c8c9b48b59a29f569633280a6e7f
|
||||
size: 4237
|
||||
params:
|
||||
configs/generate_predictions.yaml:
|
||||
input_dataclient_type: local
|
||||
|
|
@ -78,26 +78,26 @@ stages:
|
|||
outs:
|
||||
- path: data/predictions/
|
||||
hash: md5
|
||||
md5: e71d1d864228b3f3994217bfcdbcc5b7.dir
|
||||
size: 643090
|
||||
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
|
||||
size: 643406
|
||||
nfiles: 1
|
||||
generate_metrics:
|
||||
cmd: python generate_metrics.py
|
||||
deps:
|
||||
- path: data/predictions
|
||||
hash: md5
|
||||
md5: e71d1d864228b3f3994217bfcdbcc5b7.dir
|
||||
size: 643090
|
||||
md5: 85ec3fa0cb387a7775eccd23185f7966.dir
|
||||
size: 643406
|
||||
nfiles: 1
|
||||
- path: data/prepared_data
|
||||
hash: md5
|
||||
md5: 5cbabd20ff23b9d6734c5c68684dc8dc.dir
|
||||
size: 11982694
|
||||
md5: ea0a2baf3931e692d6344ba609331089.dir
|
||||
size: 13232732
|
||||
nfiles: 2
|
||||
- path: generate_metrics.py
|
||||
hash: md5
|
||||
md5: 4709c42d93f8e717a3d9e4958e46cd76
|
||||
size: 4587
|
||||
md5: 5577a28107458dc1e6bcaaa098388095
|
||||
size: 4144
|
||||
params:
|
||||
configs/generate_metrics.yaml:
|
||||
dataclient_type: local
|
||||
|
|
@ -108,8 +108,8 @@ stages:
|
|||
outs:
|
||||
- path: metrics/metrics.json
|
||||
hash: md5
|
||||
md5: 915100dc1b46b4517a3e1d71d211849d
|
||||
size: 179
|
||||
md5: d79f798a272e6b50597be4d08ae48fa8
|
||||
size: 180
|
||||
startup_cleanup:
|
||||
cmd: python startup_cleanup.py
|
||||
deps:
|
||||
|
|
|
|||
|
|
@ -5,17 +5,14 @@ After the model is built, we can evaluate its performance
|
|||
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceMetrics import MLMetrics
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.interface.InterfaceDataHandler import DataHandler
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.MLModels import model_factory
|
||||
from core.MLMetrics import metrics_factory
|
||||
from core.DataHandler import datahandler_factory
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
|
|
@ -43,9 +40,8 @@ feature_process_params = yaml.safe_load(open(feature_process_path))
|
|||
|
||||
|
||||
def generate_metrics(
|
||||
dataclient: DataClient,
|
||||
input_datahandler: DataHandler,
|
||||
output_datahandler: DataHandler,
|
||||
input_dataclient: DataClient,
|
||||
output_dataclient: DataClient,
|
||||
model: MLModel,
|
||||
metrics: MLMetrics,
|
||||
target: str,
|
||||
|
|
@ -62,17 +58,15 @@ def generate_metrics(
|
|||
logger.info("--- Loading test data ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
test_data = input_datahandler.load_data(
|
||||
dataclient=dataclient, location=test_data_filepath
|
||||
test_data = input_dataclient.load_data(
|
||||
location=test_data_filepath,
|
||||
)
|
||||
|
||||
logger.info("---------------------------")
|
||||
logger.info("--- Loading predictions ---")
|
||||
logger.info("---------------------------")
|
||||
|
||||
predictions = input_datahandler.load_data(
|
||||
dataclient=dataclient, location=predictions_output_filepath
|
||||
)
|
||||
predictions = input_dataclient.load_data(location=predictions_output_filepath)
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Generating metrics ---")
|
||||
|
|
@ -87,9 +81,7 @@ def generate_metrics(
|
|||
logger.info("--- Saving metrics ---")
|
||||
logger.info("----------------------")
|
||||
|
||||
output_datahandler.save_data(
|
||||
dataclient=dataclient, obj=metrics_output, location=metrics_output_filepath
|
||||
)
|
||||
output_dataclient.save_data(obj=metrics_output, location=metrics_output_filepath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -100,23 +92,18 @@ if __name__ == "__main__":
|
|||
|
||||
model = model_factory(build_model_params["model_type"])
|
||||
|
||||
# Use data client for input and output, as we use dvc to cache later to the cloud
|
||||
dataclient_type = generate_metrics_params["dataclient_type"]
|
||||
dataclient = dataclient_factory(dataclient_type)
|
||||
dataclient.ingest_configurations(client_params[dataclient_type])
|
||||
dataclient.establish_client()
|
||||
dataclient = dataclient_factory(
|
||||
dataclient_type=dataclient_type,
|
||||
dataclient_config=client_params[dataclient_type],
|
||||
)
|
||||
|
||||
input_datahandler = datahandler_factory(
|
||||
generate_metrics_params["input_datahandler_type"]
|
||||
)
|
||||
output_datahandler = datahandler_factory(
|
||||
generate_metrics_params["output_datahandler_type"]
|
||||
)
|
||||
metrics = metrics_factory(generate_metrics_params["metrics_type"])
|
||||
|
||||
generate_metrics(
|
||||
dataclient=dataclient,
|
||||
input_datahandler=input_datahandler,
|
||||
output_datahandler=output_datahandler,
|
||||
input_dataclient=dataclient,
|
||||
output_dataclient=dataclient,
|
||||
model=model,
|
||||
metrics=metrics,
|
||||
target=feature_process_params["feature_processor_config"]["target"],
|
||||
|
|
|
|||
|
|
@ -5,15 +5,12 @@ After the model is built, we can evaluate its performance
|
|||
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from core.interface.InterfaceModels import MLModel
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.interface.InterfaceDataHandler import DataHandler
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.MLModels import model_factory
|
||||
from core.DataHandler import datahandler_factory
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
|
|
@ -40,7 +37,6 @@ feature_process_params = yaml.safe_load(open(feature_process_path))
|
|||
def generate_predictions(
|
||||
input_dataclient: DataClient,
|
||||
output_dataclient: DataClient,
|
||||
datahandler: DataHandler,
|
||||
model: MLModel,
|
||||
target: str,
|
||||
model_filepath: str,
|
||||
|
|
@ -56,9 +52,7 @@ def generate_predictions(
|
|||
logger.info("--- Loading test data ---")
|
||||
logger.info("-------------------------")
|
||||
|
||||
test_data = datahandler.load_data(
|
||||
dataclient=input_dataclient, location=test_data_filepath
|
||||
)
|
||||
test_data = input_dataclient.load_data(location=test_data_filepath)
|
||||
|
||||
logger.info("---------------------")
|
||||
logger.info("--- Loading model ---")
|
||||
|
|
@ -83,10 +77,8 @@ def generate_predictions(
|
|||
predictions_df = pd.DataFrame(predictions)
|
||||
predictions_df.columns = [predictions_column_name]
|
||||
|
||||
datahandler.save_data(
|
||||
dataclient=output_dataclient,
|
||||
obj=predictions_df,
|
||||
location=predictions_output_filepath,
|
||||
output_dataclient.save_data(
|
||||
obj=predictions_df, location=predictions_output_filepath
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -102,23 +94,20 @@ if __name__ == "__main__":
|
|||
# For predictions, we will want a cloud data client
|
||||
|
||||
input_dataclient_type = generate_predictions_params["input_dataclient_type"]
|
||||
input_dataclient = dataclient_factory(input_dataclient_type)
|
||||
input_dataclient.ingest_configurations(config=client_params[input_dataclient_type])
|
||||
input_dataclient.establish_client()
|
||||
input_dataclient = dataclient_factory(
|
||||
dataclient_type=input_dataclient_type,
|
||||
dataclient_config=client_params[input_dataclient_type],
|
||||
)
|
||||
|
||||
output_dataclient_type = generate_predictions_params["output_dataclient_type"]
|
||||
output_dataclient = dataclient_factory(output_dataclient_type)
|
||||
output_dataclient.ingest_configurations(
|
||||
config=client_params[output_dataclient_type]
|
||||
output_dataclient = dataclient_factory(
|
||||
dataclient_type=output_dataclient_type,
|
||||
dataclient_config=client_params[output_dataclient_type],
|
||||
)
|
||||
output_dataclient.establish_client()
|
||||
|
||||
datahandler = datahandler_factory(prepare_data_params["datahandler_type"])
|
||||
|
||||
generate_predictions(
|
||||
input_dataclient=input_dataclient,
|
||||
output_dataclient=output_dataclient,
|
||||
datahandler=datahandler,
|
||||
model=model,
|
||||
target=feature_process_params["feature_processor_config"]["target"],
|
||||
model_filepath=build_model_params["model_save_filepath"],
|
||||
|
|
|
|||
|
|
@ -50,11 +50,7 @@ def prepare_data(
|
|||
logger.info("--- Loading data ---")
|
||||
logger.info("--------------------")
|
||||
|
||||
data_filetype = Path(data_filepath).suffix
|
||||
|
||||
data = input_dataclient.load_data(
|
||||
location=data_filepath, filetype=data_filetype, load_config={}
|
||||
)
|
||||
data = input_dataclient.load_data(location=data_filepath, load_config={})
|
||||
|
||||
logger.info("--------------------------")
|
||||
logger.info("--- Feature Processing ---")
|
||||
|
|
@ -83,16 +79,10 @@ def prepare_data(
|
|||
logger.info("--- Outputting data ---")
|
||||
logger.info("-----------------------")
|
||||
|
||||
output_train_filetype = Path(output_train_filepath).suffix
|
||||
output_dataclient.save_data(
|
||||
obj=train, location=output_train_filepath, filetype=output_train_filetype
|
||||
)
|
||||
output_dataclient.save_data(obj=train, location=output_train_filepath)
|
||||
|
||||
if test is not None:
|
||||
output_test_filetype = Path(output_test_filepath).suffix
|
||||
output_dataclient.save_data(
|
||||
obj=test, location=output_test_filepath, filetype=output_test_filetype
|
||||
)
|
||||
output_dataclient.save_data(obj=test, location=output_test_filepath)
|
||||
|
||||
return train, test
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue