Model/model_data/simulation_system/core/DataLoader.py
2023-09-01 19:25:35 +01:00

145 lines
4.1 KiB
Python

import pandas as pd
import os
from typing import Protocol
import boto3
from io import BytesIO, StringIO
from core.CloudClient import BotoClient
def read_parquet_from_s3(client, bucket_name, file_key):
"""
Read a CSV file from S3 using boto3 and pandas.
:param bucket_name: Name of the S3 bucket.
:param file_key: Key of the file (including directory path within the bucket).
:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Access Key
:return: DataFrame containing the CSV data.
"""
# Get the object
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
# Read the CSV body into a DataFrame
csv_body = s3_object["Body"].read()
df = pd.read_parquet(BytesIO(csv_body))
return df
def read_csv_from_s3(client, bucket_name, file_key, index_col):
"""
Read a CSV file from S3 using boto3 and pandas.
:param bucket_name: Name of the S3 bucket.
:param file_key: Key of the file (including directory path within the bucket).
:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Access Key
:return: DataFrame containing the CSV data.
"""
# Get the object
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
# Read the CSV body into a DataFrame
csv_body = s3_object["Body"].read().decode("utf-8")
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
return df
class DataLoader(Protocol):
"""
Interface for all DataLoader classes
"""
@staticmethod
def load(
client: BotoClient, filepath: str, index_col: str | None = None
) -> pd.DataFrame | None:
"""
Loading data from the relevant source
"""
class LocalDataLoader:
"""
Implements the DataLoader Protocol for local files
"""
@staticmethod
def load(
client: BotoClient, filepath: str, index_col: str | None = None
) -> pd.DataFrame:
if not os.path.exists(filepath):
raise FileNotFoundError(f"File not found: {filepath}")
if filepath.endswith(".parquet"):
df = pd.read_parquet(filepath)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = pd.read_csv(filepath, index_col=index_col)
else:
raise ValueError(f"File format not supported for file: {filepath}")
return df
class S3DataLoader:
"""
Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service
"""
@staticmethod
def load(
client: BotoClient, filepath: str, index_col: str | None = None
) -> pd.DataFrame:
if not filepath.startswith("s3://"):
filepath = "s3://" + filepath
filepath_split = filepath.split("s3://")[-1].split("/", 1)
bucket = filepath_split[0]
key = filepath_split[1]
if filepath.endswith(".parquet"):
df = read_parquet_from_s3(
client=client.client, bucket_name=bucket, file_key=key
)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = read_csv_from_s3(
client=client.client,
bucket_name=bucket,
file_key=key,
index_col=index_col,
)
else:
raise ValueError(f"File format not supported for file: {filepath}")
return df
def dataloader_factory(runtime_environment: str | None = None) -> DataLoader:
"""
Use factory pattern to determine which loading method we use
"""
if runtime_environment is None:
runtime_environment = "local"
dataloader_types = {
"local": LocalDataLoader(),
"local-mock": S3DataLoader(),
"dev": S3DataLoader(),
"staging": S3DataLoader(),
"prod": S3DataLoader(),
}
if runtime_environment not in dataloader_types:
raise ValueError("Incorrect runtime environment specified")
return dataloader_types[runtime_environment]