Model/model_data/simulation_system/test_data_generation.py
2023-08-14 17:22:21 +00:00

77 lines
2.6 KiB
Python

from Logger import logger
import argparse
import pandas as pd
from pathlib import Path
RANDOM_SEED = 0
def ingest_arguments() -> argparse.Namespace:
"""
Helper function to take in arguments from script start
"""
parser = argparse.ArgumentParser(description='Inputs for training script')
parser.add_argument('--filepath', type=str, help='Location of Parquet dataset to load', required=True)
parser.add_argument('--output-folder', type=str, help='Location of Parquet dataset to save', required=True)
parser.add_argument('--percentage', type=float, help='Percentage of data to use as test data', default=None)
parser.add_argument('--volume', type=int, help='Volume of data to use as test data', default=None)
parser.add_argument('--sampling', type=str, help='Type of sampling to do for test data', choices=['random', 'stratified'], default='random')
args = parser.parse_args()
return args
def main(filepath: str, output_folder: str, percentage: float, volume: int, sampling: str):
"""
Load a dataset in and split out the training+validation data and the test data.
"""
logger.info('---Loading Data---')
data = pd.read_parquet(filepath).reset_index(drop=True)
if percentage and volume is None:
test_amount = round(len(data)*percentage)
elif percentage is None and volume:
test_amount = volume
elif percentage is None and volume is None:
logger.error('No amount specified - please specify either a percentage or volume')
exit(1)
else:
logger.info('Both percentage and volume specified - taking largest of the two')
test_amount = max(round(len(data)*percentage), volume)
logger.info(f'---Extracting {test_amount} from dataset to be test data')
if sampling == 'random':
logger.info('--- Using random sample method ---')
sample_index = data.sample(n=test_amount, random_state=RANDOM_SEED).index
train_validation_data = data.drop(sample_index)
test_data = data.iloc[sample_index]
elif sampling =='stratified':
# Not yet implemented
pass
logger.info('--- Saving data ---')
train_validation_data.to_parquet(Path(output_folder)/'train_validation_data.parquet')
test_data.to_parquet(Path(output_folder)/'test_data.parquet')
logger.info(' ---Pipeline complete---')
if __name__ == "__main__":
logger.info('--- Generate test data pipeline ---')
args = ingest_arguments()
main(
filepath=args.filepath,
output_folder=args.output_folder,
percentage=args.percentage,
volume=args.volume,
sampling=args.sampling
)