wheat Grain Classification
Raja Reddy
Friday, September 18, 2015
Machine Learning for three class data classification.
Can we classifiy Wheat varities based on images?
Why this analysis :
This is my attempt to practice Machine Learning (ML). I am sourcing the data from UCI Machine Learning Repository. Your suggestions and feedback are more than welcome. ### What is this analysis about: Classification of wheat varieties based on physical measurements such as area, perimeter and length extracted from images has been explored for long. In fact the seeds data set in UCI Machine Learning Repository is created to explore this very objective. Several attempts with different ML algorithms have been made to use this data set. Here i plan to explore the random Forest algorithm as implemented in R to classify this data.
How this analysis is structured
Broadly, i classify this analysis into 1) Descriptive 2) Exploratory 3) Model building & choice 4) Predicting and evaluation.
First step to do is to set the working directory with “setwd()”. Then load all the required packages:
library(randomForest)
## randomForest 4.6-10
## Type rfNews() to see new features/changes/bug fixes.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(ggplot2)
library(lattice)
library(corrplot)
Now get the data from URL to R environment. But this is how you can source the data from UCI-ML repo. I don’t want to execute this portion. (Ignore if you already have this data in your working directory)
#url <- "http://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt"
#download.file(url, destfile = "Wheat.txt")
#The downloaded file name is wheat.txt (it is a tab delimited text file) If you want to use this file for your work use this following commands
I will use the data i had downloaded.
wheat <- read.csv("Wheat.txt", header = F, sep = "\t")
#wheat <- read.table("Wheat1.txt", header = F, sep = "\t")
names(wheat)<- c("Area","Perimeter","compactness","length","width","asymmetry","grooveLength","variety")
wheat$variety <- as.character(wheat$variety)
wheat$variety[wheat$variety == "1"] <- "Kama"
wheat$variety[wheat$variety == "2"] <- "Rosa"
wheat$variety[wheat$variety == "3"] <- "Canadian"
nrow(wheat)
## [1] 210
Begin descriptive analysis.
Now let us explore this data. This will help us to understand the data by describing the data in terms of number of variables, types of variables etc. see what is there in the wheat data frame and column names with “head”.
head(wheat, n = 5)
##    Area Perimeter compactness length width asymmetry grooveLength variety
## 1 15.26     14.84      0.8710  5.763 3.312     2.221        5.220    Kama
## 2 14.88     14.57      0.8811  5.554 3.333     1.018        4.956    Kama
## 3 14.29     14.09      0.9050  5.291 3.337     2.699        4.825    Kama
## 4 13.84     13.94      0.8955  5.324 3.379     2.259        4.805    Kama
## 5 16.14     14.99      0.9034  5.658 3.562     1.355        5.175    Kama
Get summary of data with “summary()”. This will generate information on type of variables, spread etc. We will generate Box plots of variables to see the variability in observations. Alternatively with density plots, we could explore the frequency distributions of individual variable values.
summary(wheat)
##       Area         Perimeter      compactness         length     
##  Min.   :10.59   Min.   :12.41   Min.   :0.8081   Min.   :4.899  
##  1st Qu.:12.27   1st Qu.:13.45   1st Qu.:0.8569   1st Qu.:5.262  
##  Median :14.36   Median :14.32   Median :0.8734   Median :5.524  
##  Mean   :14.85   Mean   :14.56   Mean   :0.8710   Mean   :5.629  
##  3rd Qu.:17.30   3rd Qu.:15.71   3rd Qu.:0.8878   3rd Qu.:5.980  
##  Max.   :21.18   Max.   :17.25   Max.   :0.9183   Max.   :6.675  
##      width         asymmetry       grooveLength     variety         
##  Min.   :2.630   Min.   :0.7651   Min.   :4.519   Length:210        
##  1st Qu.:2.944   1st Qu.:2.5615   1st Qu.:5.045   Class :character  
##  Median :3.237   Median :3.5990   Median :5.223   Mode  :character  
##  Mean   :3.259   Mean   :3.7002   Mean   :5.408                     
##  3rd Qu.:3.562   3rd Qu.:4.7687   3rd Qu.:5.877                     
##  Max.   :4.033   Max.   :8.4560   Max.   :6.550
boxplot(wheat[,1:7],data=wheat, notch = T, main = "Wheat seed measurment variability", col=rainbow(length(unique(wheat))))
featurePlot(x = wheat[,1:7], y = as.factor(wheat$variety), plot = "density", scales = list(x=list(relation="free"), y=list(relation="free")), auto.key = T, main = "Density distributions of variables across varieties")
Begin exploratory analysis
In this phase i would like to analyze variables and their relations. One good way to do that is find to find the correlations among variables
dev = "png"
par(mfrow = c(1,2))
cce <- cor(wheat[,1:7], use = "pairwise", method="pearson")# caliculate correlations
corrplot(cce) #plot correlations
title("Correlation", line = -3)
# However i would like to see the plot sorted by correlations
cce.ord <- order(cce[1,])
cce.1 <- cce[cce.ord, cce.ord]
corrplot(cce.1)
title("sorted on correlations", line = -3)
Other ways to see correlations is scatter plot and parallel co-ordinate plot. These plots tell us that the relations between the variables across wheat varieties. I will do it with pairs() and parallel plot() functions
pairs(wheat[,1:7],pch=21, col=as.factor(wheat$variety))
parallelplot(~wheat[1:7] | variety, wheat)# not much useful but you could use it for testing
parallelplot(~wheat[1:7], wheat, groups = wheat$variety, auto.key = T, ylab = "Grain measurements", main = "Parallel coordinate plot of Variable relations across three varities")
Now that we understand our variables and their relations, let-us explore how these variables could be used for classification of our wheat varieties. For this purpose i am choosing randomForest machine learning algorithm. Though it is not necessary to have a test and training set information separately for randomForest model building and evaluation, i am sticking to the classic ML methodology as followed in supervised learning.
# create data partitioning
set.seed(756) # this helps in reproducing
inTrain<-createDataPartition(y = wheat$variety,p = 0.75, list=F)
train<- wheat[inTrain,]
str(train)
## 'data.frame':    159 obs. of  8 variables:
##  $ Area        : num  15.3 14.9 14.3 13.8 16.1 ...
##  $ Perimeter   : num  14.8 14.6 14.1 13.9 15 ...
##  $ compactness : num  0.871 0.881 0.905 0.895 0.903 ...
##  $ length      : num  5.76 5.55 5.29 5.32 5.66 ...
##  $ width       : num  3.31 3.33 3.34 3.38 3.56 ...
##  $ asymmetry   : num  2.22 1.02 2.7 2.26 1.35 ...
##  $ grooveLength: num  5.22 4.96 4.83 4.8 5.17 ...
##  $ variety     : chr  "Kama" "Kama" "Kama" "Kama" ...
test<-wheat[-inTrain,]
str(test)
## 'data.frame':    51 obs. of  8 variables:
##  $ Area        : num  14.7 16.4 14.7 14.1 13 ...
##  $ Perimeter   : num  14.5 15.2 14.2 14.3 13.8 ...
##  $ compactness : num  0.88 0.888 0.915 0.872 0.864 ...
##  $ length      : num  5.56 5.88 5.21 5.52 5.39 ...
##  $ width       : num  3.26 3.5 3.47 3.17 3.03 ...
##  $ asymmetry   : num  3.59 1.97 1.77 2.69 3.37 ...
##  $ grooveLength: num  5.22 5.53 4.65 5.22 4.83 ...
##  $ variety     : chr  "Kama" "Kama" "Kama" "Kama" ...
Begin Model building & choice
Now let us develop the model. As mentioned i am using randomForest. The problem at hand is a typical classification (type of variety) challenge. We will look into the regression type of problem in another post. Basically randomForest algorithm tries to build multiple decision trees (you can specify) picking the defined set of predictors/features, to arrive at response variable (in our case variety).
# call the randomForest function, specify the predictors and response
fit <- randomForest(x = train[,1:7], y = as.factor(train$variety), ntree = 1000, mtry = 7, importance = T, proximity = TRUE)
fit # this will print the model
## 
## Call:
##  randomForest(x = train[, 1:7], y = as.factor(train$variety),      ntree = 1000, mtry = 7, importance = T, proximity = TRUE) 
##                Type of random forest: classification
##                      Number of trees: 1000
## No. of variables tried at each split: 7
## 
##         OOB estimate of  error rate: 8.81%
## Confusion matrix:
##          Canadian Kama Rosa class.error
## Canadian       49    4    0  0.07547170
## Kama            7   44    2  0.16981132
## Rosa            0    1   52  0.01886792
plot(fit) # plots error rate over trees
varImpPlot(fit, main = "Average variable Importance") # plots variable imporentce.   
#You can use below commands to see for individual class
  #varImpPlot(fit, class = "Rosa", main = "Rosa Importance")
  #varImpPlot(fit, class = "Kama", main = "Kama Importance")
  #varImpPlot(fit, class = "Canadian", main = "Canadian Importance")
margins.rf <- margin(fit, train)
#the function margin(), measures the extent to which the average number of votes for the correct class exceeds the average vote for any other class present in the dependent variable. (ref:http://www.statsoft.com/Textbook/Random-Forest)
plot(margins.rf)
## Loading required package: RColorBrewer
hist(margins.rf)
boxplot(margins.rf~train$variety)
 ### Begin Predicting and evaluate Now that we have the model we could try to predict the wheat variety in our test set. Note that the predict() function will help you to predict the class of the instance based on the seven variables. So we should store them into another variable. Later we can compare our predictions with the existing data. This would give us a confidence in our model. so let us begin…
predictedClass <- predict(fit, newdata = test, probability = T)
predictedClass
##        7       10       19       22       27       32       33       37 
##     Kama     Kama     Kama     Kama     Kama     Kama     Kama     Kama 
##       41       44       45       47       51       54       56       57 
##     Kama     Kama     Kama     Kama     Kama     Kama     Kama     Kama 
##       66       72       75       76       82       83       89       92 
##     Kama     Rosa     Rosa     Rosa     Rosa     Rosa     Rosa     Rosa 
##       95       99      107      109      110      111      117      129 
##     Rosa     Rosa     Rosa     Rosa     Rosa     Rosa     Rosa     Rosa 
##      132      136      149      152      157      159      162      163 
##     Rosa     Kama Canadian Canadian Canadian Canadian Canadian Canadian 
##      166      169      176      177      184      186      188      192 
## Canadian Canadian Canadian Canadian Canadian Canadian Canadian Canadian 
##      196      209      210 
## Canadian Canadian Canadian 
## Levels: Canadian Kama Rosa
confusionMatrix(data = predictedClass, test$variety)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction Canadian Kama Rosa
##   Canadian       17    0    0
##   Kama            0   17    1
##   Rosa            0    0   16
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9804          
##                  95% CI : (0.8955, 0.9995)
##     No Information Rate : 0.3333          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9706          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: Canadian Class: Kama Class: Rosa
## Sensitivity                   1.0000      1.0000      0.9412
## Specificity                   1.0000      0.9706      1.0000
## Pos Pred Value                1.0000      0.9444      1.0000
## Neg Pred Value                1.0000      1.0000      0.9714
## Prevalence                    0.3333      0.3333      0.3333
## Detection Rate                0.3333      0.3333      0.3137
## Detection Prevalence          0.3333      0.3529      0.3137
## Balanced Accuracy             1.0000      0.9853      0.9706
#x <- confusionMatrix(data = predictedClass, test$variety)
#y<-as.data.frame(t(x$byClass))
#View(y)
#z<-as.data.frame(x$byClass)
#plot(z$Sensitivity, z$Specificity, xlab = "Sensitivity", ylab = "Specificity", main = "Prediction Performance")
Great. Now read through the class table in the confusion matrix. You could see the Sensitivity/Specificity etc…. The overall accuracy of our model is 0.98 at 95%CI. I am having difficulty in ROC curves and AUC calculations.. That is for next post. All the best.
Great post Raj. Looking for more posts from you :)
ReplyDeleteRaj - Very well written. Wonderful post. I have bookmarked it so that I can refer back to it at times :) (The order in which you have done this is very important and will serve as a quick cheat sheet for me). Many thanks again!
ReplyDeleteArun