Logistic Regression works only for binary(two class) classification problems. Linear Discriminant Analysis is a statistical technique which shines in multiclass classification.
Note: Linear Discriminant Analysis is a linear machine learning(ML) algorithm which is much simpler and faster than non-linear algorithms.
Medium Post: Top 10 algorithms for ML newbies
This recipe includes the following topics:
- Load classification problem dataset (Pima Indians) from github
- Split columns into the usual feature columns(X) and target column(Y)
- Set k-fold count to 10
- Set seed to reproduce the same random data each time
- Split data using KFold() class
- Instantiate the classification algorithm: LinearDiscriminantAnalysis
- Call cross_val_score() to run cross validation
- Calculate mean estimated accuracy from scores returned by cross_val_score()
# import modules
import pandas as pd
import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
# read data file from github
# dataframe: pimaDf
gitFileURL = 'https://raw.githubusercontent.com/andrewgurung/data-repository/master/pima-indians-diabetes.data.csv'
cols = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
pimaDf = pd.read_csv(gitFileURL, names = cols)
# convert into numpy array for scikit-learn
pimaArr = pimaDf.values
# Let's split columns into the usual feature columns(X) and target column(Y)
# Y represents the target 'class' column whose value is either '0' or '1'
X = pimaArr[:, 0:8]
Y = pimaArr[:, 8]
# set k-fold count
folds = 10
# set seed to reproduce the same random data each time
seed = 7
# split data using KFold
kfold = KFold(n_splits=folds, random_state=seed)
# instantiate the classification algorithm
model = LinearDiscriminantAnalysis()
# call cross_val_score() to run cross validation
resultArr = cross_val_score(model, X, Y, cv=kfold)
# calculate mean of scores for all folds
meanAccuracy = resultArr.mean() * 100
# display mean estimated accuracy
print("Mean estimated accuracy: %.3f%%" % meanAccuracy)
Mean estimated accuracy: 77.346%