import pandas as pd import os from typing import Protocol import boto3 from io import BytesIO, StringIO from core.CloudClient import BotoClient def read_parquet_from_s3(bucket_name, file_key): """ Read a CSV file from S3 using boto3 and pandas. :param bucket_name: Name of the S3 bucket. :param file_key: Key of the file (including directory path within the bucket). :param aws_access_key_id: AWS Access Key ID :param aws_secret_access_key: AWS Secret Access Key :return: DataFrame containing the CSV data. """ # Initialize the S3 client s3_client = boto3.client("s3") # Get the object s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key) # Read the CSV body into a DataFrame csv_body = s3_object["Body"].read() df = pd.read_parquet(BytesIO(csv_body)) return df def read_csv_from_s3(bucket_name, file_key, index_col): """ Read a CSV file from S3 using boto3 and pandas. :param bucket_name: Name of the S3 bucket. :param file_key: Key of the file (including directory path within the bucket). :param aws_access_key_id: AWS Access Key ID :param aws_secret_access_key: AWS Secret Access Key :return: DataFrame containing the CSV data. """ # Initialize the S3 client s3_client = boto3.client("s3") # Get the object s3_object = s3_client.get_object(Bucket=bucket_name, Key=file_key) # Read the CSV body into a DataFrame csv_body = s3_object["Body"].read().decode("utf-8") df = pd.read_csv(StringIO(csv_body), index_col=index_col) return df class DataLoader(Protocol): """ Interface for all DataLoader classes """ @staticmethod def load( client: BotoClient, filepath: str, index_col: str | None = None ) -> pd.DataFrame | None: """ Loading data from the relevant source """ class LocalDataLoader: """ Implements the DataLoader Protocol for local files """ @staticmethod def load(filepath: str, index_col: str | None = None) -> pd.DataFrame: if not os.path.exists(filepath): raise FileNotFoundError(f"File not found: {filepath}") if filepath.endswith(".parquet"): df = pd.read_parquet(filepath) if index_col is not None: df = df.set_index(index_col) elif filepath.endswith(".csv"): df = pd.read_csv(filepath, index_col=index_col) else: raise ValueError(f"File format not supported for file: {filepath}") return df class S3MockDataLoader: """ Implements the DataLoader Protocol for s3 files, hosting locally in a mocked service """ @staticmethod def load( client: BotoClient, filepath: str, index_col: str | None = None ) -> pd.DataFrame: # TODO: Ingest these as environment variables in the docker compose file storage_options = { "key": os.environ.get("AWS_ACCESS_KEY_ID", "admin"), "secret": os.environ.get("AWS_SECRET_ACCESS_KEY", "password"), "client_kwargs": { "endpoint_url": os.environ.get("ENDPOINT_URL", "http://localhost:9000") }, } if not filepath.startswith("s3://"): filepath = "s3://" + filepath if filepath.endswith(".parquet"): df = pd.read_parquet(filepath, storage_options=storage_options) if index_col is not None: df = df.set_index(index_col) elif filepath.endswith(".csv"): df = pd.read_csv( filepath, index_col=index_col, storage_options=storage_options ) else: raise ValueError(f"File format not supported for file: {filepath}") return df class S3DataLoader: """ Implements the DataLoader Protocol for s3 files """ @staticmethod def load( client: BotoClient, filepath: str, index_col: str | None = None ) -> pd.DataFrame: filepath_split = filepath.split("s3://")[-1].split("/", 1) bucket = filepath_split[0] key = filepath_split[1] if filepath.endswith(".parquet"): df = read_parquet_from_s3(bucket, key) if index_col is not None: df = df.set_index(index_col) elif filepath.endswith(".csv"): df = read_csv_from_s3(bucket, key, index_col) else: raise ValueError(f"File format not supported for file: {filepath}") return df def dataloader_factory(runtime_environment: str | None = None) -> DataLoader: """ Use factory pattern to determine which loading method we use """ if runtime_environment is None: runtime_environment = "local" dataloader_types = { "local": LocalDataLoader(), "local-mock": S3MockDataLoader(), "dev": S3DataLoader(), "staging": S3DataLoader(), "prod": S3DataLoader(), } if runtime_environment not in dataloader_types: raise ValueError("Incorrect runtime environment specified") return dataloader_types[runtime_environment]