mirror of
https://github.com/Hestia-Homes/ML.git
synced 2026-06-08 11:17:25 +00:00
296 lines
8.2 KiB
Python
296 lines
8.2 KiB
Python
"""
|
|
Implementations of the DataClient Protocol
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import boto3
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from typing import List, Union, Any
|
|
from core.interface.InterfaceDataClient import DataClient
|
|
from core.Logger import logger
|
|
|
|
|
|
def dataclient_factory(
|
|
dataclient_type: str, dataclient_config: Union[dict, None] = None
|
|
) -> DataClient:
|
|
"""
|
|
Determine which dataclient to use
|
|
"""
|
|
|
|
if dataclient_config is None:
|
|
dataclient_config = {}
|
|
|
|
dataclients = {
|
|
"local": LocalClient,
|
|
"aws-s3": AWSS3Client,
|
|
# ADD MORE DATACLIENTS HERE
|
|
}
|
|
|
|
if dataclient_type not in dataclients:
|
|
raise ValueError("Dataclient type specified is not in factory")
|
|
|
|
return dataclients[dataclient_type](**dataclient_config)
|
|
|
|
|
|
def validate_dict_keys(keys_1: List[str], keys_2: List[str], config_type: str):
|
|
if not set(keys_1).issubset(keys_2):
|
|
raise ValueError(f"Incorrect {config_type} keys specified")
|
|
|
|
|
|
class AWSS3Client:
|
|
"""
|
|
Using Boto3, set up the AWS client
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
AWS_ACCESS_KEY_ID: Union[str, None],
|
|
AWS_SECRET_ACCESS_KEY: Union[str, None],
|
|
ENDPOINT_URL: Union[str, None],
|
|
):
|
|
self.AWS_ACCESS_KEY_ID = AWS_ACCESS_KEY_ID
|
|
self.AWS_SECRET_ACCESS_KEY = AWS_SECRET_ACCESS_KEY
|
|
self.ENDPOINT_URL = ENDPOINT_URL
|
|
|
|
self._establish_client()
|
|
|
|
ACCEPTED_LOAD_CONFIG_KEYS = []
|
|
ACCEPTED_SAVE_CONFIG_KEYS = []
|
|
|
|
def _establish_client(self) -> None:
|
|
"""
|
|
With the given configurations, create the connection to the client (self.client)
|
|
"""
|
|
logger.info(f"Establishing S3 Client")
|
|
session = boto3.Session()
|
|
|
|
if self.AWS_ACCESS_KEY_ID is None and self.AWS_SECRET_ACCESS_KEY is None:
|
|
self.client = session.client(service_name="s3") # Using local credentials
|
|
else:
|
|
self.client = session.client(
|
|
service_name="s3",
|
|
aws_access_key_id=self.AWS_ACCESS_KEY_ID,
|
|
aws_secret_access_key=self.AWS_SECRET_ACCESS_KEY,
|
|
endpoint_url=self.ENDPOINT_URL,
|
|
)
|
|
|
|
def load_data(
|
|
self, location: str, load_config: Union[dict, None] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Generic to load data
|
|
"""
|
|
|
|
if not location.startswith("s3://"):
|
|
raise ValueError("S3 file path specified without s3://")
|
|
|
|
if load_config is None:
|
|
load_config = {}
|
|
|
|
filetype = Path(location).suffix
|
|
|
|
load_methods = {
|
|
".parquet": self._load_parquet,
|
|
# "": _load_directory(**load_config),
|
|
# ADD MORE load_methods HERE
|
|
}
|
|
|
|
if filetype not in load_methods:
|
|
raise ValueError("load methods specified is not in factory")
|
|
|
|
return load_methods[filetype](location=location, load_config=load_config)
|
|
|
|
def save_data(
|
|
self,
|
|
obj: Any,
|
|
location: str,
|
|
save_config: Union[dict, None] = None,
|
|
) -> None:
|
|
"""
|
|
Generic to save data
|
|
"""
|
|
|
|
if not location.startswith("s3://"):
|
|
raise ValueError("S3 file path specified without s3://")
|
|
|
|
if save_config is None:
|
|
save_config = {}
|
|
|
|
filetype = Path(location).suffix
|
|
|
|
save_methods = {
|
|
".parquet": self._save_parquet,
|
|
# "": _save_directory(**save_config),
|
|
# ADD MORE save_methods HERE
|
|
}
|
|
|
|
if filetype not in save_methods:
|
|
raise ValueError("save_methods specified is not in factory")
|
|
|
|
return save_methods[filetype](
|
|
obj=obj, location=location, save_config=save_config
|
|
)
|
|
|
|
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
|
|
"""
|
|
Save object as parquet
|
|
"""
|
|
|
|
buffer = BytesIO()
|
|
obj.to_parquet(buffer, index=False)
|
|
|
|
# Reset the buffer position to the beginning
|
|
buffer.seek(0)
|
|
|
|
bucket, key = location.strip("s3://").split("/", 1)
|
|
self.client.upload_fileobj(buffer, bucket, key)
|
|
|
|
# Close the buffer
|
|
buffer.close()
|
|
|
|
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
|
|
"""
|
|
Load a parquet file
|
|
"""
|
|
|
|
bucket, key = location.strip("s3://").split("/", 1)
|
|
buffer = BytesIO()
|
|
self.client.download_fileobj(bucket, key, buffer)
|
|
|
|
df = pd.read_parquet(buffer, **load_config)
|
|
|
|
return df
|
|
|
|
# def load_data_as_buffer(self, location: str) -> BytesIO:
|
|
# """
|
|
# When the client is established, we can load data in a buffer
|
|
# """
|
|
|
|
# bucket, key = location.strip("s3://").split("/", 1)
|
|
# buffer = BytesIO()
|
|
# self.client.download_fileobj(bucket, key, buffer)
|
|
# buffer.seek(0)
|
|
|
|
# return buffer
|
|
|
|
# def upload_data_from_buffer(self, buffer: BytesIO, location: str) -> None:
|
|
# """
|
|
# When the client is established, we can save out objects from a buffer
|
|
# """
|
|
# if not location.startswith("s3://"):
|
|
# raise ValueError("S3 file path specified without s3://")
|
|
|
|
# bucket, key = location.strip("s3://").split("/", 1)
|
|
# self.client.upload_fileobj(buffer, bucket, key)
|
|
|
|
|
|
class LocalClient:
|
|
"""
|
|
Interacting with data locally
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
No initialisation needed for local client
|
|
"""
|
|
logger.info("Local - No configuration required")
|
|
self._establish_client()
|
|
|
|
def _establish_client(self) -> None:
|
|
"""
|
|
With the given configurations, create the connection to the client (self.client)
|
|
"""
|
|
logger.info("Local - No establishing client required")
|
|
|
|
def load_data(
|
|
self, location: str, load_config: Union[dict, None] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Generic to load data
|
|
"""
|
|
|
|
if load_config is None:
|
|
load_config = {}
|
|
|
|
filetype = Path(location).suffix
|
|
|
|
load_methods = {
|
|
".parquet": self._load_parquet,
|
|
# "": _load_directory(**load_config),
|
|
# ADD MORE load_methods HERE
|
|
}
|
|
|
|
if filetype not in load_methods:
|
|
raise ValueError("load methods specified is not in factory")
|
|
|
|
return load_methods[filetype](location=location, load_config=load_config)
|
|
|
|
def save_data(
|
|
self,
|
|
obj: object,
|
|
location: str,
|
|
save_config: Union[dict, None] = None,
|
|
) -> None:
|
|
"""
|
|
Generic to save data
|
|
"""
|
|
if not Path(location).parent.exists():
|
|
os.makedirs(Path(location).parent)
|
|
|
|
if save_config is None:
|
|
save_config = {}
|
|
|
|
save_methods = {
|
|
".parquet": self._save_parquet,
|
|
".json": self._save_json
|
|
# "": _save_directory(**save_config),
|
|
# ADD MORE save_methods HERE
|
|
}
|
|
|
|
filetype = Path(location).suffix
|
|
|
|
if filetype not in save_methods:
|
|
raise ValueError("save_methods specified is not in factory")
|
|
|
|
return save_methods[filetype](
|
|
obj=obj, location=location, save_config=save_config
|
|
)
|
|
|
|
def _load_parquet(self, location: str, load_config: dict) -> pd.DataFrame:
|
|
"""
|
|
Load a parquet file
|
|
"""
|
|
|
|
df = pd.read_parquet(location, **load_config)
|
|
|
|
return df
|
|
|
|
def _save_parquet(self, obj: pd.DataFrame, location: str, save_config: dict):
|
|
"""
|
|
Save object as parquet
|
|
"""
|
|
|
|
obj.to_parquet(location, **save_config)
|
|
|
|
def _save_json(self, obj: dict, location: str, save_config: dict):
|
|
"""
|
|
Save object as json
|
|
"""
|
|
# Serialize the dictionary to a JSON-formatted string
|
|
json_string = json.dumps(obj) # indent for pretty formatting
|
|
|
|
# Convert the JSON string to bytes (UTF-8 encoding)
|
|
json_bytes = json_string.encode("utf-8")
|
|
|
|
# Create a BytesIO object and write the JSON bytes to it
|
|
buffer = BytesIO()
|
|
buffer.write(json_bytes)
|
|
|
|
buffer.seek(0)
|
|
|
|
# Write the contents of the buffer to the local file
|
|
with open(location, "wb") as f:
|
|
f.write(buffer.getvalue())
|