mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
145 lines
4.1 KiB
Python
145 lines
4.1 KiB
Python
import pandas as pd
|
|
import os
|
|
from typing import Protocol
|
|
import boto3
|
|
from io import BytesIO, StringIO
|
|
from core.CloudClient import BotoClient
|
|
|
|
|
|
def read_parquet_from_s3(client, bucket_name, file_key):
|
|
"""
|
|
Read a CSV file from S3 using boto3 and pandas.
|
|
|
|
:param bucket_name: Name of the S3 bucket.
|
|
:param file_key: Key of the file (including directory path within the bucket).
|
|
:param aws_access_key_id: AWS Access Key ID
|
|
:param aws_secret_access_key: AWS Secret Access Key
|
|
:return: DataFrame containing the CSV data.
|
|
"""
|
|
|
|
# Get the object
|
|
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
|
|
|
|
# Read the CSV body into a DataFrame
|
|
csv_body = s3_object["Body"].read()
|
|
df = pd.read_parquet(BytesIO(csv_body))
|
|
|
|
return df
|
|
|
|
|
|
def read_csv_from_s3(client, bucket_name, file_key, index_col):
|
|
"""
|
|
Read a CSV file from S3 using boto3 and pandas.
|
|
|
|
:param bucket_name: Name of the S3 bucket.
|
|
:param file_key: Key of the file (including directory path within the bucket).
|
|
:param aws_access_key_id: AWS Access Key ID
|
|
:param aws_secret_access_key: AWS Secret Access Key
|
|
:return: DataFrame containing the CSV data.
|
|
"""
|
|
|
|
# Get the object
|
|
s3_object = client.get_object(Bucket=bucket_name, Key=file_key)
|
|
|
|
# Read the CSV body into a DataFrame
|
|
csv_body = s3_object["Body"].read().decode("utf-8")
|
|
df = pd.read_csv(StringIO(csv_body), index_col=index_col)
|
|
|
|
return df
|
|
|
|
|
|
class DataLoader(Protocol):
|
|
"""
|
|
Interface for all DataLoader classes
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(
|
|
client: BotoClient, filepath: str, index_col: str | None = None
|
|
) -> pd.DataFrame | None:
|
|
"""
|
|
Loading data from the relevant source
|
|
"""
|
|
|
|
|
|
class LocalDataLoader:
|
|
"""
|
|
Implements the DataLoader Protocol for local files
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(
|
|
client: BotoClient, filepath: str, index_col: str | None = None
|
|
) -> pd.DataFrame:
|
|
|
|
if not os.path.exists(filepath):
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(filepath, index_col=index_col)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
|
|
class S3DataLoader:
|
|
"""
|
|
Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(
|
|
client: BotoClient, filepath: str, index_col: str | None = None
|
|
) -> pd.DataFrame:
|
|
|
|
if not filepath.startswith("s3://"):
|
|
filepath = "s3://" + filepath
|
|
|
|
filepath_split = filepath.split("s3://")[-1].split("/", 1)
|
|
bucket = filepath_split[0]
|
|
key = filepath_split[1]
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = read_parquet_from_s3(
|
|
client=client.client, bucket_name=bucket, file_key=key
|
|
)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = read_csv_from_s3(
|
|
client=client.client,
|
|
bucket_name=bucket,
|
|
file_key=key,
|
|
index_col=index_col,
|
|
)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
|
|
def dataloader_factory(runtime_environment: str | None = None) -> DataLoader:
|
|
"""
|
|
Use factory pattern to determine which loading method we use
|
|
"""
|
|
|
|
if runtime_environment is None:
|
|
runtime_environment = "local"
|
|
|
|
dataloader_types = {
|
|
"local": LocalDataLoader(),
|
|
"local-mock": S3DataLoader(),
|
|
"dev": S3DataLoader(),
|
|
"staging": S3DataLoader(),
|
|
"prod": S3DataLoader(),
|
|
}
|
|
|
|
if runtime_environment not in dataloader_types:
|
|
raise ValueError("Incorrect runtime environment specified")
|
|
|
|
return dataloader_types[runtime_environment]
|