Building a classification tree in R

In week 6 of the Data Analysis course offered freely on Coursera, there was a lecture on building classification trees in R (also known as decision trees). I thoroughly enjoyed the lecture and here I reiterate what was taught, both to re-enforce my memory and for sharing purposes.

I will jump straight into building a classification tree in R and explain the concepts along the way. We will use the iris dataset, which gives measurements in centimeters of the variables sepal length and width, and petal length and width, respectively, for 50 flowers from three different species of iris.

data(iris)
names(iris)
[1] "Sepal.Length" "Sepal.Width"  "Petal.Length" "Petal.Width"  "Species"   
table(iris$Species)

    setosa versicolor  virginica 
        50         50         50
#install if necessary
install.packages("ggplot2")
library(ggplot2)
qplot(Petal.Width, Sepal.Width, data=iris, colour=Species, size=I(4))

iris_petal_width_vs_sepal_widthDifferent species have characteristic sepal and petal widths.

The basic idea of a classification tree is to first start with all variables in one group; imagine all the points in the above scatter plot. Then find some characteristic that best separates the groups, for example the first split could be asking whether petal widths are less than or greater than 0.8. Then continue this process until the partitions have sufficiently homogeneous or are too small.

#building the classification tree
#install if necessary
install.packages("tree")
library(tree)
tree1 <- tree(Species ~ Sepal.Width + Petal.Width, data = iris)
summary(tree1)
Classification tree:
tree(formula = Species ~ Sepal.Width + Petal.Width, data = iris)
Number of terminal nodes:  5 
Residual mean deviance:  0.204 = 29.57 / 145 
Misclassification error rate: 0.03333 = 5 / 150
plot(tree1)
text(tree1)

iris_classification_tree_petal_sepal_widthA classification tree showing at each internal node the feature property and at each terminal node the species.

plot(iris$Petal.Width,iris$Sepal.Width,pch=19,col=as.numeric(iris$Species))
partition.tree(tree1,label="Species",add=TRUE)
legend(1.75,4.5,legend=unique(iris$Species),col=unique(as.numeric(iris$Species)),pch=19)

iris_tree_partitionThe partitions are defined by the classification tree above. For example, the first node partitions every species with petal width < 0.8 as setosa. Next, all species with petal width > 1.75 are virginica and so on.

graph <- qplot(Petal.Width, Sepal.Width, data=iris, colour=Species, size=I(4))
graph + geom_hline(aes(yintercept=2.65)) + geom_vline(aes(xintercept=0.8)) + geom_vline(aes(xintercept=1.75)) + geom_vline(aes(xintercept=1.35))

iris_partition_ggplotI counted 3 misclassifications, however from the output of summary(tree1) there were 5. I redid the partitions using ggplot2 but I still only observe 3.

Using more variables

I used two variables above, Petal.Width and Sepal.Width to illustrate the classification process. We can include all four variables in the classification process:

tree1 <- tree(Species ~ Sepal.Width + Sepal.Length + Petal.Length + Petal.Width, data = iris)
summary(tree1)

Classification tree:
tree(formula = Species ~ Sepal.Width + Sepal.Length + Petal.Length + 
    Petal.Width, data = iris)
Variables actually used in tree construction:
[1] "Petal.Length" "Petal.Width"  "Sepal.Length"
Number of terminal nodes:  6 
Residual mean deviance:  0.1253 = 18.05 / 144 
Misclassification error rate: 0.02667 = 4 / 150

We get a slightly lower misclassification error rate (0.02667) and here’s how the classification tree looks:

plot(tree1)
text(tree1)

classification_tree_iris_4_variables

Let’s check some of these by subsetting the iris dataset:

#Petal.Length < 2.45
iris[iris$Petal.Length<2.45,5]
 [1] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[18] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
[35] setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa setosa
Levels: setosa versicolor virginica
#we get all 50 setosa
length(iris[iris$Petal.Length<2.45,5])
[1] 50
iris[iris$Petal.Length>2.45&iris$Petal.Width>1.75,5]
 [1] versicolor virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[11] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[21] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[31] virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
[41] virginica  virginica  virginica  virginica  virginica  virginica 
Levels: setosa versicolor virginica
#most of the virginica
length(iris[iris$Petal.Length>2.45&iris$Petal.Width>1.75,5])
[1] 46
#2 misclassifications
iris[iris$Petal.Length>2.45&iris$Petal.Width<1.75&iris$Petal.Length>4.95,5]
[1] versicolor versicolor virginica  virginica  virginica  virginica 
Levels: setosa versicolor virginica
#viewing the Petal.Length of all species
boxplot(formula=Petal.Length ~ Species, data=iris, xlab="Species", ylab="Petal length")

boxplot_petal_length_speciesWe can easily distinguish setosa species by their petal lengths.

Prettier classification trees in R using the rpart package

# install if necessary
install.packages('rpart')
install.packages('rattle')

# load libraries
library(rpart)
library(rattle)

rpart <- rpart(Species ~ ., data=iris, method="class",)

rpart
n= 150 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
  2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
  3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
    6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
    7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *

# plot decision tree
fancyRpartPlot(rpart, main="Iris")

iris_fancy_rpart_plotSame story as above but a fancier classification tree.

Conclusions

One of the disadvantages of decision trees may be overfitting i.e. continually creating partitions to achieve a relatively homogeneous population. This problem can be alleviated by pruning the tree, which is basically removing the decisions from the bottom up. Another way is to combine several trees and obtain a consensus, which can be done via a process called random forests. I have previously used random forests to predict different wines based on their chemical properties.

Print Friendly, PDF & Email



Creative Commons License
This work is licensed under a Creative Commons
Attribution 4.0 International License
.
3 comments Add yours
  1. #install if necessary
    install.packages(‘rpart’)
    library(rpart)

    rpart fancyRpartPlot(rpart, main=”Iris”)
    Error: could not find function “fancyRpartPlot” –> why ?

    Thank you

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.