Skip to content

Commit

Permalink
Modify my_mlp to improve perforemance
Browse files Browse the repository at this point in the history
delriot committed Oct 31, 2022
1 parent 74cd0eb commit d49d890
Showing 3 changed files with 7 additions and 14 deletions.
10 changes: 5 additions & 5 deletions basic_ml.py
Original file line number Diff line number Diff line change
@@ -53,17 +53,17 @@ def basic_ml(x_train, x_test, y_train, y_test, model, random_state=0):
return accuracy_score(clf.predict(x_test), y_test)


def use_tf(x_train, x_test, y_train, y_test, batch_size=64, epochs=100):
def use_tf(x_train, x_test, y_train, y_test, batch_size=50, epochs=100):
"""Train a fully connected neural network and return its accuracy."""
# design model
model = keras.models.Sequential([
keras.layers.Dense(18, activation='relu',
keras.layers.Dense(80, activation='softmax',
input_shape=(x_train.shape[1],)),
keras.layers.Dense(6),
keras.layers.Dense(40, activation='softmax'),
keras.layers.Dense(6, activation='softmax'),
])
# choose hyperparameters
print(y_train[:5])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optim = keras.optimizers.Adam(learning_rate=0.001)
metrics = ["accuracy"]
model.compile(loss=loss, optimizer=optim, metrics=metrics)
4 changes: 1 addition & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -18,10 +18,9 @@
import pickle
import random
import csv
from types import NoneType
import importlib.util
# Check if 'dataset_manipulation' is installed
if importlib.util.find_spec('dataset_manipulation') is NoneType:
if isinstance(importlib.util.find_spec('dataset_manipulation'), type(None)):
from dataset_manipulation import name_unique_features
from dataset_manipulation import remove_notunique_features
from dataset_manipulation import balance_dataset
@@ -42,7 +41,6 @@
'names_features_targets.txt')
with open(names_features_targets_file, 'rb') as f:
names, features, targets = pickle.load(f)
print(len(features), len(features[0]), len(targets))
augmented_features, augmented_targets = augmentate_dataset(features, targets)

normalized_augmented_features = normalize(augmented_features)
7 changes: 1 addition & 6 deletions ml_results.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,2 @@
Name,Normal,Balance data,Augment data
SVC,0.21,0.21,0.16
DT,0.32,0.27,0.34
KNN,0.29,0.39,0.51
RF,0.41,0.53,0.58
MPL,0.21,0.2,0.17
my_mlp,0.2,0.24,0.25
my_mlp,0.28,0.31,0.41

0 comments on commit d49d890

Please sign in to comment.