mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
118 lines
3.4 KiB
Python
118 lines
3.4 KiB
Python
import pandas as pd
|
|
import os
|
|
from typing import Protocol
|
|
|
|
|
|
class DataLoader(Protocol):
|
|
"""
|
|
Interface for all DataLoader classes
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame | None:
|
|
"""
|
|
Loading data from the relevant source
|
|
"""
|
|
|
|
|
|
class LocalDataLoader:
|
|
"""
|
|
Implements the DataLoader Protocol for local files
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
|
|
|
if not os.path.exists(filepath):
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(filepath, index_col=index_col)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
|
|
class S3MockDataLoader:
|
|
"""
|
|
Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
|
|
|
# TODO: Ingest these as environment variables in the docker compose file
|
|
storage_options = {
|
|
"key": os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
|
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
|
"client_kwargs": {
|
|
"endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000")
|
|
},
|
|
}
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath, storage_options=storage_options)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(
|
|
filepath, index_col=index_col, storage_options=storage_options
|
|
)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
|
|
class S3DataLoader:
|
|
"""
|
|
Implements the DataLoader Protocol for s3 files
|
|
"""
|
|
|
|
@staticmethod
|
|
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
|
|
|
|
storage_options = {
|
|
"key": os.environ.get("AWS_ACCESS_KEY_ID"),
|
|
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
|
}
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath, storage_options=storage_options)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(
|
|
filepath, index_col=index_col, storage_options=storage_options
|
|
)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
|
|
def dataloader_factory(runtime_environment: str | None = None) -> DataLoader:
|
|
"""
|
|
Use factory pattern to determine which loading method we use
|
|
"""
|
|
|
|
if runtime_environment is None:
|
|
runtime_environment = "local"
|
|
|
|
dataloader_types = {
|
|
"local": LocalDataLoader(),
|
|
"local-mock": S3MockDataLoader(),
|
|
"dev": S3DataLoader(),
|
|
"staging": S3DataLoader(),
|
|
"prod": S3DataLoader(),
|
|
}
|
|
|
|
if runtime_environment not in dataloader_types:
|
|
raise ValueError("Incorrect runtime environment specified")
|
|
|
|
return dataloader_types[runtime_environment]
|