Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Section 1: Input and process data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import the libraries are required\n",
"import pandas as pd # efficient data analysis tool library\n",
"import numpy as np # A numerical computation extension of an open element\n",
"import matplotlib.pyplot as plt # 2d drawing library\n",
"import warnings \n",
"%matplotlib inline\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.1 data input"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"healthcare-dataset-stroke-data.csv\") # read the data from dataset\n",
"df.head() # check the data frame"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.2 cherck the data structure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.info() # displays the information\n",
"df.shape # displays the shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.3 drop the unvalued data and repeat data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.drop(\"id\",axis=1,inplace=True) # remove the unvalued data\n",
"df = df.drop_duplicates() # delete the repeat data\n",
"df.shape # check the frame shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.describe() # check data description statistics "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.4 check the unique and null data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.nunique() # check the unique value in the data\n",
"df.isnull().any() # find null value"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.5 deal with the outliers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fill the part of bmi who has null with its mean value\n",
"df[\"bmi\"].fillna(df[\"bmi\"].mean(),inplace=True)\n",
"df.isnull().any() # check the data after process "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# find the unique type for all the cols\n",
"columns = [\"gender\", \"hypertension\", \"heart_disease\", \"ever_married\", \"work_type\", \"Residence_type\", \"smoking_status\", \"stroke\"]\n",
"for i in range(0, len(columns)):\n",
" colm = columns[i]\n",
" print(\"The unique data of\", colm, \"is:\", df[colm].unique())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check the num of other\n",
"df[\"gender\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# since just one special type, just remove it\n",
"df = df.drop(df[df[\"gender\"] == \"Other\"].index)\n",
"df[\"gender\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check the how mach Unknow in smoking_status\n",
"df[\"smoking_status\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use mode to replace unknow value\n",
"from numpy import nan\n",
"df[\"smoking_status\"].replace(\"Unknown\", nan, inplace=True) # first, replace Unknow to nan\n",
"df[\"smoking_status\"].fillna(df[\"smoking_status\"].mode()[0], inplace=True) # and then replace to the mode\n",
"df[\"smoking_status\"].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Section 2: Data analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.1 Analysis of discrete variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check for the stroke distribution\n",
"plt.figure(figsize=(8,6)) # define the figure size\n",
"plt.title(\"The rate of stroke\") # define the title\n",
"labels = df[\"stroke\"].value_counts().index\n",
"plt.pie(df[\"stroke\"].value_counts(), labels=labels,autopct='%1.2f%%',shadow=True,explode=(0,1))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# explore the relationship between stroke and rest of indicators\n",
"import seaborn as sns\n",
"# list of data to analysis\n",
"columns = [\"gender\",\"heart_disease\", \"ever_married\", \"work_type\", \"Residence_type\", \"smoking_status\"]\n",
"# define plot size\n",
"plt.figure(figsize=(2*15,3*12))\n",
"plt.rcParams[\"font.size\"] = 20\n",
"for col in columns:\n",
" idx = columns.index(col) + 1\n",
" xn = plt.subplot(3,2,idx)\n",
" sns.countplot(df[\"stroke\"], hue = df[col])\n",
" plt.ylabel(\"headcount\")\n",
" xn.set_title(\"Correlation between \"f\"{col}\"\" and stroke\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.2 Analysis of continuous variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure = plt.figure(figsize=(8,6))\n",
"plt.title(\"age and bmi\")\n",
"sns.scatterplot(x=df[\"age\"], y=df[\"bmi\"],hue=df[\"stroke\"],style=df[\"stroke\"])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure = plt.figure(figsize=(8,6))\n",
"plt.title(\"age and avg_glucose_level\")\n",
"sns.scatterplot(x=df[\"age\"], y=df[\"avg_glucose_level\"],hue=df[\"stroke\"],style=df[\"stroke\"])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure = plt.figure(figsize=(8,6))\n",
"plt.title(\"bmi and avg_glucose_level\")\n",
"sns.scatterplot(x=df[\"bmi\"], y=df[\"avg_glucose_level\"],hue=df[\"stroke\"],style=df[\"stroke\"])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# display the distribution of avg_glucose_level\n",
"plt.figure(figsize=(10,6))\n",
"plt.title(\"The distribution of avg_glucose_level\")\n",
"sns.kdeplot(df[\"avg_glucose_level\"])\n",
"# save the stroke == 1 data from avg_glucose_level\n",
"stroker = df.query(\"stroke==1\")[\"avg_glucose_level\"]\n",
"sns.kdeplot(stroker)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# display the distribution of bmi\n",
"plt.figure(figsize=(10,6))\n",
"plt.title(\"The distribution of bmi\")\n",
"sns.kdeplot(df[\"bmi\"])\n",
"# save the stroke == 1 data from bmi\n",
"stroker = df.query(\"stroke==1\")[\"bmi\"]\n",
"sns.kdeplot(stroker)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.3 Eigenvalue digitization and analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use get_dummies and label encoder to encode data\n",
"from pandas import get_dummies\n",
"from sklearn.preprocessing import LabelEncoder\n",
"# OneHotEncode the data has more than 2 categories\n",
"df1 = get_dummies(df, columns=[\"work_type\", \"Residence_type\", \"smoking_status\"])\n",
"encoder = LabelEncoder()\n",
"# LabelEncode the data binary classification data\n",
"df1[\"gender\"] = encoder.fit_transform(df1[\"gender\"])\n",
"df1[\"ever_married\"] = encoder.fit_transform(df1[\"ever_married\"])\n",
"df1.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build a correlation matrix to see the correlation between columns. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot a hotmap to explore the correlation between each data\n",
"clomns = [\"age\",\"hypertension\",\"heart_disease\",\"avg_glucose_level\",\"bmi\",\"gender\",\"ever_married\",\"work_type_Never_worked\",\"work_type_Private\",\n",
"\"work_type_Self-employed\",\"work_type_children\",\"Residence_type_Rural\",\"smoking_status_formerly smoked\",\"smoking_status_never smoked\",\"smoking_status_smokes\"]\n",
"data = df1[clomns]\n",
"correlation_matrix = data.corr().round(2) # keep values to two decimal places\n",
"plt.figure(figsize=(25,15),dpi=200)\n",
"sns.heatmap(correlation_matrix, annot=True)\n",
"plt.title(\"correlation between classes\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot a bar chart to show the correlation between stroke and variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(15,5))\n",
"plt.title(\"The correlation between stroke and features\")\n",
"# use sort_value to distribute the data from high to low\n",
"# sort the value from high to low\n",
"df1.corr()[\"stroke\"].sort_values(ascending = False).plot(kind = \"bar\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# explore the stroker distrubute in different ages\n",
"plt.figure(figsize=(10,6))\n",
"plt.title(\"Changes in stroke probability with age\")\n",
"# plot the density distribution of people who have not stroke with ages. \n",
"sns.kdeplot(df1.age[df1.stroke==0],color=\"b\",shade=True,label=\"stroke=0\")\n",
"# plot the density distribution of people who have stroke with ages. \n",
"sns.kdeplot(df1.age[df1.stroke==1],color=\"r\",shade=True,label=\"stroke=1\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.4 Features Selection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# drop the low correlation data\n",
"df1.drop([\"gender\",\"Residence_type_Rural\",\"Residence_type_Urban\"], axis = 1, inplace=True)\n",
"df1.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# scaler the continuous variable to plot a boxplot to find the outliers\n",
"from sklearn.preprocessing import StandardScaler\n",
"ss = StandardScaler()\n",
"data_check = df1[[\"age\", \"bmi\", \"avg_glucose_level\"]]\n",
"data_check[[\"age\", \"bmi\", \"avg_glucose_level\"]] = ss.fit_transform(data_check[[\"age\", \"bmi\", \"avg_glucose_level\"]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use the box plot to check is any outliers in the data\n",
"plt.figure(figsize = (10,5),dpi=120) # use high dpi is greater to analysis the plot\n",
"plt.title(\"The value distribution of age and avg_glucose_level and bmi\")\n",
"boxes = sns.boxplot(data=data_check[[\"age\",\"avg_glucose_level\",\"bmi\"]])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the quartile method to find outliers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_to_test = df1[[\"avg_glucose_level\",\"bmi\"]]\n",
"\n",
"def delete_outliers(data,cloms,n): # define quartile function\n",
" indices = [] # saves the outlier's index\n",
" for clom in cloms:\n",
" q1 = np.percentile(data[clom],25) # define the quarter 1\n",
" q3 = np.percentile(data[clom],75) # define the quarter 3\n",
" IQR = q3 - q1 # IQR is the aera between q1 and q3\n",
" outlier_step = 1.5 * IQR \n",
" outlier_colm = data[(data[clom] < q1 - outlier_step) | (data[clom] > q3 + outlier_step)].index # the value is over than q3 + 1.5 IQR or less than q1 - 1.5 IQR will be recognized as outlier\n",
" indices.extend(outlier_colm)\n",
" from collections import Counter\n",
" indices = Counter(indices) # counter the occur times of the outlier\n",
" outliers_ = list(k for k, v in indices.items() if v >= n) # if the occur times of the outlier is over than n(conditions), then they will be returned\n",
" return outliers_ "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# values are recognized as an outlier when both conditions (exists in avg glucose level and bmi) are met\n",
"outliers = delete_outliers(data_to_test, [\"avg_glucose_level\", \"bmi\"], 2)\n",
"len(outliers) #print the outlers number"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df1 = df1.drop(outliers,axis=0).reset_index(drop=True) # removed the outliers from original dataset\n",
"df1.shape # check the shape after drop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Section 3: establish analysis models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.1 Build the train and test data sets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.1.1 scaler the large distribution data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# scaler the data is not evenly distributed\n",
"from sklearn.preprocessing import StandardScaler\n",
"mms = StandardScaler()\n",
"df1[[\"avg_glucose_level\"]] = mms.fit_transform(df1[[\"avg_glucose_level\"]])\n",
"df1[\"avg_glucose_level\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.1.2 split the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data = df1.drop([\"stroke\"],axis=1)\n",
"# feature data\n",
"X = train_data.iloc[:,:]\n",
"# prediction data\n",
"y = df1.iloc[:,6]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# split the data to build model\n",
"from sklearn.model_selection import train_test_split\n",
"# split the 75% date use for training, and 25% use for test\n",
"X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.1.3 use SMOTE to oversample the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from imblearn.over_sampling import SMOTE\n",
"smo = SMOTE(random_state=32)\n",
"X_train, y_train = smo.fit_resample(X_train,y_train.ravel())\n",
"# check the data after SMOTE\n",
"# count the stroke num after SMOTE\n",
"print(\"After SMOTE the stroke number is :\", sum(y_train == 1))\n",
"# count the non-stroke num after SMOTE\n",
"print(\"After SMOTE the non-stroke number is :\", sum(y_train == 0))\n",
"X_train.shape, y_train.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.1.4 Create a function to execute models and evaluate them"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define a function to run the models and make the evaluations, and then record them in a list."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import the evlation methods\n",
"from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, precision_score, recall_score, f1_score\n",
"from sklearn.model_selection import cross_val_score\n",
"results = [] # where the evaluate are stored\n",
"def evaluator(model):\n",
" # Put the data into the model\n",
" model.fit(X_train,y_train)\n",
" # get the predict data\n",
" y_pred = model.predict(X_test)\n",
" # a series of evaluation methods\n",
" cm = confusion_matrix(y_test, y_pred)\n",
" # accuracy score\n",
" accuracy = accuracy_score(y_test, y_pred)\n",
" # cross validation\n",
" cvs = cross_val_score(model,X_train,y_train,cv=4)\n",
" # roc auc score\n",
" roc_auc = roc_auc_score(y_test, y_pred)\n",
" # recall score\n",
" recall = recall_score(y_test, y_pred)\n",
" # precision score\n",
" precision = precision_score(y_test,y_pred)\n",
" f1 = f1_score(y_test,y_pred)\n",
" # print the results\n",
" print(model,\":\")\n",
" print(cm)\n",
" print(\"accuracy score:\", accuracy)\n",
" print(\"cvs mean score:\", cvs.mean())\n",
" print(\"roc auc score:\", roc_auc)\n",
" print(\"recall:\", recall)\n",
" print(\"Precision:\", precision)\n",
" print(\"f1 score:\", f1)\n",
" # append to a list to further comparison\n",
" results.append([accuracy, cvs.mean(), roc_auc, recall, precision, f1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.2 Analysis of model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.2.1 Logistic Regression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"lr = LogisticRegression()\n",
"evaluator(lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.2.2 Random Forest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"rf = RandomForestClassifier()\n",
"evaluator(rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.2.3 Support Vector Machine"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"svm = SVC(kernel='rbf')\n",
"evaluator(svm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.2.4 XG Boost"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from xgboost import XGBClassifier\n",
"xgb = XGBClassifier()\n",
"evaluator(xgb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Section 4: Comparison of scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# keep four decimal places of all results\n",
"results = np.round(results,4)\n",
"results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# list models\n",
"models = [\"LogisticRegression\",\"RandomForest\",\"Support Vector Machine\",\"XG Boosting\"]\n",
"# list of evaluation method\n",
"cloms = [\"model\",\"accuracy\",\"cvs mean\",\"roc_auc\",\"recall\",\"precision\",\"f1\"]\n",
"results_compare = pd.DataFrame(columns=cloms)\n",
"results_compare[\"model\"] = [i for i in models]\n",
"for i in range(0,len(results) + 2):\n",
" results_compare[cloms[i + 1]] = results[:,i]\n",
" # sort them by f1 and accuracy\n",
"results_compare.sort_values(by=[\"f1\",\"accuracy\"],inplace=True,ascending=False)\n",
"results_compare"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Section 5: model tuning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use grid rearch cv to tune the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"6.1 tune the Logistic regression model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"# Adjust the hyperparameters of logistic regression\n",
"lr_grid = GridSearchCV(lr, [{\"penalty\":[\"none\",\"l1\",\"l2\"], \"C\":[1, 5, 10]}], scoring = \"accuracy\", cv=10)\n",
"lr_grid.fit(X_train,y_train)\n",
"best_accuracy = lr_grid.best_score_\n",
"bset_parameters = lr_grid.best_params_\n",
"print(\"model\", lr, \"\\n best accuracy is:\", best_accuracy)\n",
"print(\"best parameters are:\", bset_parameters)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the classification report of lr\n",
"from sklearn.metrics import classification_report\n",
"y_pred = lr_grid.predict(X_test)\n",
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot the lr confusion matrix\n",
"cfm = confusion_matrix(y_test, y_pred)\n",
"sns.heatmap(cfm, annot=True, fmt = 'd', cmap='Blues')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"6.2 tune the XGB model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adjust the hyperparameters of XG Boosting\n",
"xgb_grid = GridSearchCV(xgb, [{\"learning_rate\":[0.01, 0.05, 0.1], \"eval_metric\":[\"error\"]}], scoring = \"accuracy\", cv=10)\n",
"xgb_grid.fit(X_train,y_train)\n",
"best_accuracy = xgb_grid.best_score_\n",
"bset_parameters = xgb_grid.best_params_\n",
"print(\"model\", rf, \"\\n best accuracy is:\", best_accuracy)\n",
"print(\"best parameters are:\", bset_parameters)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the classification report of lr\n",
"y_pred = xgb_grid.predict(X_test)\n",
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot the xgb confusion matrix\n",
"cfm = confusion_matrix(y_test, y_pred)\n",
"sns.heatmap(cfm, annot=True, fmt = 'd', cmap='Blues')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adjust the hyperparameters of Support Vector Machine\n",
"svm_grid = GridSearchCV(svm, [{'kernel':['sigmoid'],'C':[1, 5, 10],'gamma':[0.001, 0.0001]}], scoring = \"accuracy\", cv=10)\n",
"svm_grid.fit(X_train,y_train)\n",
"best_accuracy = svm_grid.best_score_\n",
"bset_parameters = svm_grid.best_params_\n",
"print(\"model\", svm, \"\\n best accuracy is:\", best_accuracy)\n",
"print(\"best parameters are:\", bset_parameters)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print the classification report of lr\n",
"y_pred = svm_grid.predict(X_test)\n",
"print(classification_report(y_test,y_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot the svm confusion matrix\n",
"cfm = confusion_matrix(y_test, y_pred)\n",
"sns.heatmap(cfm, annot=True, fmt = 'd', cmap='Blues')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.6 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "865d8b2eb28e274047ba64063dfb6a2aabf0dfec4905d304d7a76618dae6fdd4"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}