Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
COVID-19-Detection/train_covid19.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
169 lines (142 sloc)
5.84 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Abdullah Jalil 2020 | |
# imported libraries | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from tensorflow.keras.applications import VGG16 | |
from tensorflow.keras.layers import AveragePooling2D | |
from tensorflow.keras.layers import Dropout | |
from tensorflow.keras.layers import Flatten | |
from tensorflow.keras.layers import Dense | |
from tensorflow.keras.layers import Input | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.optimizers import Adam | |
from tensorflow.keras.utils import to_categorical | |
from sklearn.preprocessing import LabelBinarizer | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import classification_report | |
from sklearn.metrics import confusion_matrix | |
from sklearn.metrics import roc_curve | |
from imutils import paths | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import argparse | |
import cv2 | |
import os | |
argument_parser = argparse.ArgumentParser() | |
# Path to the chest x-ray dataset | |
argument_parser.add_argument("-d", "--dataset", required=True, help="path to input dataset") | |
# Graph Plot to display training data | |
argument_parser.add_argument("-p", "--plot", type=str, default="plot.png", help="path to output loss/accuracy plot") | |
# Covid model | |
argument_parser.add_argument("-m", "--model", type=str, default="covid.model", help="path to output loss/accuracy plot") | |
arguments = vars(argument_parser.parse_args()) | |
# initialize the initial learning rate, number of epochs to train | |
initialize_learning_rate = 1e-5 | |
epochs = 40 | |
# batch size | |
batchSize = 10 | |
# select the list of images in dataset directory | |
print("[1] Images loading from Directory...") | |
imageDirectory = list(paths.list_images(arguments["dataset"])) | |
input_data = [] | |
labels = [] | |
# loop over the image directory | |
for imageDirectory in imagePaths: | |
# extract the class label depedning on the file name | |
label = imageDirectory.split(os.path.sep)[-2] | |
# load the image | |
image = cv2.imread(imageDirectory) | |
# swap color channels | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# resize it to be a fixed to 224x224 pixels | |
image = cv2.resize(image, (224, 224)) | |
# update the label and data list | |
input_data.append(image) | |
labels.append(label) | |
# intensities to the range [0, 255] | |
input_data = np.array(data) / 255.0 | |
# convert the data and labels in to arrays(NumPy) | |
labels = np.array(labels) | |
# perform one-hot encoding | |
lb = LabelBinarizer() | |
labels = lb.fit_transform(labels) | |
labels = to_categorical(labels) | |
# partition the data into training 70% | |
# partition the data into testing 30% | |
(trainX, testX, trainY, testY) = train_test_split(input_data, labels, | |
test_size=0.20, stratify=labels, random_state=42) | |
# initialize the training data augmentation object | |
trainAugmentation = ImageDataGenerator( | |
rotation_range=15, | |
fill_mode="nearest") | |
# load the VGG16 network, ensuring the head FC layer sets are left | |
# off | |
baseModel = VGG16(weights="imagenet", include_top=False, | |
input_tensor=Input(shape=(224, 224, 3))) | |
# construct the Top of the model | |
topModel = baseModel.output | |
# perform pooling stride of 4 by 4 | |
topModel = AveragePooling2D(pool_size=(4, 4))(topModel) | |
# perform flattening | |
topModel = Flatten(name="flatten")(topModel) | |
# activiation function RelU | |
topModel = Dense(64, activation="relu")(topModel) | |
topModel = Dropout(0.5)(topModel) | |
# activation function softmax | |
topModel = Dense(2, activation="softmax")(topModel) | |
# place the head Full connected model on top of the base model | |
model = Model(inputs=baseModel.input, outputs=topModel) | |
# loop over all layers and freeze them so they are not trained | |
for layer in baseModel.layers: | |
layer.trainable = False | |
# compile our model | |
print("[2] Creating model...") | |
opt = Adam(lr=initialize_learning_rate, decay=initialize_learning_rate / epochs) | |
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"]) | |
# train the top of the network | |
print("[3] Training model head...") | |
H = model.fit_generator( | |
trainAugmentation.flow(trainX, trainY, batch_size=batchSize), | |
steps_per_epoch=len(trainX) // batchSize, | |
validation_data=(testX, testY), | |
validation_steps=len(testX) // batchSize, | |
epochs=epochs) | |
# make predictions on the testing set | |
print("[4] Making prediction using testing set...") | |
predicitionMatrix = model.predict(testX, batch_size=batchSize) | |
# for each image in the testing set we need to find the index of the | |
# label with corresponding largest predicted probability | |
predicitionMatrix = np.argmax(predicitionMatrix, axis=1) | |
# show a nicely formatted classification report | |
print(classification_report(testY.argmax(axis=1), predicitionMatrix, target_names=lb.classes_)) | |
# compute the confusion matrix | |
# derive the , sensitivity, and specificity | |
confusionMatrix = confusion_matrix(testY.argmax(axis=1), predicitionMatrix) | |
total = sum(sum(confusionMatrix)) | |
# derive raw accuracy | |
accuracy = (confusionMatrix[0, 0] + confusionMatrix[1, 1]) / total | |
# derive raw sensitivity | |
sensitivity = confusionMatrix[0, 0] / (confusionMatrix[0, 0] + confusionMatrix[0, 1]) | |
# derive raw specificty | |
specificity = confusionMatrix[1, 1] / (confusionMatrix[1, 0] + confusionMatrix[1, 1]) | |
# Print confusion matrix, accuracy, sensitivity, and specificity | |
print(confusionMatrix) | |
print("accuracy: {:.4f}".format(accuracy)) | |
print("sensitivity: {:.4f}".format(sensitivity)) | |
print("specificity: {:.4f}".format(specificity)) | |
# plot the training loss and accuracy | |
N = epochs | |
plt.style.use("ggplot") | |
plt.figure() | |
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss") | |
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss") | |
plt.plot(np.arange(0, N), H.history["accuracy"], label="train_acc") | |
plt.plot(np.arange(0, N), H.history["val_accuracy"], label="val_acc") | |
plt.title("Training Loss and Accuracy on COVID-19 Dataset") | |
plt.xlabel("Epoch #") | |
plt.ylabel("Loss/Accuracy") | |
plt.legend(loc="lower left") | |
plt.savefig(arguments["plot"]) | |
# Print the model to disk | |
print("[5] saving COVID-19 Model...") | |
model.save(arguments["model"], save_format="h5") |