-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
14 changed files
with
119 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,41 @@ | ||
import pickle | ||
import numpy as np | ||
from replicating_Dorians_features import extract_features | ||
from basic_ml import use_tf, basic_ml | ||
from itertools import product | ||
import sys | ||
import os | ||
import csv | ||
import importlib | ||
if isinstance(importlib.util.find_spec('dataset_manipulation'), type(None)): | ||
from dataset_manipulation import name_unique_features | ||
from dataset_manipulation import remove_notunique_features | ||
from dataset_manipulation import balance_dataset | ||
from dataset_manipulation import augmentate_dataset | ||
else: | ||
from packages.dataset_manipulation import name_unique_features | ||
from packages.dataset_manipulation import remove_notunique_features | ||
from packages.dataset_manipulation import balance_dataset | ||
from packages.dataset_manipulation import augmentate_dataset | ||
|
||
|
||
dataset_file = os.path.join(os.path.dirname(__file__), 'DatasetsBeforeProcessing', 'dataset_without_repetition_return_ncells.txt') | ||
f = open(dataset_file, 'rb') | ||
dataset = pickle.load(f) | ||
original_polys_list, names, features_list, targets_list, timings_list = extract_features(dataset) | ||
dataset_filename = os.path.join(os.path.dirname(__file__), 'DatasetsBeforeProcessing', 'dataset_without_repetition_return_ncells.txt') | ||
clean_dataset_filename = os.path.join(os.path.dirname(__file__), | ||
'datasets', | ||
'clean_dataset.txt') | ||
|
||
# working with raw features | ||
features = np.array(features_list) | ||
targets = np.array(targets_list) | ||
timings = np.array(timings_list) | ||
original_polys = np.array(original_polys_list) | ||
|
||
clean_dataset_file = os.path.join(os.path.dirname(__file__), | ||
'datasets', | ||
'clean_dataset.txt') | ||
g = open(clean_dataset_file, 'wb') | ||
dataset = pickle.dump((original_polys, names, features, targets, timings), g) | ||
def cleaning_dataset(dataset_filename, clean_dataset_filename): | ||
with open(dataset_filename, 'rb') as f: | ||
dataset = pickle.load(f) | ||
original_polys_list, names, features_list, targets_list, timings_list = extract_features(dataset) | ||
|
||
# working with raw features | ||
features = np.array(features_list) | ||
unique_names, unique_features = remove_notunique_features(names, features) | ||
|
||
targets = np.array(targets_list) | ||
timings = np.array(timings_list) | ||
original_polys = np.array(original_polys_list) | ||
|
||
with open(clean_dataset_filename, 'wb') as g: | ||
dataset = pickle.dump((original_polys, unique_names, unique_features, targets, timings), g) | ||
|
||
cleaning_dataset(dataset_filename, clean_dataset_filename) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,22 @@ | ||
import yaml | ||
from yaml import UnsafeLoader | ||
import os | ||
from config.ml_models import ml_models | ||
from config.ml_models import dataset_types | ||
|
||
print(ml_models) | ||
for ml_model in ml_models: | ||
for method in dataset_types: | ||
filename = os.path.join(os.path.dirname(__file__), | ||
'config', 'hyperparams', | ||
f'{method}_{ml_model}.yaml') | ||
with open(filename, 'r') as f: | ||
hyperparameters = yaml.load(f, Loader=UnsafeLoader) | ||
print(type(hyperparameters), hyperparameters) | ||
import pickle | ||
from yaml_tools import read_yaml_from_file | ||
from config.ml_models import classifiers | ||
|
||
|
||
def train_model(ml_model, method): | ||
train_data_file = os.path.join(os.path.dirname(__file__), | ||
'datasets', 'train', | ||
f'{method}_train_dataset.txt') | ||
hyperparams_file = os.path.join(os.path.dirname(__file__), | ||
'config', 'hyperparams', | ||
f'{method}_{ml_model}') | ||
with open(train_data_file, 'rb') as f: | ||
method_x_train, method_y_train = pickle.load(f) | ||
hyperparams = read_yaml_from_file(hyperparams_file) | ||
current_classifier = classifiers[ml_model] | ||
clf = current_classifier(**hyperparams) | ||
clf.fit(method_x_train, method_y_train) | ||
|
||
|
||
# print(train_model(ml_models[1], dataset_types[0])) |