Skip to content

Commit

Permalink
Adding start of reinforecement
Browse files Browse the repository at this point in the history
  • Loading branch information
Tereso del Rio committed Sep 15, 2023
1 parent 109808c commit 7a8a578
Showing 19 changed files with 225 additions and 109 deletions.
7 changes: 4 additions & 3 deletions choose_hyperparams.py
Original file line number Diff line number Diff line change
@@ -31,10 +31,11 @@ def k_folds_ml(x_train, y_train, model, random_state=0):
def choose_hyperparams(ml_model, method):
"""Given a ml_model and a method, a file with the hyperparameters
chosen by cross validation is created"""
this_dataset_file = find_dataset_filename('train', method=method)
this_dataset_file = find_dataset_filename('Train', method=method)
with open(this_dataset_file, 'rb') as f:
x_train, y_train, _ = pickle.load(f)
hyperparams = k_folds_ml(x_train, y_train, model=ml_model)
dataset = pickle.load(f)
hyperparams = k_folds_ml(dataset['features'], dataset['labels'], model=ml_model)
print(hyperparams)
hyperparams_filename = find_hyperparams_filename(method, ml_model)
print(hyperparams_filename)
write_yaml_to_file(hyperparams, hyperparams_filename)
29 changes: 22 additions & 7 deletions config/hyperparameters_grid.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,16 @@

grid = dict()
grid['RF'] = {
'n_estimators': [200, 300, 400, 500],
'max_features': ['sqrt', 'log2'],
'max_depth': [4, 5, 6, 7, 8],
'criterion': ['gini', 'entropy']
# 'n_estimators': [200, 300, 400, 500],
# 'max_features': ['sqrt', 'log2'],
# 'max_depth': [4, 5, 6, 7, 8],
# 'criterion': ['gini', 'entropy']
'n_estimators': [50, 100, 200],
'criterion': ['gini', 'entropy'],
'max_depth': [None, 10, 20, 30],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
'class_weight': [None, 'balanced'],
}
grid['KNN'] = {
'n_neighbors': [1,3,5,7,12],
@@ -35,9 +41,18 @@
}

grid['RFR'] = {
'criterion': ['squared_error', 'friedman_mse'],
"max_depth": [1,3,7],
"min_samples_leaf": [1,5,10],
# 'n_estimators': [200, 300, 400, 500],
# 'max_features': ['sqrt', 'log2'],
# 'max_depth': [4, 5, 6, 7, 8],
# 'criterion': ['squared_error', 'entropy']
# # 'criterion': ['squared_error', 'friedman_mse'],
# # "max_depth": [1,3,7],
# # "min_samples_leaf": [1,5,10],
'n_estimators': [50, 100, 200],
'criterion': ['mse', 'mae'],
'max_depth': [None, 10, 20, 30],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
}
grid['KNNR'] = {
'n_neighbors': [3, 5, 10],
9 changes: 5 additions & 4 deletions config/hyperparams/augmented_RF.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class_weight: null
criterion: entropy
max_depth: 8
max_features: sqrt
n_estimators: 500
random_state: 18
max_depth: null
min_samples_leaf: 1
min_samples_split: 2
n_estimators: 200
11 changes: 6 additions & 5 deletions config/hyperparams/balanced_RF.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
criterion: gini
max_depth: 8
max_features: sqrt
n_estimators: 500
random_state: 18
class_weight: balanced
criterion: entropy
max_depth: 20
min_samples_leaf: 2
min_samples_split: 5
n_estimators: 50
Loading

0 comments on commit 7a8a578

Please sign in to comment.