Model/model_data/simulation_system/test_data_generation.py
2023-08-17 16:07:22 +01:00

80 lines
2.7 KiB
Python

from core.Logger import logger
import argparse
import pandas as pd
from pathlib import Path
from core.Settings import (
RANDOM_SEED,
TRAIN_AND_VALIDATION_DATA_NAME,
TEST_DATA_NAME
)
def ingest_arguments() -> argparse.Namespace:
"""
Helper function to take in arguments from script start
"""
parser = argparse.ArgumentParser(description='Inputs for training script')
parser.add_argument('--filepath', type=str, help='Location of Parquet dataset to load', required=True)
parser.add_argument('--output-folder', type=str, help='Location of Parquet dataset to save', required=True)
parser.add_argument('--percentage', type=float, help='Percentage of data to use as test data', default=None)
parser.add_argument('--volume', type=int, help='Volume of data to use as test data', default=None)
parser.add_argument('--sampling', type=str, help='Type of sampling to do for test data', choices=['random', 'stratified'], default='random')
args = parser.parse_args()
return args
def main(filepath: str, output_folder: str, percentage: float, volume: int, sampling: str):
"""
Load a dataset in and split out the training+validation data and the test data.
"""
logger.info('---Loading Data---')
data = pd.read_parquet(filepath).reset_index(drop=True)
if percentage and volume is None:
test_amount = round(len(data)*percentage)
elif percentage is None and volume:
test_amount = volume
elif percentage is None and volume is None:
logger.error('No amount specified - please specify either a percentage or volume')
exit(1)
else:
logger.info('Both percentage and volume specified - taking largest of the two')
test_amount = max(round(len(data)*percentage), volume)
logger.info(f'---Extracting {test_amount} from dataset to be test data')
if sampling == 'random':
logger.info('--- Using random sample method ---')
sample_index = data.sample(n=test_amount, random_state=RANDOM_SEED).index
train_validation_data = data.drop(sample_index)
test_data = data.iloc[sample_index]
elif sampling =='stratified':
# Not yet implemented
pass
logger.info('--- Saving data ---')
train_validation_data.to_parquet(Path(output_folder)/ TRAIN_AND_VALIDATION_DATA_NAME)
test_data.to_parquet(Path(output_folder)/ TEST_DATA_NAME)
logger.info(' ---Pipeline complete---')
if __name__ == "__main__":
logger.info('--- Generate test data pipeline ---')
args = ingest_arguments()
main(
filepath=args.filepath,
output_folder=args.output_folder,
percentage=args.percentage,
volume=args.volume,
sampling=args.sampling
)