ML/modules/ml-pipeline/src/pipeline/core/DataClient.py
2023-10-10 12:35:34 +01:00

296 lines
8.2 KiB
Python

"""
Implementations of the DataClient Protocol
"""
import os
import json
import boto3
import pandas as pd
from pathlib import Path
from io import BytesIO
from typing import List, Union, Any
from core.interface.InterfaceDataClient import DataClient
from core.Logger import logger
def dataclient_factory(
dataclient_type: str, dataclient_config: Union[dict, None] = None
) -> DataClient:
"""
Determine which dataclient to use
"""
if dataclient_config is None:
dataclient_config = {}
dataclients = {
"local": LocalClient,
"aws-s3": AWSS3Client,
# ADD MORE DATACLIENTS HERE
}
if dataclient_type not in dataclients:
raise ValueError("Dataclient type specified is not in factory")
return dataclients[dataclient_type](**dataclient_config)
def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
if not set(keys_1).issubset(keys_2):
raise ValueError(f"Incorrect {config_type} keys specified")
class AWSS3Client:
"""
Using Boto3, set up the AWS client
"""
def __init__(
self,
AWS_ACCESS_KEY_ID: Union[str, None],
AWS_SECRET_ACCESS_KEY: Union[str, None],
ENDPOINT_URL: Union[str, None],
):
self.AWS_ACCESS_KEY_ID = AWS_ACCESS_KEY_ID
self.AWS_SECRET_ACCESS_KEY = AWS_SECRET_ACCESS_KEY
self.ENDPOINT_URL = ENDPOINT_URL
self._establish_client()
ACCEPTED_LOAD_CONFIG_KEYS = []
ACCEPTED_SAVE_CONFIG_KEYS = []
def _establish_client(self) -> None:
"""
With the given configurations, create the connection to the client (self.client)
"""
logger.info(f"Establishing S3 Client")
session = boto3.Session()
if self.AWS_ACCESS_KEY_ID is None and self.AWS_SECRET_ACCESS_KEY is None:
self.client = session.client(service_name="s3") # Using local credentials
else:
self.client = session.client(
service_name="s3",
aws_access_key_id=self.AWS_ACCESS_KEY_ID,
aws_secret_access_key=self.AWS_SECRET_ACCESS_KEY,
endpoint_url=self.ENDPOINT_URL,
)
def load_data(
self, location: str, load_config: Union[dict, None] = None
) -> pd.DataFrame:
"""
Generic to load data
"""
if not location.startswith("s3://"):
raise ValueError("S3 file path specified without s3://")
if load_config is None:
load_config = {}
filetype = Path(location).suffix
load_methods = {
".parquet": self._load_parquet,
# "": _load_directory(**load_config),
# ADD MORE load_methods HERE
}
if filetype not in load_methods:
raise ValueError("load methods specified is not in factory")
return load_methods[filetype](location=location, load_config=load_config)
def save_data(
self,
obj: Any,
location: str,
save_config: Union[dict, None] = None,
) -> None:
"""
Generic to save data
"""
if not location.startswith("s3://"):
raise ValueError("S3 file path specified without s3://")
if save_config is None:
save_config = {}
filetype = Path(location).suffix
save_methods = {
".parquet": self._save_parquet,
# "": _save_directory(**save_config),
# ADD MORE save_methods HERE
}
if filetype not in save_methods:
raise ValueError("save_methods specified is not in factory")
return save_methods[filetype](
obj=obj, location=location, save_config=save_config
)
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
"""
Save object as parquet
"""
buffer = BytesIO()
obj.to_parquet(buffer, index=False)
# Reset the buffer position to the beginning
buffer.seek(0)
bucket, key = location.strip("s3://").split("/", 1)
self.client.upload_fileobj(buffer, bucket, key)
# Close the buffer
buffer.close()
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
"""
Load a parquet file
"""
bucket, key = location.strip("s3://").split("/", 1)
buffer = BytesIO()
self.client.download_fileobj(bucket, key, buffer)
df = pd.read_parquet(buffer, **load_config)
return df
# def load_data_as_buffer(self, location: str) -> BytesIO:
# """
# When the client is established, we can load data in a buffer
# """
# bucket, key = location.strip("s3://").split("/", 1)
# buffer = BytesIO()
# self.client.download_fileobj(bucket, key, buffer)
# buffer.seek(0)
# return buffer
# def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
# """
# When the client is established, we can save out objects from a buffer
# """
# if not location.startswith("s3://"):
# raise ValueError("S3 file path specified without s3://")
# bucket, key = location.strip("s3://").split("/", 1)
# self.client.upload_fileobj(buffer, bucket, key)
class LocalClient:
"""
Interacting with data locally
"""
def __init__(self):
"""
No initialisation needed for local client
"""
logger.info("Local - No configuration required")
self._establish_client()
def _establish_client(self) -> None:
"""
With the given configurations, create the connection to the client (self.client)
"""
logger.info("Local - No establishing client required")
def load_data(
self, location: str, load_config: Union[dict, None] = None
) -> pd.DataFrame:
"""
Generic to load data
"""
if load_config is None:
load_config = {}
filetype = Path(location).suffix
load_methods = {
".parquet": self._load_parquet,
# "": _load_directory(**load_config),
# ADD MORE load_methods HERE
}
if filetype not in load_methods:
raise ValueError("load methods specified is not in factory")
return load_methods[filetype](location=location, load_config=load_config)
def save_data(
self,
obj: object,
location: str,
save_config: Union[dict, None] = None,
) -> None:
"""
Generic to save data
"""
if not Path(location).parent.exists():
os.makedirs(Path(location).parent)
if save_config is None:
save_config = {}
save_methods = {
".parquet": self._save_parquet,
".json": self._save_json
# "": _save_directory(**save_config),
# ADD MORE save_methods HERE
}
filetype = Path(location).suffix
if filetype not in save_methods:
raise ValueError("save_methods specified is not in factory")
return save_methods[filetype](
obj=obj, location=location, save_config=save_config
)
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
"""
Load a parquet file
"""
df = pd.read_parquet(location, **load_config)
return df
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
"""
Save object as parquet
"""
obj.to_parquet(location, **save_config)
def _save_json(self, obj: dict, location: str, save_config: dict):
"""
Save object as json
"""
# Serialize the dictionary to a JSON-formatted string
json_string = json.dumps(obj) # indent for pretty formatting
# Convert the JSON string to bytes (UTF-8 encoding)
json_bytes = json_string.encode("utf-8")
# Create a BytesIO object and write the JSON bytes to it
buffer = BytesIO()
buffer.write(json_bytes)
buffer.seek(0)
# Write the contents of the buffer to the local file
with open(location, "wb") as f:
f.write(buffer.getvalue())