mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
Add datahandler
This commit is contained in:
parent
74bce78d2e
commit
9487b4def2
8 changed files with 309 additions and 51 deletions
|
|
@ -65,9 +65,6 @@ def build_model(
|
|||
|
||||
model.save_model(path=Path(model_save_location))
|
||||
|
||||
# TODO: replace this with the data client to load
|
||||
# TODO: can fine tune model here if need with the test data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
dataclient_type: minio
|
||||
data_location: s3://dev_bucket
|
||||
input_dataclient_type: aws-s3
|
||||
input_dataclient:
|
||||
AWS_ACCESS_KEY_ID: null
|
||||
AWS_SECRET_ACCESS_KEY: null
|
||||
ENDPOINT_URL: null
|
||||
output_dataclient_type: local
|
||||
output_dataclient:
|
||||
null
|
||||
datahandler_type: parquet
|
||||
data_filepath: s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/train_validation_data.parquet
|
||||
train_proportion: 0.8
|
||||
output_train_filepath: ./data/prepared_data/train.parquet
|
||||
output_test_filepath: ./data/prepared_data/test.parquet
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@
|
|||
Implementations of the DataClient Protocol
|
||||
"""
|
||||
|
||||
import os
|
||||
import boto3
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.Logger import logger
|
||||
|
||||
|
||||
def dataclient_factory(dataclient_type: str) -> DataClient:
|
||||
|
|
@ -12,7 +16,8 @@ def dataclient_factory(dataclient_type: str) -> DataClient:
|
|||
Determine which dataclient to use
|
||||
"""
|
||||
dataclients = {
|
||||
"minio": MinioClient(),
|
||||
"local": LocalClient(),
|
||||
"aws-s3": AWSS3Client(),
|
||||
# ADD MORE DATACLIENTS HERE
|
||||
}
|
||||
|
||||
|
|
@ -27,15 +32,69 @@ def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
|
|||
raise ValueError(f"Incorrect {config_type} keys specified")
|
||||
|
||||
|
||||
class MinioClient:
|
||||
# class MinioClient:
|
||||
# """
|
||||
# Using the Minio s3 client, to do local testing
|
||||
# """
|
||||
|
||||
# ACCEPTED_CONFIG_KEYS = [
|
||||
# "aws_access_key_id",
|
||||
# "aws_secret_access_key",
|
||||
# "endpoint_url",
|
||||
# ]
|
||||
# ACCEPTED_LOAD_CONFIG_KEYS = []
|
||||
# ACCEPTED_SAVE_CONFIG_KEYS = []
|
||||
|
||||
# def ingest_configurations(self, config: dict) -> None:
|
||||
# """
|
||||
# Load all configuration into the instance (self.config)
|
||||
# """
|
||||
# validate_dict_keys(
|
||||
# keys_1=list(config.keys()),
|
||||
# keys_2=self.ACCEPTED_CONFIG_KEYS,
|
||||
# config_type="config",
|
||||
# )
|
||||
|
||||
# self.config = config
|
||||
|
||||
# def establish_client(self) -> None:
|
||||
# """
|
||||
# With the given configurations, create the connection to the client (self.client)
|
||||
# """
|
||||
|
||||
# ...
|
||||
|
||||
# def download_data(self, download_config: dict) -> pd.DataFrame:
|
||||
# """
|
||||
# When the client is established, we can load data
|
||||
# """
|
||||
# validate_dict_keys(
|
||||
# keys_1=list(download_config.keys()),
|
||||
# keys_2=self.ACCEPTED_LOAD_CONFIG_KEYS,
|
||||
# config_type="load_config",
|
||||
# )
|
||||
|
||||
# return pd.DataFrame()
|
||||
|
||||
# def save_data(self, obj: object, save_config: dict) -> None:
|
||||
# """
|
||||
# When the client is established, we can save out objects
|
||||
# """
|
||||
# validate_dict_keys(
|
||||
# keys_1=list(save_config.keys()),
|
||||
# keys_2=self.ACCEPTED_SAVE_CONFIG_KEYS,
|
||||
# config_type="save_config",
|
||||
# )
|
||||
|
||||
|
||||
class AWSS3Client:
|
||||
"""
|
||||
Using the Minio s3 client, to do local testing
|
||||
Using Boto3, set up the AWS client
|
||||
"""
|
||||
|
||||
ACCEPTED_CONFIG_KEYS = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"endpoint_url",
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
]
|
||||
ACCEPTED_LOAD_CONFIG_KEYS = []
|
||||
ACCEPTED_SAVE_CONFIG_KEYS = []
|
||||
|
|
@ -45,38 +104,120 @@ class MinioClient:
|
|||
Load all configuration into the instance (self.config)
|
||||
"""
|
||||
validate_dict_keys(
|
||||
keys_1=list(config.keys()),
|
||||
keys_2=self.ACCEPTED_CONFIG_KEYS,
|
||||
config_type="config",
|
||||
keys_1=self.ACCEPTED_CONFIG_KEYS,
|
||||
keys_2=list(config.keys()),
|
||||
config_type="Ingest Config",
|
||||
)
|
||||
|
||||
self.config = config
|
||||
|
||||
def establish_client(self) -> None:
|
||||
"""
|
||||
With the given configurations, create the connection to the client (self.client)
|
||||
"""
|
||||
logger.info(f"Establishing S3 Client")
|
||||
session = boto3.Session()
|
||||
|
||||
...
|
||||
if (
|
||||
self.config["AWS_ACCESS_KEY_ID"] is None
|
||||
and self.config["AWS_SECRET_ACCESS_KEY"] is None
|
||||
):
|
||||
self.client = session.client(service_name="s3") # Using local credentials
|
||||
else:
|
||||
self.client = session.client(
|
||||
service_name="s3",
|
||||
aws_access_key_id=self.config["AWS_ACCESS_KEY_ID"],
|
||||
aws_secret_access_key=self.config["AWS_SECRET_ACCESS_KEY"],
|
||||
)
|
||||
|
||||
def load_data(self, load_config: dict) -> pd.DataFrame:
|
||||
def download_data(self, location: dict) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
validate_dict_keys(
|
||||
keys_1=list(load_config.keys()),
|
||||
keys_2=self.ACCEPTED_LOAD_CONFIG_KEYS,
|
||||
config_type="load_config",
|
||||
)
|
||||
...
|
||||
|
||||
return pd.DataFrame()
|
||||
def load_data_as_buffer(self, location: str) -> BytesIO:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
if not location.startswith("s3://"):
|
||||
raise ValueError("S3 file path specified without s3://")
|
||||
|
||||
def save_data(self, obj: object, save_config: dict) -> None:
|
||||
bucket, key = location.strip("s3://").split("/", 1)
|
||||
buffer = BytesIO()
|
||||
self.client.download_fileobj(bucket, key, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
return buffer
|
||||
|
||||
def load_database(self, database_location: dict) -> None:
|
||||
"""
|
||||
When the client is established, we can read from a database
|
||||
"""
|
||||
...
|
||||
|
||||
def upload_data(self, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
validate_dict_keys(
|
||||
keys_1=list(save_config.keys()),
|
||||
keys_2=self.ACCEPTED_SAVE_CONFIG_KEYS,
|
||||
config_type="save_config",
|
||||
)
|
||||
...
|
||||
|
||||
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
if not location.startswith("s3://"):
|
||||
raise ValueError("S3 file path specified without s3://")
|
||||
|
||||
bucket, key = location.strip("s3://").split("/", 1)
|
||||
self.client.upload_fileobj(buffer, bucket, key)
|
||||
|
||||
|
||||
class LocalClient:
|
||||
"""
|
||||
Interacting with data locally
|
||||
"""
|
||||
|
||||
def ingest_configurations(self, config: dict) -> None:
|
||||
"""
|
||||
Load all configuration into the instance (self.config)
|
||||
"""
|
||||
logger.info("Local - No configuration required")
|
||||
|
||||
def establish_client(self) -> None:
|
||||
"""
|
||||
With the given configurations, create the connection to the client (self.client)
|
||||
"""
|
||||
logger.info("Local - No establishing client required")
|
||||
|
||||
def download_data(self, location: dict) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def load_data_as_buffer(self, location: str) -> BytesIO:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def load_database(self, database_location: dict) -> None:
|
||||
"""
|
||||
When the client is established, we can read from a database
|
||||
"""
|
||||
...
|
||||
|
||||
def upload_data(self, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
...
|
||||
|
||||
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
|
||||
# Write the contents of the buffer to the local file
|
||||
with open(location, "wb") as f:
|
||||
f.write(buffer.getvalue())
|
||||
|
|
|
|||
54
modules/ml-pipeline/src/pipeline/src/core/DataHandler.py
Normal file
54
modules/ml-pipeline/src/pipeline/src/core/DataHandler.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Implementations of the datahandler Protocol
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from core.interface.InterfaceDataHandler import DataHandler
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
|
||||
|
||||
def datahandler_factory(datahandler_type: str) -> DataHandler:
|
||||
"""
|
||||
Determine which dataclient to use
|
||||
"""
|
||||
datahandler = {
|
||||
"parquet": ParquetHandler(),
|
||||
# ADD MORE DATACLIENTS HERE
|
||||
}
|
||||
|
||||
if datahandler_type not in datahandler:
|
||||
raise ValueError("Dataloader type specified is not in factory")
|
||||
|
||||
return datahandler[datahandler_type]
|
||||
|
||||
|
||||
def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
|
||||
if not set(keys_1).issubset(keys_2):
|
||||
raise ValueError(f"Incorrect {config_type} keys specified")
|
||||
|
||||
|
||||
class ParquetHandler:
|
||||
"""
|
||||
Load and save Parquet datasets
|
||||
"""
|
||||
|
||||
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
df = pd.read_parquet(dataclient.load_data_as_buffer(location=location))
|
||||
return df
|
||||
|
||||
def save_data(
|
||||
self, dataclient: DataClient, obj: pd.DataFrame, location: str
|
||||
) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
# Convert the Pandas DataFrame to a Parquet buffer
|
||||
parquet_buffer = BytesIO()
|
||||
obj.to_parquet(parquet_buffer, index=False)
|
||||
|
||||
dataclient.upload_data_from_buffer(buffer=parquet_buffer, location=location)
|
||||
|
|
@ -3,6 +3,7 @@ Interface for all DataClient i.e. s3, database, local etc
|
|||
"""
|
||||
|
||||
import pandas as pd
|
||||
from io import BytesIO
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
|
|
@ -23,13 +24,32 @@ class DataClient(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def load_data(self, load_config: dict) -> pd.DataFrame:
|
||||
def download_data(self, location: dict) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def save_data(self, obj: object, save_config: dict) -> None:
|
||||
def load_data_as_buffer(self, location: str) -> BytesIO:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def load_database(self, database_location: dict) -> None:
|
||||
"""
|
||||
When the client is established, we can read from a database
|
||||
"""
|
||||
...
|
||||
|
||||
def upload_data(self, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
...
|
||||
|
||||
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
Interface for all DataHandler i.e. Parquet data, csv data
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Protocol, Union, Any
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
|
||||
|
||||
class DataHandler(Protocol):
|
||||
"""
|
||||
Declare the methods required for a DataClient
|
||||
"""
|
||||
|
||||
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
|
||||
"""
|
||||
When the client is established, we can load data
|
||||
"""
|
||||
...
|
||||
|
||||
def save_data(
|
||||
self, dataclient: DataClient, obj: Union[pd.DataFrame, Any], location: str
|
||||
) -> None:
|
||||
"""
|
||||
When the client is established, we can save out objects
|
||||
"""
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
/prepared_data
|
||||
/model
|
||||
/predictions
|
||||
.DS_Store
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ from pathlib import Path
|
|||
from sklearn.datasets import load_diabetes
|
||||
from sklearn.model_selection import train_test_split
|
||||
from core.interface.InterfaceDataClient import DataClient
|
||||
from core.interface.InterfaceDataHandler import DataHandler
|
||||
from core.Logger import logger
|
||||
from core.DataClient import dataclient_factory
|
||||
from core.DataHandler import datahandler_factory
|
||||
|
||||
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
|
||||
|
||||
|
|
@ -31,7 +33,10 @@ def use_dummy_data() -> pd.DataFrame:
|
|||
|
||||
|
||||
def prepare_data(
|
||||
dataclient: DataClient,
|
||||
input_dataclient: DataClient,
|
||||
output_dataclient: DataClient,
|
||||
datahandler: DataHandler,
|
||||
data_filepath: str,
|
||||
train_proportion: float,
|
||||
output_train_filepath: str = "train.parquet",
|
||||
output_test_filepath: str = "test.parquet",
|
||||
|
|
@ -46,8 +51,7 @@ def prepare_data(
|
|||
logger.info("--- Loading data ---")
|
||||
logger.info("--------------------")
|
||||
|
||||
# TODO: REPLACE THIS WIL CLIENT AND LOAD DATA
|
||||
data = use_dummy_data()
|
||||
data = datahandler.load_data(dataclient=input_dataclient, location=data_filepath)
|
||||
|
||||
logger.info("----------------------")
|
||||
logger.info("--- Splitting data ---")
|
||||
|
|
@ -65,21 +69,12 @@ def prepare_data(
|
|||
logger.info("--- Outputting data ---")
|
||||
logger.info("-----------------------")
|
||||
|
||||
# TODO: REPLACE WITH CLIENT
|
||||
output_directory = Path(output_train_filepath)
|
||||
if not output_directory.parent.exists():
|
||||
os.makedirs(output_directory.parent)
|
||||
|
||||
output_directory = Path(output_test_filepath)
|
||||
if not output_directory.parent.exists():
|
||||
os.makedirs(output_directory.parent)
|
||||
|
||||
logger.info("--- Outputting train and test data ---")
|
||||
train.to_parquet(output_train_filepath)
|
||||
test.to_parquet(output_test_filepath)
|
||||
|
||||
# client.save_data(obj=train)
|
||||
# client.save_data(obj=test)
|
||||
datahandler.save_data(
|
||||
dataclient=output_dataclient, obj=train, location=output_train_filepath
|
||||
)
|
||||
datahandler.save_data(
|
||||
dataclient=output_dataclient, obj=test, location=output_test_filepath
|
||||
)
|
||||
|
||||
return train, test
|
||||
|
||||
|
|
@ -94,14 +89,30 @@ if __name__ == "__main__":
|
|||
logger.info(f"--- Initiate DataClient ---")
|
||||
logger.info("----------------------------")
|
||||
|
||||
dataclient = dataclient_factory(params["dataclient_type"])
|
||||
input_dataclient = dataclient_factory(params["input_dataclient_type"])
|
||||
output_dataclient = dataclient_factory(params["output_dataclient_type"])
|
||||
|
||||
input_dataclient.ingest_configurations(config=params["input_dataclient"])
|
||||
input_dataclient.establish_client()
|
||||
|
||||
output_dataclient.ingest_configurations(config=params["output_dataclient"])
|
||||
output_dataclient.establish_client()
|
||||
|
||||
logger.info("-----------------------------")
|
||||
logger.info(f"--- Initiate DataHandler ---")
|
||||
logger.info("-----------------------------")
|
||||
|
||||
datahandler = datahandler_factory(params["datahandler_type"])
|
||||
|
||||
logger.info("---------------------------")
|
||||
logger.info(f"--- Prepare Data Stage ---")
|
||||
logger.info("---------------------------")
|
||||
|
||||
prepare_data(
|
||||
dataclient=dataclient,
|
||||
input_dataclient=input_dataclient,
|
||||
output_dataclient=output_dataclient,
|
||||
datahandler=datahandler,
|
||||
data_filepath=params["data_filepath"],
|
||||
train_proportion=params["train_proportion"],
|
||||
output_train_filepath=params["output_train_filepath"],
|
||||
output_test_filepath=params["output_test_filepath"],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue