mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
from Logger import logger
|
|
import argparse
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
|
|
RANDOM_SEED = 0
|
|
|
|
def ingest_arguments() -> argparse.Namespace:
|
|
"""
|
|
Helper function to take in arguments from script start
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(description='Inputs for training script')
|
|
|
|
parser.add_argument('--filepath', type=str, help='Location of Parquet dataset to load', required=True)
|
|
parser.add_argument('--output-folder', type=str, help='Location of Parquet dataset to save', required=True)
|
|
parser.add_argument('--percentage', type=float, help='Percentage of data to use as test data', default=None)
|
|
parser.add_argument('--volume', type=int, help='Volume of data to use as test data', default=None)
|
|
parser.add_argument('--sampling', type=str, help='Type of sampling to do for test data', choices=['random', 'stratified'], default='random')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
def main(filepath: str, output_folder: str, percentage: float, volume: int, sampling: str):
|
|
"""
|
|
Load a dataset in and split out the training+validation data and the test data.
|
|
"""
|
|
|
|
logger.info('---Loading Data---')
|
|
data = pd.read_parquet(filepath)
|
|
|
|
if percentage and volume is None:
|
|
test_amount = round(len(data)*percentage)
|
|
elif percentage is None and volume:
|
|
test_amount = volume
|
|
elif percentage is None and volume is None:
|
|
logger.error('No amount specified - please specify either a percentage or volume')
|
|
exit(1)
|
|
else:
|
|
logger.info('Both percentage and volume specified - taking largest of the two')
|
|
test_amount = max(round(len(data)*percentage), volume)
|
|
|
|
logger.info(f'---Extracting {test_amount} from dataset to be test data')
|
|
|
|
if sampling == 'random':
|
|
logger.info('--- Using random sample method ---')
|
|
sample_index = data.sample(n=test_amount, random_state=RANDOM_SEED).index
|
|
|
|
train_validation_data = data.drop(sample_index)
|
|
test_data = data.loc[sample_index]
|
|
|
|
elif sampling =='stratified':
|
|
# Not yet implemented
|
|
pass
|
|
|
|
logger.info('--- Saving data ---')
|
|
|
|
train_validation_data.to_parquet(Path(output_folder)/'train_validation_data.parquet')
|
|
test_data.to_parquet(Path(output_folder)/'test_data.parquet')
|
|
|
|
logger.info(' ---Pipeline complete---')
|
|
|
|
if __name__ == "__main__":
|
|
|
|
logger.info('--- Generate test data pipeline ---')
|
|
|
|
args = ingest_arguments()
|
|
|
|
main(
|
|
filepath=args.filepath,
|
|
output_folder=args.output_folder,
|
|
percentage=args.percentage,
|
|
volume=args.volume,
|
|
sampling=args.sampling
|
|
)
|
|
|