Model/model_data/simulation_system/core/DataLoader.py
2023-09-01 11:27:20 +01:00

121 lines
3.5 KiB
Python

import pandas as pd
import os
from typing import Protocol
class DataLoader(Protocol):
"""
Interface for all DataLoader classes
"""
@staticmethod
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame | None:
"""
Loading data from the relevant source
"""
class LocalDataLoader:
"""
Implements the DataLoader Protocol for local files
"""
@staticmethod
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
if not os.path.exists(filepath):
raise FileNotFoundError(f"File not found: {filepath}")
if filepath.endswith(".parquet"):
df = pd.read_parquet(filepath)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = pd.read_csv(filepath, index_col=index_col)
else:
raise ValueError(f"File format not supported for file: {filepath}")
return df
class S3MockDataLoader:
"""
Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service
"""
@staticmethod
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
# TODO: Ingest these as environment variables in the docker compose file
storage_options = {
"key": os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
"client_kwargs": {
"endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000")
},
}
if not filepath.startswith("s3://"):
filepath = "s3://" + filepath
if filepath.endswith(".parquet"):
df = pd.read_parquet(filepath, storage_options=storage_options)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = pd.read_csv(
filepath, index_col=index_col, storage_options=storage_options
)
else:
raise ValueError(f"File format not supported for file: {filepath}")
return df
class S3DataLoader:
"""
Implements the DataLoader Protocol for s3 files
"""
@staticmethod
def load(filepath: str, index_col: str | None = None) -> pd.DataFrame:
storage_options = {
"key": os.environ.get("AWS_ACCESS_KEY_ID"),
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
if filepath.endswith(".parquet"):
df = pd.read_parquet(filepath, storage_options=storage_options)
if index_col is not None:
df = df.set_index(index_col)
elif filepath.endswith(".csv"):
df = pd.read_csv(
filepath, index_col=index_col, storage_options=storage_options
)
else:
raise ValueError(f"File format not supported for file: {filepath}")
return df
def dataloader_factory(runtime_environment: str | None = None) -> DataLoader:
"""
Use factory pattern to determine which loading method we use
"""
if runtime_environment is None:
runtime_environment = "local"
dataloader_types = {
"local": LocalDataLoader(),
"local-mock": S3MockDataLoader(),
"dev": S3MockDataLoader(),
"staging": S3DataLoader(),
"prod": S3DataLoader(),
}
if runtime_environment not in dataloader_types:
raise ValueError("Incorrect runtime environment specified")
return dataloader_types[runtime_environment]