Decision Trees in R: Classification and Regression Trees

Decision trees are a highly useful visual aid in analysing a series of predicted outcomes for a particular model. As such, it is often used as a supplement (or even alternative to) regression analysis in determining how a series of explanatory variables will impact the dependent variable.

In this particular example, we analyse the impact of explanatory variables of age, gender, education, marital status, and income on the dependent variable sales across a series of customers.

1. Firstly, we load our dataset and create a response variable (which is used for the classification tree since we need to convert sales from a numerical to categorical variable):

> #Load dataset and create response variable
> setwd("directory")
> fullData <- read.csv("filepath")
> attach(fullData)
> fullData$response[Sales > 10000] <- ">10000"
> fullData$response[Sales > 1000 & Sales <= 10000] <- ">1000 & <10000"
> fullData$response[Sales <= 1000] <- "<1000"
> fullData$response<-as.factor(fullData$response)
> str(fullData)

2. We then create the training and test data (i.e. the data that we will use to create our model and then the data we will test this data against):

> #Create training and test data
> factor32 <- sapply(fullData, function(x) class(x) == "factor" & nlevels(x) > 32)
> fullData <- fullData[, !factor32]
> str(fullData)
> train <- sample (1:nrow(fullData), size=0.8*nrow(fullData)) # training row indices
> inputData <- fullData[train, ] # training data
> testData <- fullData[-train, ] # test data

3. Then, our classification tree is created:

> #Classification Tree
> library(rpart)
> formula=response~Age+Gender+Education+Relationship+Income
> dtree=rpart(formula,data=inputData,method="class",control=rpart.control(minsplit=30,cp=0.001))
> plot(dtree)
> text(dtree)
> summary(dtree)
> printcp(dtree)
> plotcp(dtree)
> printcp(dtree)

Note that the cp value is what indicates our desired tree size - we see that our X-val relative error is minimised when our size of tree value is 4. Therefore, the decision tree is created using the dtree variable by taking into account this variable.

> summary(fitclas)

Call:
rpart(formula = formula, data = inputData, method = "class", 
    control = rpart.control(minsplit = 30, cp = 0.001))
  n= 800 

          CP nsplit rel error    xerror      xstd
1 0.06751055      0 1.0000000 1.0000000 0.1068093
2 0.02953586      3 0.7974684 0.9493671 0.1043584
3 0.02531646      6 0.7088608 1.0506329 0.1091758
4 0.00100000      8 0.6582278 1.1012658 0.1114634

4. The decision tree is then "pruned", where inappropriate nodes are removed from the tree to prevent overfitting of the data:

> #Prune the Tree and Plot
> pdtree<- prune(dtree, cp=dtree$cptable[which.min(dtree$cptable[,"xerror"]),"CP"])
> plot(pdtree, uniform=TRUE,
     main="Pruned Classification Tree For Sales")
> text(pdtree, use.n=TRUE, all=TRUE, cex=.8)

5. The model is now tested against the test data, and we see that we have a misclassification percentage of 16.75%. Clearly, the lower the better, since this indicates our model is more accurate at predicting the "real" data:

> #Model Testing
> out<-predict(pdtree)
> response_predicted<- colnames(out)[max.col(out, ties.method = c("first"))] # predicted
> response_input<- as.character (testData$response) # actuals
> mean (response_input != response_predicted) # misclassification %
[1] 0.1675

6. When the dependent variable is numerical rather than categorical, we will want to set up a regression tree instead as follows:

> #Regression Tree
> fitreg <- rpart(Sales~Age+Gender+Education+Relationship+Income,
             method="anova", data=inputData)
> printcp(fitreg) # display the results
> plotcp(fitreg) # visualize cross-validation results
> summary(fitreg) # detailed summary of splits

> #Create Additional Plots
> par(mfrow=c(1,2)) # two plots on one page
> rsq.rpart(fitreg) # visualize cross-validation results

Regression tree:
rpart(formula = Sales ~ Age + Gender + Education + Relationship + 
    Income, data = inputData, method = "anova")

Variables actually used in tree construction:
[1] Age          Income       Relationship

Root node error: 6.545e+09/800 = 8181215

n= 800 

        CP nsplit rel error  xerror     xstd
1 0.215653      0   1.00000 1.00358 0.031966
2 0.058822      1   0.78435 0.78808 0.034297
3 0.057839      2   0.72553 0.77490 0.033874
4 0.034801      3   0.66769 0.70979 0.031150
5 0.015115      4   0.63289 0.65365 0.029596
6 0.010682      5   0.61777 0.65254 0.029500
7 0.010000      6   0.60709 0.65592 0.029128

Now, we prune our regression tree:

> #Prune the Tree
> pfitreg<- prune(fitreg, cp=fitreg$cptable[which.min(fitreg$cptable[,"xerror"]),"CP"]) # from cptable

> #Plot The Pruned Tree
> plot(pfitreg, uniform=TRUE,
     main="Pruned Regression Tree for Sales")
> text(pfitreg, use.n=TRUE, all=TRUE, cex=.8)

Dataset

salesanddecisiondataset.csv