mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
110 lines
3 KiB
Python
110 lines
3 KiB
Python
from core.Logger import logger
|
|
import argparse
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from core.Settings import RANDOM_SEED, TRAIN_AND_VALIDATION_DATA_NAME, TEST_DATA_NAME
|
|
|
|
|
|
def ingest_arguments() -> argparse.Namespace:
|
|
"""
|
|
Helper function to take in arguments from script start
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description="Inputs for training script")
|
|
|
|
parser.add_argument(
|
|
"--filepath",
|
|
type=str,
|
|
help="Location of Parquet dataset to load",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--output-folder",
|
|
type=str,
|
|
help="Location of Parquet dataset to save",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--percentage",
|
|
type=float,
|
|
help="Percentage of data to use as test data",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--volume", type=int, help="Volume of data to use as test data", default=None
|
|
)
|
|
parser.add_argument(
|
|
"--sampling",
|
|
type=str,
|
|
help="Type of sampling to do for test data",
|
|
choices=["random", "stratified"],
|
|
default="random",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main(
|
|
filepath: str, output_folder: str, percentage: float, volume: int, sampling: str
|
|
):
|
|
"""
|
|
Load a dataset in and split out the training+validation data and the test data.
|
|
"""
|
|
|
|
logger.info("---Loading Data---")
|
|
data = pd.read_parquet(filepath).reset_index(drop=True)
|
|
|
|
if percentage and volume is None:
|
|
test_amount = round(len(data) * percentage)
|
|
elif percentage is None and volume:
|
|
test_amount = volume
|
|
elif percentage is None and volume is None:
|
|
logger.error(
|
|
"No amount specified - please specify either a percentage or volume"
|
|
)
|
|
exit(1)
|
|
else:
|
|
logger.info("Both percentage and volume specified - taking largest of the two")
|
|
test_amount = max(round(len(data) * percentage), volume)
|
|
|
|
logger.info(f"---Extracting {test_amount} from dataset to be test data")
|
|
|
|
train_validation_data = pd.DataFrame()
|
|
test_data = pd.DataFrame()
|
|
|
|
if sampling == "random":
|
|
logger.info("--- Using random sample method ---")
|
|
sample_index = data.sample(n=test_amount, random_state=RANDOM_SEED).index
|
|
|
|
train_validation_data = data.drop(sample_index)
|
|
test_data = data.iloc[sample_index]
|
|
|
|
elif sampling == "stratified":
|
|
# Not yet implemented
|
|
pass
|
|
|
|
logger.info("--- Saving data ---")
|
|
|
|
train_validation_data.to_parquet(
|
|
Path(output_folder) / TRAIN_AND_VALIDATION_DATA_NAME
|
|
)
|
|
test_data.to_parquet(Path(output_folder) / TEST_DATA_NAME)
|
|
|
|
logger.info(" ---Pipeline complete---")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
logger.info("--- Generate test data pipeline ---")
|
|
|
|
args = ingest_arguments()
|
|
|
|
main(
|
|
filepath=args.filepath,
|
|
output_folder=args.output_folder,
|
|
percentage=args.percentage,
|
|
volume=args.volume,
|
|
sampling=args.sampling,
|
|
)
|