From 109808cb3331dd2d384434a97e401f3fdf7e5e61 Mon Sep 17 00:00:00 2001 From: Tereso del Rio Date: Thu, 14 Sep 2023 19:58:50 +0200 Subject: [PATCH] cleaning some errors --- main.py | 13 ++++++++----- test_train_datasets.py | 31 ++++++++++++++++++------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index ab1064e..8f71563 100644 --- a/main.py +++ b/main.py @@ -30,7 +30,7 @@ # Hyperparameter tuning take a very long time, # if tune_hyperparameters is used to decide whether to tune them # or to used previously tuned -# tune_hyperparameters = False +tune_hyperparameters = False paradigm = 'classification' cleaning_dataset() @@ -52,12 +52,14 @@ output_file = "classification_output_acc_time.csv" for ml_model in ml_models: print(f"Testing models trained in {training_method}") - metrics = test_model(ml_model, paradigm=training_method, testing_method=testing_method) + metrics = test_model(ml_model, + paradigm=training_method, + testing_method=testing_method) if first_time == 1: first_time = 0 keys = list(metrics.keys()) with open(output_file, 'a') as f: - f.write('No more cheating\n') + f.write('Now really NO more cheating\n') f.write(', '.join(['Model'] + keys) + '\n') with open(output_file, 'a', newline='') as f: writer = csv.writer(f) @@ -66,7 +68,7 @@ # timings = dict() # testing_method = 'augmented' -# test_dataset_filename = find_dataset_filename('test', +# test_dataset_filename = find_dataset_filename('Test', # testing_method) # with open("classification_output_timings.csv", 'w') as f: @@ -83,7 +85,8 @@ # # with open("classification_output_acc_time.csv", 'a') as f: # # f.write(f"{ml_model}, {accuracy}, {total_time}\n") # with open("classification_output_timings.csv", 'a') as f: -# f.write(f"{ml_model}, {sum(timings['Normal'])}, {sum(timings['Balanced'])}, {sum(timings['Augmented'])}\n") +# f.write(f"{ml_model}, {sum(timings['Normal'])}, \ +# {sum(timings['Balanced'])}, {sum(timings['Augmented'])}\n") # timings['optimal'] = timings_in_test('optimal', testing_method) # print(sum(timings['optimal'])) # from make_plots import survival_plot diff --git a/test_train_datasets.py b/test_train_datasets.py index 8777fcc..767fbc1 100644 --- a/test_train_datasets.py +++ b/test_train_datasets.py @@ -59,22 +59,24 @@ def create_train_test_datasets(): keys = ['features', 'labels', 'timings', 'cells'] for purpose in purposes: datasets[f'{purpose}_Balanced'] = \ - {key: elem for key, elem in zip(keys, - balance_dataset(*[datasets[f'{purpose}_Normal'][key2] - for key2 in keys]) - ) + {key: elem for key, + elem in zip(keys, balance_dataset( + *[datasets[f'{purpose}_Normal'][key2] + for key2 in keys])) } datasets[f'{purpose}_Augmented'] = \ - {key: elem for key, elem in zip(keys, - augmentate_dataset(*[datasets[f'{purpose}_Normal'][key2] - for key2 in keys]) - ) + {key: elem for key, + elem in zip(keys, augmentate_dataset( + *[datasets[f'{purpose}_Normal'][key2] + for key2 in keys])) } for purpose in purposes: for quality in dataset_qualities: - this_dataset_filename = find_dataset_filename(purpose, method=quality) + this_dataset_filename = \ + find_dataset_filename(purpose, method=quality) with open(this_dataset_filename, 'wb') as this_dataset_file: - pickle.dump(datasets[purpose + '_' + quality], this_dataset_file) + pickle.dump(datasets[purpose + '_' + quality], + this_dataset_file) ## The following code is to count how many instances of each are there in the different datasets @@ -108,10 +110,13 @@ def create_regression_datasets(taking_logarithms=True): # we will use the augmented dataset here with open(this_dataset_filename, 'rb') as this_dataset_file: regression_dataset = pickle.load(this_dataset_file) + regression_dataset['labels'] = \ + [timings[0] for timings + in regression_dataset['timings']] if taking_logarithms: - regression_dataset['labels'] = [log(timings[0]) for timings in regression_dataset['timings']] - else: - regression_dataset['labels'] = [timings[0] for timings in regression_dataset['timings']] + regression_dataset['labels'] = \ + [log(label) for label + in regression_dataset['labels']] this_dataset_filename =\ find_dataset_filename(purpose, method='regression') with open(this_dataset_filename, 'wb') as this_dataset_file: