wheat Grain Classification
Raja Reddy
Friday, September 18, 2015
Machine Learning for three class data classification.
Can we classifiy Wheat varities based on images?
Why this analysis :
This is my attempt to practice Machine Learning (ML). I am sourcing the data from UCI Machine Learning Repository. Your suggestions and feedback are more than welcome. ### What is this analysis about: Classification of wheat varieties based on physical measurements such as area, perimeter and length extracted from images has been explored for long. In fact the seeds data set in UCI Machine Learning Repository is created to explore this very objective. Several attempts with different ML algorithms have been made to use this data set. Here i plan to explore the random Forest algorithm as implemented in R to classify this data.
How this analysis is structured
Broadly, i classify this analysis into 1) Descriptive 2) Exploratory 3) Model building & choice 4) Predicting and evaluation.
First step to do is to set the working directory with “setwd()”. Then load all the required packages:
library(randomForest)
## randomForest 4.6-10
## Type rfNews() to see new features/changes/bug fixes.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(ggplot2)
library(lattice)
library(corrplot)
Now get the data from URL to R environment. But this is how you can source the data from UCI-ML repo. I don’t want to execute this portion. (Ignore if you already have this data in your working directory)
#url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt"
#download.file(url, destfile = "Wheat.txt")
#The downloaded file name is wheat.txt (it is a tab delimited text file) If you want to use this file for your work use this following commands
I will use the data i had downloaded.
wheat <- read.csv("Wheat.txt", header = F, sep = "\t")
#wheat <- read.table("Wheat1.txt", header = F, sep = "\t")
names(wheat)<- c("Area","Perimeter","compactness","length","width","asymmetry","grooveLength","variety")
wheat$variety <- as.character(wheat$variety)
wheat$variety[wheat$variety == "1"] <- "Kama"
wheat$variety[wheat$variety == "2"] <- "Rosa"
wheat$variety[wheat$variety == "3"] <- "Canadian"
nrow(wheat)
## [1] 210
Begin descriptive analysis.
Now let us explore this data. This will help us to understand the data by describing the data in terms of number of variables, types of variables etc. see what is there in the wheat data frame and column names with “head”.
head(wheat, n = 5)
## Area Perimeter compactness length width asymmetry grooveLength variety
## 1 15.26 14.84 0.8710 5.763 3.312 2.221 5.220 Kama
## 2 14.88 14.57 0.8811 5.554 3.333 1.018 4.956 Kama
## 3 14.29 14.09 0.9050 5.291 3.337 2.699 4.825 Kama
## 4 13.84 13.94 0.8955 5.324 3.379 2.259 4.805 Kama
## 5 16.14 14.99 0.9034 5.658 3.562 1.355 5.175 Kama
Get summary of data with “summary()”. This will generate information on type of variables, spread etc. We will generate Box plots of variables to see the variability in observations. Alternatively with density plots, we could explore the frequency distributions of individual variable values.
summary(wheat)
## Area Perimeter compactness length
## Min. :10.59 Min. :12.41 Min. :0.8081 Min. :4.899
## 1st Qu.:12.27 1st Qu.:13.45 1st Qu.:0.8569 1st Qu.:5.262
## Median :14.36 Median :14.32 Median :0.8734 Median :5.524
## Mean :14.85 Mean :14.56 Mean :0.8710 Mean :5.629
## 3rd Qu.:17.30 3rd Qu.:15.71 3rd Qu.:0.8878 3rd Qu.:5.980
## Max. :21.18 Max. :17.25 Max. :0.9183 Max. :6.675
## width asymmetry grooveLength variety
## Min. :2.630 Min. :0.7651 Min. :4.519 Length:210
## 1st Qu.:2.944 1st Qu.:2.5615 1st Qu.:5.045 Class :character
## Median :3.237 Median :3.5990 Median :5.223 Mode :character
## Mean :3.259 Mean :3.7002 Mean :5.408
## 3rd Qu.:3.562 3rd Qu.:4.7687 3rd Qu.:5.877
## Max. :4.033 Max. :8.4560 Max. :6.550
boxplot(wheat[,1:7],data=wheat, notch = T, main = "Wheat seed measurment variability", col=rainbow(length(unique(wheat))))
featurePlot(x = wheat[,1:7], y = as.factor(wheat$variety), plot = "density", scales = list(x=list(relation="free"), y=list(relation="free")), auto.key = T, main = "Density distributions of variables across varieties")
Begin exploratory analysis
In this phase i would like to analyze variables and their relations. One good way to do that is find to find the correlations among variables
dev = "png"
par(mfrow = c(1,2))
cce <- cor(wheat[,1:7], use = "pairwise", method="pearson")# caliculate correlations
corrplot(cce) #plot correlations
title("Correlation", line = -3)
# However i would like to see the plot sorted by correlations
cce.ord <- order(cce[1,])
cce.1 <- cce[cce.ord, cce.ord]
corrplot(cce.1)
title("sorted on correlations", line = -3)
Other ways to see correlations is scatter plot and parallel co-ordinate plot. These plots tell us that the relations between the variables across wheat varieties. I will do it with pairs() and parallel plot() functions
pairs(wheat[,1:7],pch=21, col=as.factor(wheat$variety))
parallelplot(~wheat[1:7] | variety, wheat)# not much useful but you could use it for testing
parallelplot(~wheat[1:7], wheat, groups = wheat$variety, auto.key = T, ylab = "Grain measurements", main = "Parallel coordinate plot of Variable relations across three varities")
Now that we understand our variables and their relations, let-us explore how these variables could be used for classification of our wheat varieties. For this purpose i am choosing randomForest machine learning algorithm. Though it is not necessary to have a test and training set information separately for randomForest model building and evaluation, i am sticking to the classic ML methodology as followed in supervised learning.
# create data partitioning
set.seed(756) # this helps in reproducing
inTrain<-createDataPartition(y = wheat$variety,p = 0.75, list=F)
train<- wheat[inTrain,]
str(train)
## 'data.frame': 159 obs. of 8 variables:
## $ Area : num 15.3 14.9 14.3 13.8 16.1 ...
## $ Perimeter : num 14.8 14.6 14.1 13.9 15 ...
## $ compactness : num 0.871 0.881 0.905 0.895 0.903 ...
## $ length : num 5.76 5.55 5.29 5.32 5.66 ...
## $ width : num 3.31 3.33 3.34 3.38 3.56 ...
## $ asymmetry : num 2.22 1.02 2.7 2.26 1.35 ...
## $ grooveLength: num 5.22 4.96 4.83 4.8 5.17 ...
## $ variety : chr "Kama" "Kama" "Kama" "Kama" ...
test<-wheat[-inTrain,]
str(test)
## 'data.frame': 51 obs. of 8 variables:
## $ Area : num 14.7 16.4 14.7 14.1 13 ...
## $ Perimeter : num 14.5 15.2 14.2 14.3 13.8 ...
## $ compactness : num 0.88 0.888 0.915 0.872 0.864 ...
## $ length : num 5.56 5.88 5.21 5.52 5.39 ...
## $ width : num 3.26 3.5 3.47 3.17 3.03 ...
## $ asymmetry : num 3.59 1.97 1.77 2.69 3.37 ...
## $ grooveLength: num 5.22 5.53 4.65 5.22 4.83 ...
## $ variety : chr "Kama" "Kama" "Kama" "Kama" ...
Begin Model building & choice
Now let us develop the model. As mentioned i am using randomForest. The problem at hand is a typical classification (type of variety) challenge. We will look into the regression type of problem in another post. Basically randomForest algorithm tries to build multiple decision trees (you can specify) picking the defined set of predictors/features, to arrive at response variable (in our case variety).
# call the randomForest function, specify the predictors and response
fit <- randomForest(x = train[,1:7], y = as.factor(train$variety), ntree = 1000, mtry = 7, importance = T, proximity = TRUE)
fit # this will print the model
##
## Call:
## randomForest(x = train[, 1:7], y = as.factor(train$variety), ntree = 1000, mtry = 7, importance = T, proximity = TRUE)
## Type of random forest: classification
## Number of trees: 1000
## No. of variables tried at each split: 7
##
## OOB estimate of error rate: 8.81%
## Confusion matrix:
## Canadian Kama Rosa class.error
## Canadian 49 4 0 0.07547170
## Kama 7 44 2 0.16981132
## Rosa 0 1 52 0.01886792
plot(fit) # plots error rate over trees
varImpPlot(fit, main = "Average variable Importance") # plots variable imporentce.
#You can use below commands to see for individual class
#varImpPlot(fit, class = "Rosa", main = "Rosa Importance")
#varImpPlot(fit, class = "Kama", main = "Kama Importance")
#varImpPlot(fit, class = "Canadian", main = "Canadian Importance")
margins.rf <- margin(fit, train)
#the function margin(), measures the extent to which the average number of votes for the correct class exceeds the average vote for any other class present in the dependent variable. (ref:http://www.statsoft.com/Textbook/Random-Forest)
plot(margins.rf)
## Loading required package: RColorBrewer
hist(margins.rf)
boxplot(margins.rf~train$variety)
### Begin Predicting and evaluate Now that we have the model we could try to predict the wheat variety in our test set. Note that the predict() function will help you to predict the class of the instance based on the seven variables. So we should store them into another variable. Later we can compare our predictions with the existing data. This would give us a confidence in our model. so let us begin…
predictedClass <- predict(fit, newdata = test, probability = T)
predictedClass
## 7 10 19 22 27 32 33 37
## Kama Kama Kama Kama Kama Kama Kama Kama
## 41 44 45 47 51 54 56 57
## Kama Kama Kama Kama Kama Kama Kama Kama
## 66 72 75 76 82 83 89 92
## Kama Rosa Rosa Rosa Rosa Rosa Rosa Rosa
## 95 99 107 109 110 111 117 129
## Rosa Rosa Rosa Rosa Rosa Rosa Rosa Rosa
## 132 136 149 152 157 159 162 163
## Rosa Kama Canadian Canadian Canadian Canadian Canadian Canadian
## 166 169 176 177 184 186 188 192
## Canadian Canadian Canadian Canadian Canadian Canadian Canadian Canadian
## 196 209 210
## Canadian Canadian Canadian
## Levels: Canadian Kama Rosa
confusionMatrix(data = predictedClass, test$variety)
## Confusion Matrix and Statistics
##
## Reference
## Prediction Canadian Kama Rosa
## Canadian 17 0 0
## Kama 0 17 1
## Rosa 0 0 16
##
## Overall Statistics
##
## Accuracy : 0.9804
## 95% CI : (0.8955, 0.9995)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9706
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: Canadian Class: Kama Class: Rosa
## Sensitivity 1.0000 1.0000 0.9412
## Specificity 1.0000 0.9706 1.0000
## Pos Pred Value 1.0000 0.9444 1.0000
## Neg Pred Value 1.0000 1.0000 0.9714
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3333 0.3137
## Detection Prevalence 0.3333 0.3529 0.3137
## Balanced Accuracy 1.0000 0.9853 0.9706
#x <- confusionMatrix(data = predictedClass, test$variety)
#y<-as.data.frame(t(x$byClass))
#View(y)
#z<-as.data.frame(x$byClass)
#plot(z$Sensitivity, z$Specificity, xlab = "Sensitivity", ylab = "Specificity", main = "Prediction Performance")
Great. Now read through the class table in the confusion matrix. You could see the Sensitivity/Specificity etc…. The overall accuracy of our model is 0.98 at 95%CI. I am having difficulty in ROC curves and AUC calculations.. That is for next post. All the best.
Great post Raj. Looking for more posts from you :)
ReplyDeleteRaj - Very well written. Wonderful post. I have bookmarked it so that I can refer back to it at times :) (The order in which you have done this is very important and will serve as a quick cheat sheet for me). Many thanks again!
ReplyDeleteArun