Machine Learning Workflow: Classification Trees

Biostat 212A

Author

Dr. Jin Zhou @ UCLA

Published

February 24, 2025

Display system information for reproducibility.

import IPython
print(IPython.sys_info())
{'commit_hash': '22d6a1c16',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/jinjinzhou/.virtualenvs/r-reticulate/lib/python3.13/site-packages/IPython',
 'ipython_version': '8.31.0',
 'os_name': 'posix',
 'platform': 'macOS-15.3.1-arm64-arm-64bit-Mach-O',
 'sys_executable': '/Users/jinjinzhou/.virtualenvs/r-reticulate/bin/python',
 'sys_platform': 'darwin',
 'sys_version': '3.13.0 (main, Oct  7 2024, 05:02:14) [Clang 16.0.0 '
                '(clang-1600.0.26.4)]'}

1 Overview

Produced by OmniGraffle 7.9.4 2019-02-16 02:42:35 +0000 Canvas 1 Layer 1 All Data Training Testing Assessment Analysis Resample 1 Assessment Analysis Resample 2 Assessment Analysis Resample B

We illustrate the typical machine learning workflow for regression trees using the Heart data set from R ISLR2 package.

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: not much is needed for regression trees.

  3. Tune the cost complexity pruning hyper-parameter(s) using 10-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final prediction on the test data.

2 Heart data

The goal is to predict the binary outcome AHD (Yes or No) of patients.

sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: aarch64-apple-darwin20
Running under: macOS 15.3.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] digest_0.6.37     fastmap_1.2.0     xfun_0.47         Matrix_1.7-0     
 [5] lattice_0.22-6    reticulate_1.39.0 knitr_1.48        htmltools_0.5.8.1
 [9] png_0.1-8         rmarkdown_2.28    cli_3.6.3         grid_4.4.1       
[13] compiler_4.4.1    rstudioapi_0.16.0 tools_4.4.1       evaluate_1.0.0   
[17] Rcpp_1.0.13       yaml_2.3.10       rlang_1.1.4       jsonlite_1.8.9   
[21] htmlwidgets_1.6.4
library(GGally)
library(gtsummary)
library(ranger)
library(tidyverse)
library(tidymodels)
library(ISLR2)

Heart <- read_csv("../data/Heart.csv") 
Heart <-  Heart %>% 
  # first column is patient ID, which we don't need
  select(-1) %>%
  # RestECG is categorical with value 0, 1, 2
  mutate(RestECG = as.character(RestECG)) %>%
  print(width = Inf)
# A tibble: 303 × 14
     Age   Sex ChestPain    RestBP  Chol   Fbs RestECG MaxHR ExAng Oldpeak Slope
   <dbl> <dbl> <chr>         <dbl> <dbl> <dbl> <chr>   <dbl> <dbl>   <dbl> <dbl>
 1    63     1 typical         145   233     1 2         150     0     2.3     3
 2    67     1 asymptomatic    160   286     0 2         108     1     1.5     2
 3    67     1 asymptomatic    120   229     0 2         129     1     2.6     2
 4    37     1 nonanginal      130   250     0 0         187     0     3.5     3
 5    41     0 nontypical      130   204     0 2         172     0     1.4     1
 6    56     1 nontypical      120   236     0 0         178     0     0.8     1
 7    62     0 asymptomatic    140   268     0 2         160     0     3.6     3
 8    57     0 asymptomatic    120   354     0 0         163     1     0.6     1
 9    63     1 asymptomatic    130   254     0 2         147     0     1.4     2
10    53     1 asymptomatic    140   203     1 2         155     1     3.1     3
      Ca Thal       AHD  
   <dbl> <chr>      <chr>
 1     0 fixed      No   
 2     3 normal     Yes  
 3     2 reversable Yes  
 4     0 normal     No   
 5     0 normal     No   
 6     0 normal     No   
 7     2 normal     Yes  
 8     0 normal     No   
 9     1 reversable Yes  
10     0 reversable Yes  
# ℹ 293 more rows
Heart %>% tbl_summary()

Characteristic

N = 303

1
Age 56 (48, 61)
Sex 206 (68%)
ChestPain
    asymptomatic 144 (48%)
    nonanginal 86 (28%)
    nontypical 50 (17%)
    typical 23 (7.6%)
RestBP 130 (120, 140)
Chol 241 (211, 275)
Fbs 45 (15%)
RestECG
    0 151 (50%)
    1 4 (1.3%)
    2 148 (49%)
MaxHR 153 (133, 166)
ExAng 99 (33%)
Oldpeak 0.80 (0.00, 1.60)
Slope
    1 142 (47%)
    2 140 (46%)
    3 21 (6.9%)
Ca
    0 176 (59%)
    1 65 (22%)
    2 38 (13%)
    3 20 (6.7%)
    Unknown 4
Thal
    fixed 18 (6.0%)
    normal 166 (55%)
    reversable 117 (39%)
    Unknown 2
AHD 139 (46%)
1

Median (Q1, Q3); n (%)

# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 1.2)
# Display all columns
pd.set_option('display.max_columns', None)

Heart = pd.read_csv("../data/Heart.csv")
Heart
     Unnamed: 0  Age  Sex     ChestPain  RestBP  Chol  Fbs  RestECG  MaxHR  \
0             1   63    1       typical     145   233    1        2    150   
1             2   67    1  asymptomatic     160   286    0        2    108   
2             3   67    1  asymptomatic     120   229    0        2    129   
3             4   37    1    nonanginal     130   250    0        0    187   
4             5   41    0    nontypical     130   204    0        2    172   
..          ...  ...  ...           ...     ...   ...  ...      ...    ...   
298         299   45    1       typical     110   264    0        0    132   
299         300   68    1  asymptomatic     144   193    1        0    141   
300         301   57    1  asymptomatic     130   131    0        0    115   
301         302   57    0    nontypical     130   236    0        2    174   
302         303   38    1    nonanginal     138   175    0        0    173   

     ExAng  Oldpeak  Slope   Ca        Thal  AHD  
0        0      2.3      3  0.0       fixed   No  
1        1      1.5      2  3.0      normal  Yes  
2        1      2.6      2  2.0  reversable  Yes  
3        0      3.5      3  0.0      normal   No  
4        0      1.4      1  0.0      normal   No  
..     ...      ...    ...  ...         ...  ...  
298      0      1.2      2  0.0  reversable  Yes  
299      0      3.4      2  2.0  reversable  Yes  
300      1      1.2      2  1.0  reversable  Yes  
301      0      0.0      2  1.0      normal  Yes  
302      0      0.0      1  NaN      normal   No  

[303 rows x 15 columns]
# Numerical summaries
Heart.describe(include = 'all')
        Unnamed: 0         Age         Sex     ChestPain      RestBP  \
count   303.000000  303.000000  303.000000           303  303.000000   
unique         NaN         NaN         NaN             4         NaN   
top            NaN         NaN         NaN  asymptomatic         NaN   
freq           NaN         NaN         NaN           144         NaN   
mean    152.000000   54.438944    0.679868           NaN  131.689769   
std      87.612784    9.038662    0.467299           NaN   17.599748   
min       1.000000   29.000000    0.000000           NaN   94.000000   
25%      76.500000   48.000000    0.000000           NaN  120.000000   
50%     152.000000   56.000000    1.000000           NaN  130.000000   
75%     227.500000   61.000000    1.000000           NaN  140.000000   
max     303.000000   77.000000    1.000000           NaN  200.000000   

              Chol         Fbs     RestECG       MaxHR       ExAng  \
count   303.000000  303.000000  303.000000  303.000000  303.000000   
unique         NaN         NaN         NaN         NaN         NaN   
top            NaN         NaN         NaN         NaN         NaN   
freq           NaN         NaN         NaN         NaN         NaN   
mean    246.693069    0.148515    0.990099  149.607261    0.326733   
std      51.776918    0.356198    0.994971   22.875003    0.469794   
min     126.000000    0.000000    0.000000   71.000000    0.000000   
25%     211.000000    0.000000    0.000000  133.500000    0.000000   
50%     241.000000    0.000000    1.000000  153.000000    0.000000   
75%     275.000000    0.000000    2.000000  166.000000    1.000000   
max     564.000000    1.000000    2.000000  202.000000    1.000000   

           Oldpeak       Slope          Ca    Thal  AHD  
count   303.000000  303.000000  299.000000     301  303  
unique         NaN         NaN         NaN       3    2  
top            NaN         NaN         NaN  normal   No  
freq           NaN         NaN         NaN     166  164  
mean      1.039604    1.600660    0.672241     NaN  NaN  
std       1.161075    0.616226    0.937438     NaN  NaN  
min       0.000000    1.000000    0.000000     NaN  NaN  
25%       0.000000    1.000000    0.000000     NaN  NaN  
50%       0.800000    2.000000    0.000000     NaN  NaN  
75%       1.600000    2.000000    1.000000     NaN  NaN  
max       6.200000    3.000000    3.000000     NaN  NaN  

Graphical summary:

# Graphical summaries
plt.figure()
sns.pairplot(data = Heart);
plt.show()

3 Initial split into test and non-test sets

We randomly split the data in half of test data and another half of non-test data. Stratify on AHD.

# For reproducibility
set.seed(212)

data_split <- initial_split(
  Heart, 
  prop = 0.5,
  strata = AHD
  )
data_split
<Training/Testing/Total>
<151/152/303>
Heart_other <- training(data_split)
dim(Heart_other)
[1] 151  14
Heart_test <- testing(data_split)
dim(Heart_test)
[1] 152  14
from sklearn.model_selection import train_test_split

Heart_other, Heart_test = train_test_split(
  Heart, 
  train_size = 0.75,
  random_state = 425, # seed
  stratify = Heart.AHD
  )
Heart_test.shape
(76, 15)
Heart_other.shape
(227, 15)

Separate \(X\) and \(y\). We will use 13 features.

num_features = ['Age', 'Sex', 'RestBP', 'Chol', 'Fbs', 'RestECG', 'MaxHR', 'ExAng', 'Oldpeak', 'Slope', 'Ca']
cat_features = ['ChestPain', 'Thal']
features = np.concatenate([num_features, cat_features])
# Non-test X and y
X_other = Heart_other[features]
y_other = Heart_other.AHD
# Test X and y
X_test = Heart_test[features]
y_test = Heart_test.AHD

4 Preprocessing (Python) or recipe (R)

  • A data dictionary (roughly) is at https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/.

  • We have following features:

    • Numerical features: Age, RestBP, Chol, Slope (1, 2 or 3), MaxHR, ExAng, Oldpeak, Ca (0, 1, 2 or 3).

    • Categorical features coded as integer: Sex (0 or 1), Fbs (0 or 1), RestECG (0, 1 or 2).

    • Categorical features coded as string: ChestPain, Thal.

  • There are missing values in Ca and Thal. Since missing proportion is not high, we will use simple mean (for numerical feature Ca) and mode (for categorical feature Thal) imputation.

tree_recipe <- 
  recipe(
    AHD ~ ., 
    data = Heart_other
  ) %>%
  step_naomit(all_predictors()) %>%
  # # create traditional dummy variables (not necessary for random forest in R)
  step_dummy(all_nominal_predictors()) %>%
  # zero-variance filter
  step_zv(all_numeric_predictors()) %>% 
  # # center and scale numeric data (not necessary for random forest)
  step_normalize(all_numeric_predictors()) # %>%
  # estimate the means and standard deviations
  # prep(training = Heart_other, retain = TRUE)
tree_recipe

There are missing values in Ca (quantitative) and Thal (qualitative) variables. We are going to use simple mean imputation for Ca and most_frequent imputation for Thal. This is suboptimal. Better strategy is to use multiple imputation.

# How many NaNs
Heart.isna().sum()
Unnamed: 0    0
Age           0
Sex           0
ChestPain     0
RestBP        0
Chol          0
Fbs           0
RestECG       0
MaxHR         0
ExAng         0
Oldpeak       0
Slope         0
Ca            4
Thal          2
AHD           0
dtype: int64

In principle, decision trees should be able to handle categorical predictors. However scikit-learn and xgboost implementations don’t allow categorical predictors and require one-hot encoding.

from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# Transformer for categorical variables
categorical_tf = Pipeline(steps = [
  ("cat_impute", SimpleImputer(strategy = 'most_frequent')),
  ("encoder", OneHotEncoder())
])

# Transformer for continuous variables
numeric_tf = Pipeline(steps = [
  ("num_impute", SimpleImputer(strategy = 'mean')),
])

# Column transformer
col_tf = ColumnTransformer(transformers = [
  ('num', numeric_tf, num_features),
  ('cat', categorical_tf, cat_features)
])

5 Model

classtree_mod <- decision_tree(
  cost_complexity = tune(),
  tree_depth = tune(),
  min_n = 5,
  mode = "classification",
  engine = "rpart"
  ) 
from sklearn.tree import DecisionTreeClassifier, plot_tree

classtree_mod = DecisionTreeClassifier(
  criterion = 'gini',
  random_state = 425
  )

6 Pipeline (Python) or workflow (R)

Here we bundle the preprocessing step (Python) or recipe (R) and model.

tree_wf <- workflow() %>%
  add_recipe(tree_recipe) %>%
  add_model(classtree_mod) 
tree_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_naomit()
• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = tune()
  tree_depth = tune()
  min_n = 5

Computational engine: rpart 
from sklearn.pipeline import Pipeline

pipe = Pipeline(steps = [
  ("col_tf", col_tf),
  ("model", classtree_mod)
  ])
pipe
Pipeline(steps=[('col_tf',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('num_impute',
                                                                   SimpleImputer())]),
                                                  ['Age', 'Sex', 'RestBP',
                                                   'Chol', 'Fbs', 'RestECG',
                                                   'MaxHR', 'ExAng', 'Oldpeak',
                                                   'Slope', 'Ca']),
                                                 ('cat',
                                                  Pipeline(steps=[('cat_impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('encoder',
                                                                   OneHotEncoder())]),
                                                  ['ChestPain', 'Thal'])])),
                ('model', DecisionTreeClassifier(random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

7 Tuning grid

ccp_alpha is the Minimal Cost-Complexity Pruning parameter. Greater values of ccp_alpha increase the number of nodes pruned.

tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          levels = c(100,5))
# Tune hyper-parameter(s)
ccp_alpha_grid = np.linspace(start = 0.0, stop = 0.05, num = 100)
tuned_parameters = {
  "model__ccp_alpha": ccp_alpha_grid
  }
tuned_parameters  
{'model__ccp_alpha': array([0.        , 0.00050505, 0.0010101 , 0.00151515, 0.0020202 ,
       0.00252525, 0.0030303 , 0.00353535, 0.0040404 , 0.00454545,
       0.00505051, 0.00555556, 0.00606061, 0.00656566, 0.00707071,
       0.00757576, 0.00808081, 0.00858586, 0.00909091, 0.00959596,
       0.01010101, 0.01060606, 0.01111111, 0.01161616, 0.01212121,
       0.01262626, 0.01313131, 0.01363636, 0.01414141, 0.01464646,
       0.01515152, 0.01565657, 0.01616162, 0.01666667, 0.01717172,
       0.01767677, 0.01818182, 0.01868687, 0.01919192, 0.01969697,
       0.02020202, 0.02070707, 0.02121212, 0.02171717, 0.02222222,
       0.02272727, 0.02323232, 0.02373737, 0.02424242, 0.02474747,
       0.02525253, 0.02575758, 0.02626263, 0.02676768, 0.02727273,
       0.02777778, 0.02828283, 0.02878788, 0.02929293, 0.02979798,
       0.03030303, 0.03080808, 0.03131313, 0.03181818, 0.03232323,
       0.03282828, 0.03333333, 0.03383838, 0.03434343, 0.03484848,
       0.03535354, 0.03585859, 0.03636364, 0.03686869, 0.03737374,
       0.03787879, 0.03838384, 0.03888889, 0.03939394, 0.03989899,
       0.04040404, 0.04090909, 0.04141414, 0.04191919, 0.04242424,
       0.04292929, 0.04343434, 0.04393939, 0.04444444, 0.04494949,
       0.04545455, 0.0459596 , 0.04646465, 0.0469697 , 0.04747475,
       0.0479798 , 0.04848485, 0.0489899 , 0.04949495, 0.05      ])}

8 Cross-validation (CV)

Set cross-validation partitions.

set.seed(212)

folds <- vfold_cv(Heart_other, v = 5)
folds
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits           id   
  <list>           <chr>
1 <split [120/31]> Fold1
2 <split [121/30]> Fold2
3 <split [121/30]> Fold3
4 <split [121/30]> Fold4
5 <split [121/30]> Fold5

Fit cross-validation.

tree_fit <- tree_wf %>%
  tune_grid(
    resamples = folds,
    grid = tree_grid,
    metrics = metric_set(accuracy, roc_auc)
    )
tree_fit
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 4
  splits           id    .metrics             .notes            
  <list>           <chr> <list>               <list>            
1 <split [120/31]> Fold1 <tibble [1,000 × 6]> <tibble [0 × 3]>  
2 <split [121/30]> Fold2 <tibble [1,000 × 6]> <tibble [0 × 3]>  
3 <split [121/30]> Fold3 <tibble [1,000 × 6]> <tibble [0 × 3]>  
4 <split [121/30]> Fold4 <tibble [1,000 × 6]> <tibble [500 × 3]>
5 <split [121/30]> Fold5 <tibble [1,000 × 6]> <tibble [500 × 3]>

There were issues with some computations:

  - Warning(s) x500: ! There are new levels in `RestECG`: "1". ℹ Consider using step_n...
  - Warning(s) x500: ! There are new levels in `Thal`: NA. ℹ Consider using step_unkno...

Run `show_notes(.Last.tune.result)` for more information.

Visualize CV results:

tree_fit %>%
  collect_metrics() %>%
  print(width = Inf) %>%
  filter(.metric == "roc_auc") %>%
  mutate(tree_depth = as.factor(tree_depth)) %>%
  ggplot(mapping = aes(x = cost_complexity, y = mean, color = tree_depth)) +
  geom_point() + 
  geom_line() + 
  labs(x = "cost_complexity", y = "CV ROC AUC", color = "tree_depth") 
# A tibble: 1,000 × 8
   cost_complexity tree_depth .metric  .estimator  mean     n std_err
             <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl>
 1        1   e-10          1 accuracy binary     0.689     5  0.0165
 2        1   e-10          1 roc_auc  binary     0.674     5  0.0185
 3        1.23e-10          1 accuracy binary     0.689     5  0.0165
 4        1.23e-10          1 roc_auc  binary     0.674     5  0.0185
 5        1.52e-10          1 accuracy binary     0.689     5  0.0165
 6        1.52e-10          1 roc_auc  binary     0.674     5  0.0185
 7        1.87e-10          1 accuracy binary     0.689     5  0.0165
 8        1.87e-10          1 roc_auc  binary     0.674     5  0.0185
 9        2.31e-10          1 accuracy binary     0.689     5  0.0165
10        2.31e-10          1 roc_auc  binary     0.674     5  0.0185
   .config               
   <chr>                 
 1 Preprocessor1_Model001
 2 Preprocessor1_Model001
 3 Preprocessor1_Model002
 4 Preprocessor1_Model002
 5 Preprocessor1_Model003
 6 Preprocessor1_Model003
 7 Preprocessor1_Model004
 8 Preprocessor1_Model004
 9 Preprocessor1_Model005
10 Preprocessor1_Model005
# ℹ 990 more rows

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 5
search = GridSearchCV(
  pipe,
  tuned_parameters,
  cv = n_folds, 
  scoring = "roc_auc",
  # Refit the best model on the whole data set
  refit = True
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('col_tf',
                                        ColumnTransformer(transformers=[('num',
                                                                         Pipeline(steps=[('num_impute',
                                                                                          SimpleImputer())]),
                                                                         ['Age',
                                                                          'Sex',
                                                                          'RestBP',
                                                                          'Chol',
                                                                          'Fbs',
                                                                          'RestECG',
                                                                          'MaxHR',
                                                                          'ExAng',
                                                                          'Oldpeak',
                                                                          'Slope',
                                                                          'Ca']),
                                                                        ('cat',
                                                                         Pipeline(steps=[('cat_impute',
                                                                                          SimpleImputer(strategy='most_frequent')),
                                                                                         ('encoder',
                                                                                          OneHotEncoder())]),
                                                                         ['ChestPain',
                                                                          '...
       0.03282828, 0.03333333, 0.03383838, 0.03434343, 0.03484848,
       0.03535354, 0.03585859, 0.03636364, 0.03686869, 0.03737374,
       0.03787879, 0.03838384, 0.03888889, 0.03939394, 0.03989899,
       0.04040404, 0.04090909, 0.04141414, 0.04191919, 0.04242424,
       0.04292929, 0.04343434, 0.04393939, 0.04444444, 0.04494949,
       0.04545455, 0.0459596 , 0.04646465, 0.0469697 , 0.04747475,
       0.0479798 , 0.04848485, 0.0489899 , 0.04949495, 0.05      ])},
             scoring='roc_auc')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize CV results.

Code
cv_res = pd.DataFrame({
  "ccp_alpha": np.array(search.cv_results_["param_model__ccp_alpha"]),
  "auc": search.cv_results_["mean_test_score"]
  })

plt.figure()
sns.relplot(
  # kind = "line",
  data = cv_res,
  x = "ccp_alpha",
  y = "auc"
  ).set(
    xlabel = "CCP Alpha",
    ylabel = "CV AUC"
);
plt.show()

Best CV AUC:

search.best_score_
np.float64(0.803652380952381)

The training accuracy is

from sklearn.metrics import accuracy_score, roc_auc_score

accuracy_score(
  y_other,
  search.best_estimator_.predict(X_other)
  )
0.8149779735682819

9 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

tree_fit %>%
  show_best(metric = "roc_auc")
# A tibble: 5 × 8
  cost_complexity tree_depth .metric .estimator  mean     n std_err .config     
            <dbl>      <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
1        1.87e- 2          8 roc_auc binary     0.756     5  0.0326 Preprocesso…
2        1.87e- 2         11 roc_auc binary     0.756     5  0.0326 Preprocesso…
3        1.87e- 2         15 roc_auc binary     0.756     5  0.0326 Preprocesso…
4        1   e-10          4 roc_auc binary     0.752     5  0.0556 Preprocesso…
5        1.23e-10          4 roc_auc binary     0.752     5  0.0556 Preprocesso…

Let’s select the best model.

best_tree <- tree_fit %>%
  select_best(metric = "roc_auc")
best_tree
# A tibble: 1 × 3
  cost_complexity tree_depth .config               
            <dbl>      <int> <chr>                 
1          0.0187          8 Preprocessor1_Model292
# Final workflow
final_wf <- tree_wf %>%
  finalize_workflow(best_tree)
final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_naomit()
• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = 0.0187381742286039
  tree_depth = 8
  min_n = 5

Computational engine: rpart 
# Fit the whole training set, then predict the test cases
final_fit <- 
  final_wf %>%
  last_fit(data_split)
final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits            id               .metrics .notes   .predictions .workflow 
  <list>            <chr>            <list>   <list>   <list>       <list>    
1 <split [151/152]> train/test split <tibble> <tibble> <tibble>     <workflow>

There were issues with some computations:

  - Warning(s) x1: ! There are new levels in `Thal`: NA. ℹ Consider using step_unkno...

Run `show_notes(.Last.tune.result)` for more information.
# Test metrics
final_fit %>% 
  collect_metrics()
# A tibble: 3 × 4
  .metric     .estimator .estimate .config             
  <chr>       <chr>          <dbl> <chr>               
1 accuracy    binary         0.796 Preprocessor1_Model1
2 roc_auc     binary         0.795 Preprocessor1_Model1
3 brier_class binary         0.176 Preprocessor1_Model1

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('col_tf',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('num_impute',
                                                                   SimpleImputer())]),
                                                  ['Age', 'Sex', 'RestBP',
                                                   'Chol', 'Fbs', 'RestECG',
                                                   'MaxHR', 'ExAng', 'Oldpeak',
                                                   'Slope', 'Ca']),
                                                 ('cat',
                                                  Pipeline(steps=[('cat_impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('encoder',
                                                                   OneHotEncoder())]),
                                                  ['ChestPain', 'Thal'])])),
                ('model',
                 DecisionTreeClassifier(ccp_alpha=np.float64(0.0202020202020202),
                                        random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize the best classification tree.

features = np.concatenate([
    features[:-2], 
    ['ChestPain:asymptomatic', 'ChestPain:nonanginal', 'ChestPain:nontypical', 'ChestPain:typical'],
    ['Thal:fixed', 'Thal:normal', 'Thal:reversable']
    ])

plt.figure()
plot_tree(
  search.best_estimator_['model'],
  feature_names = features
  );
plt.show()

Feature importances:

vi_df = pd.DataFrame({
  "feature": features,
  "vi": search.best_estimator_['model'].feature_importances_
  })

plt.figure()
sns.barplot(
  data = vi_df,
  x = "feature",
  y = "vi"
  ).set(
    xlabel = "Feature",
    ylabel = "VI"
);
plt.xticks(rotation = 90);
plt.show()

The final AUC on the test set is

roc_auc_score(
  y_test,
  search.best_estimator_.predict_proba(X_test)[:, 1]
  )
np.float64(0.8574912891986062)

The final classification accuracy on the test set is

accuracy_score(
  y_test, 
  search.best_estimator_.predict(X_test)
  )
0.8421052631578947

10 Visualize the final model

library(rpart.plot)
final_tree <- extract_workflow(final_fit)
final_tree
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
4 Recipe Steps

• step_naomit()
• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
n= 148 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 148 68 No (0.54054054 0.45945946)  
   2) Ca< -0.2344095 83 21 No (0.74698795 0.25301205)  
     4) Thal_reversable< 0.2949493 61  9 No (0.85245902 0.14754098)  
       8) Oldpeak< 1.4531 57  6 No (0.89473684 0.10526316) *
       9) Oldpeak>=1.4531 4  1 Yes (0.25000000 0.75000000) *
     5) Thal_reversable>=0.2949493 22 10 Yes (0.45454545 0.54545455)  
      10) Chol< 0.438099 13  4 No (0.69230769 0.30769231)  
        20) Age>=-0.3617272 7  0 No (1.00000000 0.00000000) *
        21) Age< -0.3617272 6  2 Yes (0.33333333 0.66666667) *
      11) Chol>=0.438099 9  1 Yes (0.11111111 0.88888889) *
   3) Ca>=-0.2344095 65 18 Yes (0.27692308 0.72307692)  
     6) ChestPain_nonanginal>=0.3905548 18  6 No (0.66666667 0.33333333)  
      12) Oldpeak< 0.742162 15  3 No (0.80000000 0.20000000) *
      13) Oldpeak>=0.742162 3  0 Yes (0.00000000 1.00000000) *
     7) ChestPain_nonanginal< 0.3905548 47  6 Yes (0.12765957 0.87234043)  
      14) ChestPain_typical>=1.380583 4  1 No (0.75000000 0.25000000) *
      15) ChestPain_typical< 1.380583 43  3 Yes (0.06976744 0.93023256) *
final_tree %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE)

library(vip)

final_tree %>% 
  extract_fit_parsnip() %>% 
  vip()