Boook club “Hands-On Machine Learning with R” #5

Chp7 Splines - Chp 8 KNN

Chapter 7 Multivariate Adaptive Regression Splines


  • Multivariate adaptive regression splines (MARS)
  • Automatically creates a piecewise linear model
  • Inherently nonlinear


  • Will search for, and discover, nonlinearities and interactions in the data that help maximize predictive accuracy


  • Hinge function
  • Looks for the single point across the range of X values where two different linear relationships between Y and X achieve the smallest error

Capturing non-linear relationships

  • Polynomials

  • Step functions

  • Require specifications by the user

    • Which variables should have what specific degree of interaction or at what points of a variable \(X\) should cut points be made for the step functions

Multivariate adaptive regression splines (MARS)

  • Capture the nonlinear relationships in the data by assessing cutpoints (knots) similar to step functions

  • The procedure assesses each data point for each predictor as a knot and creates a linear regression model with the candidate feature(s)

  • Many knots may have a good fit in training data, but may not generalize to new data

  • Pruning: remove knots that do not contribute to predictive accuracy using, e.g. cross-validation

Fitting a basic MARS model with earth package

library(dplyr)    # for data manipulation
library(ggplot2)  # for awesome graphics
library(rsample)  # splitting data
library(caret)    # for cross-validation, etc.
library(vip)      # variable importance
library(modeldata) # ames data

# Stratified sampling with the rsample package
set.seed(77654) # I used a different seed than in the book
split <- initial_split(ames, prop = 0.7, strata = "Sale_Price")
ames_train  <- training(split)
ames_test   <- testing(split)

mars1 <- earth::earth(
  Sale_Price ~ .,  
  data = ames_train   
Selected 38 of 41 terms, and 24 of 276 predictors
Termination condition: RSq changed by less than 0.001 at 41 terms
Importance: Gr_Liv_Area, Year_Built, Total_Bsmt_SF, Kitchen_AbvGr, ...
Number of terms at each degree of interaction: 1 37 (additive model)
GCV 559029270    RSS 1.063131e+12    GRSq 0.9122792    RSq 0.9185039
# hinge functions produced from the original 307 predictors
summary(mars1) %>% .$coefficients %>% head(10)
(Intercept)           259584.69398
h(Gr_Liv_Area-3082)       93.14554
h(3082-Gr_Liv_Area)      -61.74407
h(Year_Built-2003)      5435.63870
h(2003-Year_Built)      -397.66277
h(Total_Bsmt_SF-2330)   -729.49718
h(2330-Total_Bsmt_SF)    -23.57389
h(Total_Bsmt_SF-1657)    110.09111
h(Bsmt_Unf_SF-555)       -23.77975
h(555-Bsmt_Unf_SF)        11.41421

Performance and residual plots

  • Generalized cross-validation (GCV) \(R^2\) (solid black line; left-hand y-axis)
  • Number of terms retained in the model (x-axis)
  • Number of original predictors (right-hand y-axis)
plot(mars1, which = 1)

Interactions between different hinge functions

# degree = 2: interaction terms between a maximum of two hinge functions (e.g., h(2004-Year_Built)*h(Total_Bsmt_SF-1330)
mars2 <- earth::earth(
  Sale_Price ~ .,  
  data = ames_train,
  degree = 2
# check out the first 10 coefficient terms
summary(mars2) %>% .$coefficients %>% head(10)
(Intercept)                               3.461254e+05
h(Gr_Liv_Area-3082)                       7.303777e+02
h(3082-Gr_Liv_Area)                      -6.622323e+01
h(Year_Built-2003)                        1.298095e+04
h(2003-Year_Built)                       -1.122875e+03
h(2330-Total_Bsmt_SF)                    -4.846894e+01
h(2003-Year_Built)*h(Total_Bsmt_SF-1237) -1.189485e+00
h(2003-Year_Built)*h(1237-Total_Bsmt_SF)  5.062716e-01
h(Bsmt_Unf_SF-876)*h(3082-Gr_Liv_Area)   -1.538933e-02
h(876-Bsmt_Unf_SF)*h(3082-Gr_Liv_Area)    5.832110e-03

Tuning hyperparameters

  • The maximum degree of interactions
  • The number of terms retained in the final model
  • Perform a CV grid search to identify the optimal hyperparameter mix
# degree: degree of interactions
# nprune: number of terms to retain

# create a tuning grid
hyper_grid <- expand.grid(
  degree = 1:3, 
  nprune = seq(2, 100, length.out = 10) %>% floor()

  degree nprune
1      1      2
2      2      2
3      3      2
4      1     12
5      2     12
6      3     12

Tuning hyperparameters with caret

  • Grid search using 10-fold CV
  • The optimal model’s cross-validated RMSE was $26,817 in the book and $27,246.61 in this example with different seed
  • The optimal model retains 56 terms and includes up to 2nd degree interactions in the book and 45 terms and 1 degree interactions with the seed I chose
# Cross-validated model
set.seed(123)  # for reproducibility
cv_mars <- train(
  x = subset(ames_train, select = -Sale_Price),
  y = ames_train$Sale_Price,
  method = "earth",
  metric = "RMSE",
  trControl = trainControl(method = "cv", number = 10),
  tuneGrid = hyper_grid
  nprune degree
5     45      1
cv_mars$results %>%
  filter(nprune == cv_mars$bestTune$nprune, degree == cv_mars$bestTune$degree)
  degree nprune     RMSE  Rsquared      MAE   RMSESD RsquaredSD    MAESD
1      1     45 27246.61 0.8844656 17822.15 3950.652 0.02631387 1398.517
ggplot(cv_mars) # mind different seed

      RMSE          Rsquared           MAE          Resample        
 Min.   :21382   Min.   :0.8429   Min.   :15338   Length:10         
 1st Qu.:25083   1st Qu.:0.8684   1st Qu.:16760   Class :character  
 Median :27005   Median :0.8781   Median :18240   Mode  :character  
 Mean   :27247   Mean   :0.8845   Mean   :17822                     
 3rd Qu.:29091   3rd Qu.:0.9046   3rd Qu.:18775                     
 Max.   :35227   Max.   :0.9242   Max.   :19859                     

Comparing MARS with other modelling approaches

       RMSE  Rsquared      MAE Resample
1  26592.75 0.8747269 18206.57   Fold06
2  28382.27 0.8990683 18886.13   Fold03
3  21381.58 0.9165223 15338.32   Fold07
4  25412.92 0.8811599 16541.65   Fold10
5  35226.86 0.8428650 19858.67   Fold04
6  29327.27 0.8663435 18838.37   Fold08
7  27417.38 0.8749882 18584.27   Fold01
8  23167.69 0.9241581 16278.33   Fold05
9  30584.40 0.8584081 18273.82   Fold09
10 24972.94 0.9064159 17415.43   Fold02

Feature interpretation

  • earth has backwards elimination feature selection tool
  • This tool looks at reductions in the GCV estimate of error as each predictor is added to the model: value = "gcv"
  • MARS automatically includes and excludes terms during the pruning process (automated feature selection)
  • Feature never included in the final model-> importance value=0
  • An alternative: the change in the residual sums of squares (RSS) as terms are added (value = "rss")
  • No measuring of the impact for particular hinge functions created for a given feature
# variable importance plots
p1 <- vip(cv_mars, num_features = 40, geom = "point", value = "gcv") + ggtitle("GCV")
p2 <- vip(cv_mars, num_features = 40, geom = "point", value = "rss") + ggtitle("RSS")

gridExtra::grid.arrange(p1, p2, ncol = 2)

Feature interpretation: hinge functions interactions

  • Investigate interactions
  • Create partial dependence plots (PDPs) for each feature individually and also together

# extract coefficients, convert to tidy data frame, and filter for interaction terms
cv_mars$finalModel %>%
  coef(.) %>%  
  broom::tidy(.) %>%  
  filter(stringr::str_detect(names, "\\*"))
# A tibble: 0 × 2
# … with 2 variables: names <chr>, x <dbl>
# no interactions with the seed I used
  • Model found that one knot in each feature provides the best fit
  • Gr_Liv_Area increases and for newer homes, Sale_Price increases dramatically
# Construct partial dependence plots
p1 <- partial(cv_mars, pred.var = "Gr_Liv_Area", grid.resolution = 10) %>% 
p2 <- partial(cv_mars, pred.var = "Year_Built", grid.resolution = 10) %>% 
p3 <- partial(cv_mars, pred.var = c("Gr_Liv_Area", "Year_Built"), 
              grid.resolution = 10) %>% 
  plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, colorkey = TRUE, 
              screen = list(z = -20, x = -60))

# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)

Attrition data example

  • MARS method and algorithm can be extended to handle classification problems and GLMs
# plot results

df <- attrition %>% mutate_if(is.ordered, factor, ordered = FALSE)

# Create training (70%) and test (30%) sets for the 
# rsample::attrition data.
set.seed(123)  # for reproducibility
churn_split <- initial_split(df, prop = .7, strata = "Attrition")
churn_train <- training(churn_split)
churn_test  <- testing(churn_split)


# cross validated model
tuned_mars <- train(
  x = subset(churn_train, select = -Attrition),
  y = churn_train$Attrition,
  method = "earth",
  trControl = trainControl(method = "cv", number = 10),
  tuneGrid = hyper_grid
# best model
  nprune degree
3     23      1

Attrition data example: compare MARS vs other approaches

Take home message for Chapter 7: MARS

  • MARS naturally handles mixed types of predictors (quantitative and qualitative)
    • Considers all possible binary partitions of the categories for a qualitative predictor into two groups
    • Each group then generates a pair of piecewise indicator functions for the two categories
  • Needs minimum feature engineering
    • Automated feature selection
    • Highly correlated predictors do not impede predictive accuracy (chooses the first one it happens to come across when scanning the features)
  • MARS models is that they’re typically slower to train

Chapter 8 K-Nearest Neighbors


  • K-nearest neighbor (KNN) is a very simple algorithm in which each observation is predicted based on its “similarity” to other observations


  • Have been successful in a large number of business problems
  • Useful for preprocessing purposes


  • Memory-based algorithm and cannot be summarized by a closed-form model
  • Training samples are required at run-time and predictions are made directly from the sample relationships

Measuring similarity

  • Algorithm identifies \(k\) observations that are “similar”/nearest to the new record being predicted and then uses the average response value (regression) or the most common class (classification) of those \(k\) observations as the predicted output

Distance measures

  • Euclidean distance: most common and measures the straight-line distance between two samples (i.e., how the crow flies)

  • Manhattan distance: the point-to-point travel time (i.e., city block) and is commonly used for binary predictors (e.g., one-hot encoded 0/1 indicator variables)

  • Minkowski distance

  • Mahalanobis distance

Attrition data example

library(dplyr)      # for data wrangling
library(ggplot2)    # for awesome graphics
library(rsample)    # for creating validation splits
library(recipes)    # for feature engineering
library(caret)      # for fitting KNN models

# create training (70%) set for the rsample::attrition data.
attrit <- attrition %>% mutate_if(is.ordered, factor, ordered = FALSE)
churn_split <- initial_split(attrit, prop = .7, strata = "Attrition")
churn_train <- training(churn_split)

# import MNIST training data
mnist <- dslabs::read_mnist()
[1] "train" "test" 
(two_houses <- ames_train[1:2, c("Gr_Liv_Area", "Year_Built")])
# A tibble: 2 × 2
  Gr_Liv_Area Year_Built
        <int>      <int>
1         896       1961
2         864       1971
# Euclidean
dist(two_houses, method = "euclidean")
2 33.52611
# Manhattan
dist(two_houses, method = "manhattan")
2 42


  • Euclidean distance is more sensitive to outliers
  • Most distance measures are sensitive to the scale of the features
  • Features with different scales bias the distance measures: predictors with the largest values contribute most to the distance between two samples
  • Standardizing numeric features
  • All categorical features must be represented numerically (one-hot encoded or encoded using another method (e.g., ordinal encoding))
  • KNN method is sensitive to noisy predictors (similar samples will have larger magnitudes and variability in distance)
## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <int>      <int> <int>
## 1 home1             4       2008   423
## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <int>      <int> <int>
## 1 home2             2       2008   424
## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <int>      <int> <int>
## 1 home3             3       1998     6

# The Euclidean distance between home1 and home3 is larger due to the larger difference in Year_Built with home2
features <- c("Bedroom_AbvGr", "Year_Built")

# distance between home 1 and 2
dist(rbind(home1[,features], home2[,features]))
##   1
## 2 2

# distance between home 1 and 3
dist(rbind(home1[,features], home3[,features]))
##          1
## 2 10.04988
# Year_Built has a much larger range (1875–2010) than Bedroom_AbvGr (0–8). The difference between 2 and 4 bedrooms is much more important than a 10 year difference in the age of a home


## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <dbl>      <dbl> <int>
## 1 home1          1.38       1.21   423
## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <dbl>      <dbl> <int>
## 1 home2         -1.03       1.21   424
## # A tibble: 1 x 4
##   home  Bedroom_AbvGr Year_Built    id
##   <chr>         <dbl>      <dbl> <int>
## 1 home3         0.176      0.881     6

# distance between home 1 and 2
dist(rbind(home1_std[,features], home2_std[,features]))
##          1
## 2 2.416244

# distance between home 1 and 3
dist(rbind(home1_std[,features], home3_std[,features]))
##          1
## 2 1.252547

Choosing \(k\)

  • Performance of KNNs is very sensitive to the choice of \(k\)
  • low \(k\) values overfit and large values underfit
  • \(k\)=1 use 1 observation vs \(k\)=\(n\) is mean/most common class (classification) across all training samples as predicted value
  • High signal data with very few noisy (irrelevant) features, smaller values of \(k\) tend to work best
  • As more irrelevant features are involved, larger values of \(k\) are required to smooth out the noise
  • When using KNN for classification, it is best to assess odd numbers for \(k\) to avoid ties in the event there is equal proportion of response levels (i.e. when k = 2 one of the neighbors could have class “0” while the other neighbor has class “1”)

Choosing \(k\): example

blueprint <- recipe(Attrition ~ ., data = churn_train) %>%
  step_nzv(all_nominal()) %>%
  step_integer(contains("Satisfaction")) %>%
  step_integer(WorkLifeBalance) %>%
  step_integer(JobInvolvement) %>%
  step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>%
  step_center(all_numeric(), -all_outcomes()) %>%
  step_scale(all_numeric(), -all_outcomes())

# Create a resampling method
cv <- trainControl(
  method = "repeatedcv", 
  number = 10, 
  repeats = 5,
  classProbs = TRUE,                 
  summaryFunction = twoClassSummary

# Create a hyperparameter grid search
hyper_grid <- expand.grid(
  k = floor(seq(1, nrow(churn_train)/3, length.out = 20))

# Fit knn model and perform grid search
knn_grid <- train(
  data = churn_train, 
  method = "knn", 
  trControl = cv, 
  tuneGrid = hyper_grid,
  metric = "ROC"
# search grid results for Attrition training data where 20 values between 1 and 343 are assessed for k

MNIST example

  • 84 features representing the darkness (0–255) of pixels in images of handwritten numbers (0–9)
  • KNN models can be severely impacted by irrelevant features
# training initial models on a random sample of 10,000 rows from the training set
index <- sample(nrow(mnist$train$images), size = 10000)
mnist_x <- mnist$train$images[index, ]
mnist_y <- factor(mnist$train$labels[index])
  • Avoid zero, or near-zero variance features (see Section 3.4).
  • There are nearly 125 features that have zero variance and many more that have very little variation
mnist_x %>% %>%
  purrr::map_df(sd) %>%
  gather(feature, sd) %>%
  ggplot(aes(sd)) +
  geom_histogram(binwidth = 1)

MNIST example

  • Images (A)–(C) illustrate typical handwritten numbers from the test set
  • Image (D) illustrates which features in images have variability
  • The white - the features that represent the center pixels have regular variability whereas the black exterior - the features representing the edge pixels in have zero or near-zero variability

  • Add column names to the feature matrices as these are required by caret
  • Perform search grid
  • Best model used 3 nearest neighbors and provided an accuracy of 93.8%
# Rename features
colnames(mnist_x) <- paste0("V", 1:ncol(mnist_x))

# Remove near zero variance features manually
nzv <- nearZeroVar(mnist_x)
index <- setdiff(1:ncol(mnist_x), nzv)
mnist_x <- mnist_x[, index]
# Use train/validate resampling method
cv <- trainControl(
  method = "LGOCV", 
  p = 0.7,
  number = 1,
  savePredictions = TRUE
# Create a hyperparameter grid search. Hyperparameter grid search assesses 13 k values between 1–25
hyper_grid <- expand.grid(k = seq(3, 25, by = 2))
# Execute grid search
knn_mnist <- train(
  method = "knn",
  tuneGrid = hyper_grid,
  preProc = c("center", "scale"),
  trControl = cv

MNIST example

  • The most common incorrectly predicted digit is 1 (lowest specificity)
# Create confusion matrix
cm <- confusionMatrix(knn_mnist$pred$pred, knn_mnist$pred$obs)
cm$byClass[, c(1:2, 11)]  # sensitivity, specificity, & accuracy
         Sensitivity Specificity Balanced Accuracy
Class: 0   0.9641638   0.9962374         0.9802006
Class: 1   0.9916667   0.9841210         0.9878938
Class: 2   0.9155666   0.9955114         0.9555390
Class: 3   0.9163952   0.9920325         0.9542139
Class: 4   0.8698630   0.9960538         0.9329584
Class: 5   0.9151404   0.9914891         0.9533148
Class: 6   0.9795322   0.9888684         0.9842003
Class: 7   0.9326520   0.9896962         0.9611741
Class: 8   0.8224382   0.9978798         0.9101590
Class: 9   0.9329897   0.9852687         0.9591292

MNIST example: Feature importance for KNNs

  • Plot these results to get an understanding of what pixel features are driving our results
  • The most influential features lie around the edges of numbers (outer white circle) and along the very center
# Top 20 most important features
vi <- varImp(knn_mnist)
ROC curve variable importance

  variables are sorted by maximum importance across the classes
  only 20 most important variables shown (out of 249)

         X0     X1     X2     X3     X4     X5     X6     X7     X8    X9
V435 100.00 100.00 100.00 100.00 100.00 100.00 100.00 100.00 100.00 80.56
V407  99.42  99.42  99.42  99.42  99.42  99.42  99.42  99.42  99.42 75.21
V463  97.88  97.88  97.88  97.88  97.88  97.88  97.88  97.88  97.88 83.27
V379  97.38  97.38  97.38  97.38  97.38  97.38  97.38  97.38  97.38 86.56
V434  95.87  95.87  95.87  95.87  95.87  95.87  96.66  95.87  95.87 76.20
V380  96.10  96.10  96.10  96.10  96.10  96.10  96.10  96.10  96.10 88.04
V462  95.56  95.56  95.56  95.56  95.56  95.56  95.56  95.56  95.56 83.38
V408  95.37  95.37  95.37  95.37  95.37  95.37  95.37  95.37  95.37 75.05
V352  93.55  93.55  93.55  93.55  93.55  93.55  93.55  93.55  93.55 87.13
V490  93.07  93.07  93.07  93.07  93.07  93.07  93.07  93.07  93.07 81.88
V406  92.90  92.90  92.90  92.90  92.90  92.90  92.90  92.90  92.90 74.55
V437  70.79  60.44  92.79  52.04  71.11  83.42  75.51  91.15  52.02 70.79
V351  92.41  92.41  92.41  92.41  92.41  92.41  92.41  92.41  92.41 82.08
V409  70.55  76.12  88.11  54.54  79.94  77.69  84.88  91.91  52.72 76.12
V436  89.96  89.96  90.89  89.96  89.96  89.96  91.39  89.96  89.96 78.83
V464  76.73  76.51  90.24  76.51  76.51  76.58  77.67  82.02  76.51 76.73
V491  89.49  89.49  89.49  89.49  89.49  89.49  89.49  89.49  89.49 77.41
V598  68.01  68.01  88.44  68.01  68.01  84.92  68.01  88.25  68.01 38.76
V465  63.09  36.58  87.68  38.16  50.72  80.62  59.88  84.28  57.13 63.09
V433  63.74  55.69  76.69  55.69  57.43  55.69  87.59  68.44  55.69 63.74
# Get median value for feature importance
imp <- vi$importance %>%
  tibble::rownames_to_column(var = "feature") %>%
  gather(response, imp, -feature) %>%
  group_by(feature) %>%
  summarize(imp = median(imp))
# Create tibble for all edge pixels
edges <- tibble::tibble(
  feature = paste0("V", nzv),
  imp = 0

# Combine and plot
imp <- rbind(imp, edges) %>%
  mutate(ID  = as.numeric(stringr::str_extract(feature, "\\d+"))) %>%
image(matrix(imp$imp, 28, 28), col = gray(seq(0, 1, 0.05)), 
      xaxt="n", yaxt="n")

MNIST example: correctly vs incorrectly classified predictions

# Get a few accurate predictions
good <- knn_mnist$pred %>%
  filter(pred == obs) %>%

# Get a few inaccurate predictions
bad <- knn_mnist$pred %>%
  filter(pred != obs) %>%

combine <- bind_rows(good, bad)

# Get original feature set with all pixel features
index <- sample(nrow(mnist$train$images), 10000)
X <- mnist$train$images[index,]
# Plot results
par(mfrow = c(4, 2), mar=c(1, 1, 1, 1))
layout(matrix(seq_len(nrow(combine)), 4, 2, byrow = FALSE))
for(i in seq_len(nrow(combine))) {
  image(matrix(X[combine$rowIndex[i],], 28, 28)[, 28:1], 
        col = gray(seq(0, 1, 0.05)),
        main = paste("Actual:", combine$obs[i], "  ", 
                     "Predicted:", combine$pred[i]),
        xaxt="n", yaxt="n") 

Take home message for Chapter 8

  • Simple and intuitive algorithm which with “average to decent predictive power”
  • Drawback of KNNs is their computation time (increases by \(n*p\)for each observation)
  • Lazy learner: requires the model be run at prediction time which limits their use for real-time modeling
  • Rarely provide the best predictive performance
  • Feature engineering and in data cleaning and preprocessing
    • KNNs may be used to add a local knowledge feature (running a KNN to estimate the predicted output or class and using predicted value as a new feature for downstream modeling)