Add datahandler

This commit is contained in:
Michael Duong 2023-09-12 00:34:30 +01:00
parent 74bce78d2e
commit 9487b4def2
8 changed files with 309 additions and 51 deletions

View file

@ -65,9 +65,6 @@ def build_model(
model.save_model(path=Path(model_save_location))
# TODO: replace this with the data client to load
# TODO: can fine tune model here if need with the test data
if __name__ == "__main__":

View file

@ -1,5 +1,13 @@
dataclient_type: minio
data_location: s3://dev_bucket
input_dataclient_type: aws-s3
input_dataclient:
AWS_ACCESS_KEY_ID: null
AWS_SECRET_ACCESS_KEY: null
ENDPOINT_URL: null
output_dataclient_type: local
output_dataclient:
null
datahandler_type: parquet
data_filepath: s3://retrofit-data-dev/model_build_data/change_data/rdsap_full/train_validation_data.parquet
train_proportion: 0.8
output_train_filepath: ./data/prepared_data/train.parquet
output_test_filepath: ./data/prepared_data/test.parquet

View file

@ -2,9 +2,13 @@
Implementations of the DataClient Protocol
"""
import os
import boto3
import pandas as pd
from io import BytesIO
from typing import List
from core.interface.InterfaceDataClient import DataClient
from core.Logger import logger
def dataclient_factory(dataclient_type: str) -> DataClient:
@ -12,7 +16,8 @@ def dataclient_factory(dataclient_type: str) -> DataClient:
Determine which dataclient to use
"""
dataclients = {
"minio": MinioClient(),
"local": LocalClient(),
"aws-s3": AWSS3Client(),
# ADD MORE DATACLIENTS HERE
}
@ -27,15 +32,69 @@ def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
raise ValueError(f"Incorrect {config_type} keys specified")
class MinioClient:
# class MinioClient:
# """
# Using the Minio s3 client, to do local testing
# """
# ACCEPTED_CONFIG_KEYS = [
# "aws_access_key_id",
# "aws_secret_access_key",
# "endpoint_url",
# ]
# ACCEPTED_LOAD_CONFIG_KEYS = []
# ACCEPTED_SAVE_CONFIG_KEYS = []
# def ingest_configurations(self, config: dict) -> None:
# """
# Load all configuration into the instance (self.config)
# """
# validate_dict_keys(
# keys_1=list(config.keys()),
# keys_2=self.ACCEPTED_CONFIG_KEYS,
# config_type="config",
# )
# self.config = config
# def establish_client(self) -> None:
# """
# With the given configurations, create the connection to the client (self.client)
# """
# ...
# def download_data(self, download_config: dict) -> pd.DataFrame:
# """
# When the client is established, we can load data
# """
# validate_dict_keys(
# keys_1=list(download_config.keys()),
# keys_2=self.ACCEPTED_LOAD_CONFIG_KEYS,
# config_type="load_config",
# )
# return pd.DataFrame()
# def save_data(self, obj: object, save_config: dict) -> None:
# """
# When the client is established, we can save out objects
# """
# validate_dict_keys(
# keys_1=list(save_config.keys()),
# keys_2=self.ACCEPTED_SAVE_CONFIG_KEYS,
# config_type="save_config",
# )
class AWSS3Client:
"""
Using the Minio s3 client, to do local testing
Using Boto3, set up the AWS client
"""
ACCEPTED_CONFIG_KEYS = [
"aws_access_key_id",
"aws_secret_access_key",
"endpoint_url",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
]
ACCEPTED_LOAD_CONFIG_KEYS = []
ACCEPTED_SAVE_CONFIG_KEYS = []
@ -45,38 +104,120 @@ class MinioClient:
Load all configuration into the instance (self.config)
"""
validate_dict_keys(
keys_1=list(config.keys()),
keys_2=self.ACCEPTED_CONFIG_KEYS,
config_type="config",
keys_1=self.ACCEPTED_CONFIG_KEYS,
keys_2=list(config.keys()),
config_type="Ingest Config",
)
self.config = config
def establish_client(self) -> None:
"""
With the given configurations, create the connection to the client (self.client)
"""
logger.info(f"Establishing S3 Client")
session = boto3.Session()
...
if (
self.config["AWS_ACCESS_KEY_ID"] is None
and self.config["AWS_SECRET_ACCESS_KEY"] is None
):
self.client = session.client(service_name="s3") # Using local credentials
else:
self.client = session.client(
service_name="s3",
aws_access_key_id=self.config["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=self.config["AWS_SECRET_ACCESS_KEY"],
)
def load_data(self, load_config: dict) -> pd.DataFrame:
def download_data(self, location: dict) -> pd.DataFrame:
"""
When the client is established, we can load data
"""
validate_dict_keys(
keys_1=list(load_config.keys()),
keys_2=self.ACCEPTED_LOAD_CONFIG_KEYS,
config_type="load_config",
)
...
return pd.DataFrame()
def load_data_as_buffer(self, location: str) -> BytesIO:
"""
When the client is established, we can load data
"""
if not location.startswith("s3://"):
raise ValueError("S3 file path specified without s3://")
def save_data(self, obj: object, save_config: dict) -> None:
bucket, key = location.strip("s3://").split("/", 1)
buffer = BytesIO()
self.client.download_fileobj(bucket, key, buffer)
buffer.seek(0)
return buffer
def load_database(self, database_location: dict) -> None:
"""
When the client is established, we can read from a database
"""
...
def upload_data(self, location: str) -> None:
"""
When the client is established, we can save out objects
"""
validate_dict_keys(
keys_1=list(save_config.keys()),
keys_2=self.ACCEPTED_SAVE_CONFIG_KEYS,
config_type="save_config",
)
...
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
"""
When the client is established, we can save out objects
"""
if not location.startswith("s3://"):
raise ValueError("S3 file path specified without s3://")
bucket, key = location.strip("s3://").split("/", 1)
self.client.upload_fileobj(buffer, bucket, key)
class LocalClient:
"""
Interacting with data locally
"""
def ingest_configurations(self, config: dict) -> None:
"""
Load all configuration into the instance (self.config)
"""
logger.info("Local - No configuration required")
def establish_client(self) -> None:
"""
With the given configurations, create the connection to the client (self.client)
"""
logger.info("Local - No establishing client required")
def download_data(self, location: dict) -> pd.DataFrame:
"""
When the client is established, we can load data
"""
...
def load_data_as_buffer(self, location: str) -> BytesIO:
"""
When the client is established, we can load data
"""
...
def load_database(self, database_location: dict) -> None:
"""
When the client is established, we can read from a database
"""
...
def upload_data(self, location: str) -> None:
"""
When the client is established, we can save out objects
"""
...
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
"""
When the client is established, we can save out objects
"""
# Write the contents of the buffer to the local file
with open(location, "wb") as f:
f.write(buffer.getvalue())

View file

@ -0,0 +1,54 @@
"""
Implementations of the datahandler Protocol
"""
import pandas as pd
from io import BytesIO
from typing import List
from core.interface.InterfaceDataHandler import DataHandler
from core.interface.InterfaceDataClient import DataClient
def datahandler_factory(datahandler_type: str) -> DataHandler:
"""
Determine which dataclient to use
"""
datahandler = {
"parquet": ParquetHandler(),
# ADD MORE DATACLIENTS HERE
}
if datahandler_type not in datahandler:
raise ValueError("Dataloader type specified is not in factory")
return datahandler[datahandler_type]
def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
if not set(keys_1).issubset(keys_2):
raise ValueError(f"Incorrect {config_type} keys specified")
class ParquetHandler:
"""
Load and save Parquet datasets
"""
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
"""
When the client is established, we can load data
"""
df = pd.read_parquet(dataclient.load_data_as_buffer(location=location))
return df
def save_data(
self, dataclient: DataClient, obj: pd.DataFrame, location: str
) -> None:
"""
When the client is established, we can save out objects
"""
# Convert the Pandas DataFrame to a Parquet buffer
parquet_buffer = BytesIO()
obj.to_parquet(parquet_buffer, index=False)
dataclient.upload_data_from_buffer(buffer=parquet_buffer, location=location)

View file

@ -3,6 +3,7 @@ Interface for all DataClient i.e. s3, database, local etc
"""
import pandas as pd
from io import BytesIO
from typing import Protocol
@ -23,13 +24,32 @@ class DataClient(Protocol):
"""
...
def load_data(self, load_config: dict) -> pd.DataFrame:
def download_data(self, location: dict) -> pd.DataFrame:
"""
When the client is established, we can load data
"""
...
def save_data(self, obj: object, save_config: dict) -> None:
def load_data_as_buffer(self, location: str) -> BytesIO:
"""
When the client is established, we can load data
"""
...
def load_database(self, database_location: dict) -> None:
"""
When the client is established, we can read from a database
"""
...
def upload_data(self, location: str) -> None:
"""
When the client is established, we can save out objects
"""
...
def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
"""
When the client is established, we can save out objects
"""
...

View file

@ -0,0 +1,26 @@
"""
Interface for all DataHandler i.e. Parquet data, csv data
"""
import pandas as pd
from typing import Protocol, Union, Any
from core.interface.InterfaceDataClient import DataClient
class DataHandler(Protocol):
"""
Declare the methods required for a DataClient
"""
def load_data(self, dataclient: DataClient, location: str) -> pd.DataFrame:
"""
When the client is established, we can load data
"""
...
def save_data(
self, dataclient: DataClient, obj: Union[pd.DataFrame, Any], location: str
) -> None:
"""
When the client is established, we can save out objects
"""

View file

@ -1,3 +1,4 @@
/prepared_data
/model
/predictions
.DS_Store

View file

@ -11,8 +11,10 @@ from pathlib import Path
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from core.interface.InterfaceDataClient import DataClient
from core.interface.InterfaceDataHandler import DataHandler
from core.Logger import logger
from core.DataClient import dataclient_factory
from core.DataHandler import datahandler_factory
RUNTIME_ENVIRONMENT = os.environ.get("RUNTIME_ENVIRONMENT", "local")
@ -31,7 +33,10 @@ def use_dummy_data() -> pd.DataFrame:
def prepare_data(
dataclient: DataClient,
input_dataclient: DataClient,
output_dataclient: DataClient,
datahandler: DataHandler,
data_filepath: str,
train_proportion: float,
output_train_filepath: str = "train.parquet",
output_test_filepath: str = "test.parquet",
@ -46,8 +51,7 @@ def prepare_data(
logger.info("--- Loading data ---")
logger.info("--------------------")
# TODO: REPLACE THIS WIL CLIENT AND LOAD DATA
data = use_dummy_data()
data = datahandler.load_data(dataclient=input_dataclient, location=data_filepath)
logger.info("----------------------")
logger.info("--- Splitting data ---")
@ -65,21 +69,12 @@ def prepare_data(
logger.info("--- Outputting data ---")
logger.info("-----------------------")
# TODO: REPLACE WITH CLIENT
output_directory = Path(output_train_filepath)
if not output_directory.parent.exists():
os.makedirs(output_directory.parent)
output_directory = Path(output_test_filepath)
if not output_directory.parent.exists():
os.makedirs(output_directory.parent)
logger.info("--- Outputting train and test data ---")
train.to_parquet(output_train_filepath)
test.to_parquet(output_test_filepath)
# client.save_data(obj=train)
# client.save_data(obj=test)
datahandler.save_data(
dataclient=output_dataclient, obj=train, location=output_train_filepath
)
datahandler.save_data(
dataclient=output_dataclient, obj=test, location=output_test_filepath
)
return train, test
@ -94,14 +89,30 @@ if __name__ == "__main__":
logger.info(f"--- Initiate DataClient ---")
logger.info("----------------------------")
dataclient = dataclient_factory(params["dataclient_type"])
input_dataclient = dataclient_factory(params["input_dataclient_type"])
output_dataclient = dataclient_factory(params["output_dataclient_type"])
input_dataclient.ingest_configurations(config=params["input_dataclient"])
input_dataclient.establish_client()
output_dataclient.ingest_configurations(config=params["output_dataclient"])
output_dataclient.establish_client()
logger.info("-----------------------------")
logger.info(f"--- Initiate DataHandler ---")
logger.info("-----------------------------")
datahandler = datahandler_factory(params["datahandler_type"])
logger.info("---------------------------")
logger.info(f"--- Prepare Data Stage ---")
logger.info("---------------------------")
prepare_data(
dataclient=dataclient,
input_dataclient=input_dataclient,
output_dataclient=output_dataclient,
datahandler=datahandler,
data_filepath=params["data_filepath"],
train_proportion=params["train_proportion"],
output_train_filepath=params["output_train_filepath"],
output_test_filepath=params["output_test_filepath"],