Decision Trees in R: Classification and Regression Trees

Decision trees are a highly useful visual aid in analysing a series of predicted outcomes for a particular model. As such, it is often used as a supplement (or even alternative to) regression analysis in determining how a series of explanatory variables will impact the dependent variable.

In this particular example, we analyse the impact of explanatory variables of age, gender, education, marital status, and income on the dependent variable sales across a series of customers.

1. Firstly, we load our dataset and create a response variable (which is used for the classification tree since we need to convert sales from a numerical to categorical variable):

#Load dataset and create response variable
setwd("directory")
fullData <- read.csv("filepath")
attach(fullData)
fullData$response[Sales > 10000] <- ">10000"
fullData$response[Sales > 1000 & Sales <= 10000] <- ">1000 & <10000"
fullData$response[Sales <= 1000] <- "<1000"
fullData$response<-as.factor(fullData$response)
str(fullData)

2. We then create the training and test data (i.e. the data that we will use to create our model and then the data we will test this data against):

#Create training and test data
factor32 <- sapply(fullData, function(x) class(x) == "factor" & nlevels(x) > 32)
fullData <- fullData[, !factor32]
str(fullData)
train <- sample (1:nrow(fullData), size=0.8*nrow(fullData)) # training row indices
inputData <- fullData[train, ] # training data
testData <- fullData[-train, ] # test data

3. Then, our classification tree is created:

#Classification Tree
library(rpart)
formula=response~Age+Gender+Education+Relationship+Income
fitclas=rpart(formula,data=inputData,method="class",control=rpart.control(minsplit=30,cp=0.001))
plot(fitclas)
text(fitclas)
summary(fitclas)
printcp(fitclas)
plotcp(fitclas)
printcp(fitclas)

Note that the cp value is what indicates our desired tree size - we see that our X-val relative error is minimised when our size of tree value is 4. Therefore, the decision tree is created using the fitclas variable by taking into account this variable.

4. The decision tree is then "pruned", where inappropriate nodes are removed from the tree to prevent overfitting of the data:

#Prune the Tree and Plot
pfitclas<- prune(fitclas, cp=fitclas$cptable[which.min(fitclas$cptable[,"xerror"]),"CP"])
plot(pfitclas, uniform=TRUE,
main="Pruned Classification Tree For Sales")

text(pfitclas, use.n=TRUE, all=TRUE, cex=.8)

5. The model is now tested against the test data, and we see that we have a misclassification percentage of 16.75%. Clearly, the lower the better, since this indicates our model is more accurate at predicting the "real" data:

#Model Testing
out<-predict(pfitclas)
pred.response <- colnames(out)[max.col(out, ties.method = c("first"))] # predicted
input.response <- as.character (testData$response) # actuals
mean (input.response != pred.response) # misclassification %

6. When the dependent variable is numerical rather than categorical, we will want to set up a regression tree instead as follows:

#Regression Tree
fitreg <- rpart(Sales~Age+Gender+Education+Relationship+Income, method="anova", data=inputData)
printcp(fitreg) # display the results
plotcp(fitreg) # visualize cross-validation results
summary(fitreg) # detailed summary of splits

#Create Additional Plots
par(mfrow=c(1,2)) # two plots on one page
rsq.rpart(fitreg) # visualize cross-validation results

#Prune the Tree
pfitreg<- prune(fitreg, cp=fitreg$cptable[which.min(fitreg$cptable[,"xerror"]),"CP"]) # from cptable

#Plot The Pruned Tree
plot(pfitreg, uniform=TRUE,
main="Pruned Regression Tree for Sales")

text(pfitreg, use.n=TRUE, all=TRUE, cex=.8)

7. A random forest prediction (which is a collating of many decision trees) can also be performed in order to view results across many trials and determine the importance of each predictor:

#Random Forest Prediction of Sales Data
library(randomForest)
fitregforest <- randomForest(Sales~Age+Gender+Education+Relationship+Income,data=inputData)
print(fitregforest) # view results
importance(fitregforest) # importance of each predictor