mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
64 lines
2 KiB
Python
64 lines
2 KiB
Python
import pandas as pd
|
|
import os
|
|
|
|
|
|
class DataLoader:
|
|
@staticmethod
|
|
def load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
|
"""
|
|
Load different datasets
|
|
"""
|
|
|
|
if not os.path.exists(filepath):
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(filepath, index_col=index_col)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
@staticmethod
|
|
def s3_load(filepath: str, index_col: str = None) -> pd.DataFrame:
|
|
"""
|
|
Load different datasets from s3
|
|
"""
|
|
|
|
STORAGE_OPTIONS = {
|
|
"key": os.environ.get("AWS_ACCESS_KEY_ID", "admin"),
|
|
"secret": os.environ.get("AWS_SECRET_ACCESS_KEY", "password"),
|
|
"client_kwargs": {
|
|
# "endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000")
|
|
"endpoint_url": os.environ.get("ENDPOINT_URL")
|
|
},
|
|
}
|
|
|
|
if filepath.endswith(".parquet"):
|
|
df = pd.read_parquet(filepath, storage_options=STORAGE_OPTIONS)
|
|
if index_col is not None:
|
|
df = df.set_index(index_col)
|
|
elif filepath.endswith(".csv"):
|
|
df = pd.read_csv(
|
|
filepath, index_col=index_col, storage_options=STORAGE_OPTIONS
|
|
)
|
|
else:
|
|
raise ValueError(f"File format not supported for file: {filepath}")
|
|
|
|
return df
|
|
|
|
def process(self, filepath: str, index_col: str = None) -> pd.DataFrame:
|
|
"""
|
|
Based off the filepath, we choose a loader style
|
|
"""
|
|
|
|
if filepath.startswith("s3://"):
|
|
df = self.s3_load(filepath=filepath, index_col=index_col)
|
|
else:
|
|
df = self.load(filepath=filepath, index_col=index_col)
|
|
|
|
return df
|