Machine Learning Workflow: Regression Trees

Biostat 274

Author

Dr. Jin Zhou @ UCLA

Published

December 24, 2025

Display system information for reproducibility.

sessionInfo()
R version 4.5.2 (2025-10-31)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.7.3

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.5.2    fastmap_1.2.0     cli_3.6.5        
 [5] tools_4.5.2       htmltools_0.5.8.1 yaml_2.3.10       rmarkdown_2.29   
 [9] knitr_1.50        jsonlite_2.0.0    xfun_0.53         digest_0.6.37    
[13] rlang_1.1.6       evaluate_1.0.5   
import IPython
print(IPython.sys_info())
{'commit_hash': '5ed988a91',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.33.0',
 'os_name': 'posix',
 'platform': 'macOS-15.7.3-arm64-arm-64bit',
 'sys_executable': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/bin/python',
 'sys_platform': 'darwin',
 'sys_version': '3.10.16 (main, Mar  3 2025, 20:01:33) [Clang 16.0.0 '
                '(clang-1600.0.26.6)]'}

1 Overview

We illustrate the typical machine learning workflow for regression trees using the Hitters data set from R ISLR2 package.

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: not much is needed for regression trees.

  3. Tune the cost complexity pruning hyper-parameter(s) using 10-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final prediction on the test data.

2 Hitters data

A documentation of the Hitters data is here. The goal is to predict the log(Salary) (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.

library(GGally)
library(gtsummary)
library(ranger)
library(tidyverse)
library(tidymodels)
library(ISLR2)

# Numerical summaries stratified by the outcome `AHD`.
Hitters %>% tbl_summary()
Characteristic N = 3221
AtBat 380 (255, 512)
Hits 96 (64, 137)
HmRun 8 (4, 16)
Runs 48 (30, 69)
RBI 44 (28, 65)
Walks 35 (22, 53)
Years 6 (4, 11)
CAtBat 1,928 (815, 3,926)
CHits 508 (209, 1,062)
CHmRun 38 (14, 90)
CRuns 247 (100, 529)
CRBI 221 (88, 428)
CWalks 171 (67, 340)
League
    A 175 (54%)
    N 147 (46%)
Division
    E 157 (49%)
    W 165 (51%)
PutOuts 212 (109, 325)
Assists 40 (7, 166)
Errors 6 (3, 11)
Salary 425 (190, 750)
    Unknown 59
NewLeague
    A 176 (55%)
    N 146 (45%)
1 Median (Q1, Q3); n (%)
Hitters <- Hitters %>% filter(!is.na(Salary))
# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 1.2)
# Display all columns
pd.set_option('display.max_columns', None)

Hitters = pd.read_csv("./slides/data/Hitters.csv")
Hitters
     AtBat  Hits  HmRun  Runs  RBI  Walks  Years  CAtBat  CHits  CHmRun  \
0      293    66      1    30   29     14      1     293     66       1   
1      315    81      7    24   38     39     14    3449    835      69   
2      479   130     18    66   72     76      3    1624    457      63   
3      496   141     20    65   78     37     11    5628   1575     225   
4      321    87     10    39   42     30      2     396    101      12   
..     ...   ...    ...   ...  ...    ...    ...     ...    ...     ...   
317    497   127      7    65   48     37      5    2703    806      32   
318    492   136      5    76   50     94     12    5511   1511      39   
319    475   126      3    61   43     52      6    1700    433       7   
320    573   144      9    85   60     78      8    3198    857      97   
321    631   170      9    77   44     31     11    4908   1457      30   

     CRuns  CRBI  CWalks League Division  PutOuts  Assists  Errors  Salary  \
0       30    29      14      A        E      446       33      20     NaN   
1      321   414     375      N        W      632       43      10   475.0   
2      224   266     263      A        W      880       82      14   480.0   
3      828   838     354      N        E      200       11       3   500.0   
4       48    46      33      N        E      805       40       4    91.5   
..     ...   ...     ...    ...      ...      ...      ...     ...     ...   
317    379   311     138      N        E      325        9       3   700.0   
318    897   451     875      A        E      313      381      20   875.0   
319    217    93     146      A        W       37      113       7   385.0   
320    470   420     332      A        E     1314      131      12   960.0   
321    775   357     249      A        W      408        4       3  1000.0   

    NewLeague  
0           A  
1           N  
2           A  
3           N  
4           N  
..        ...  
317         N  
318         A  
319         A  
320         A  
321         A  

[322 rows x 20 columns]
# Numerical summaries
Hitters.describe()
            AtBat        Hits       HmRun        Runs         RBI       Walks  \
count  322.000000  322.000000  322.000000  322.000000  322.000000  322.000000   
mean   380.928571  101.024845   10.770186   50.909938   48.027950   38.742236   
std    153.404981   46.454741    8.709037   26.024095   26.166895   21.639327   
min     16.000000    1.000000    0.000000    0.000000    0.000000    0.000000   
25%    255.250000   64.000000    4.000000   30.250000   28.000000   22.000000   
50%    379.500000   96.000000    8.000000   48.000000   44.000000   35.000000   
75%    512.000000  137.000000   16.000000   69.000000   64.750000   53.000000   
max    687.000000  238.000000   40.000000  130.000000  121.000000  105.000000   

            Years       CAtBat        CHits      CHmRun        CRuns  \
count  322.000000    322.00000   322.000000  322.000000   322.000000   
mean     7.444099   2648.68323   717.571429   69.490683   358.795031   
std      4.926087   2324.20587   654.472627   86.266061   334.105886   
min      1.000000     19.00000     4.000000    0.000000     1.000000   
25%      4.000000    816.75000   209.000000   14.000000   100.250000   
50%      6.000000   1928.00000   508.000000   37.500000   247.000000   
75%     11.000000   3924.25000  1059.250000   90.000000   526.250000   
max     24.000000  14053.00000  4256.000000  548.000000  2165.000000   

              CRBI       CWalks      PutOuts     Assists      Errors  \
count   322.000000   322.000000   322.000000  322.000000  322.000000   
mean    330.118012   260.239130   288.937888  106.913043    8.040373   
std     333.219617   267.058085   280.704614  136.854876    6.368359   
min       0.000000     0.000000     0.000000    0.000000    0.000000   
25%      88.750000    67.250000   109.250000    7.000000    3.000000   
50%     220.500000   170.500000   212.000000   39.500000    6.000000   
75%     426.250000   339.250000   325.000000  166.000000   11.000000   
max    1659.000000  1566.000000  1378.000000  492.000000   32.000000   

            Salary  
count   263.000000  
mean    535.925882  
std     451.118681  
min      67.500000  
25%     190.000000  
50%     425.000000  
75%     750.000000  
max    2460.000000  

Graphical summary:

# Graphical summaries
plt.figure()
sns.pairplot(data = Hitters);
plt.show()

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters.dropna(inplace = True)
Hitters.shape
(263, 20)

3 Initial split into test and non-test sets

We randomly split the data in half of test data and another half of non-test data.

# For reproducibility
set.seed(203)

data_split <- initial_split(
  Hitters, 
  prop = 0.5
  )
data_split
<Training/Testing/Total>
<131/132/263>
Hitters_other <- training(data_split)
dim(Hitters_other)
[1] 131  20
Hitters_test <- testing(data_split)
dim(Hitters_test)
[1] 132  20
from sklearn.model_selection import train_test_split

Hitters_other, Hitters_test = train_test_split(
  Hitters, 
  train_size = 0.5,
  random_state = 425, # seed
  )
Hitters_test.shape
(132, 20)
Hitters_other.shape
(131, 20)

Separate \(X\) and \(y\). We will use 9 of the features.

features = ['Assists', 'AtBat', 'Hits', 'HmRun', 'PutOuts', 'RBI', 'Runs', 'Walks', 'Years']
# Non-test X and y
X_other = Hitters_other[features]
y_other = np.log(Hitters_other.Salary)
# Test X and y
X_test = Hitters_test[features]
y_test = np.log(Hitters_test.Salary)

4 Preprocessing (Python) or recipe (R)

tree_recipe <- 
  recipe(
    Salary ~ ., 
    data = Hitters_other
  ) %>%
  # # create traditional dummy variables (not necessary for random forest in R)
  # step_dummy(all_nominal()) %>%
  step_naomit(Salary) %>%
  # zero-variance filter
  step_zv(all_numeric_predictors()) # %>% 
  # # center and scale numeric data (not necessary for random forest)
  # step_normalize(all_numeric_predictors()) %>%
  # estimate the means and standard deviations
  # prep(training = Hitters_other, retain = TRUE)
tree_recipe

Not much preprocessing is needed here since all predictors are quantitative.

5 Model

regtree_mod <- decision_tree(
  cost_complexity = tune(),
  tree_depth = tune(),
  min_n = 5,
  mode = "regression",
  engine = "rpart"
  ) 
from sklearn.tree import DecisionTreeRegressor, plot_tree

regtree_mod = DecisionTreeRegressor(random_state = 425)

6 Pipeline (Python) or workflow (R)

Here we bundle the preprocessing step (Python) or recipe (R) and model.

tree_wf <- workflow() %>%
  add_recipe(tree_recipe) %>%
  add_model(regtree_mod)
tree_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_naomit()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (regression)

Main Arguments:
  cost_complexity = tune()
  tree_depth = tune()
  min_n = 5

Computational engine: rpart 
from sklearn.pipeline import Pipeline

pipe = Pipeline(steps = [
  ("model", regtree_mod)
  ])
pipe
Pipeline(steps=[('model', DecisionTreeRegressor(random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

7 Tuning grid

ccp_alpha is the Minimal Cost-Complexity Pruning parameter. Greater values of ccp_alpha increase the number of nodes pruned.

tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          levels = c(100, 5))
# Tune hyper-parameter(s)
ccp_alpha_grid = np.linspace(start = 0.0, stop = 0.1, num = 100)
tuned_parameters = {
  "model__ccp_alpha": ccp_alpha_grid
  }
tuned_parameters  
{'model__ccp_alpha': array([0.        , 0.0010101 , 0.0020202 , 0.0030303 , 0.0040404 ,
       0.00505051, 0.00606061, 0.00707071, 0.00808081, 0.00909091,
       0.01010101, 0.01111111, 0.01212121, 0.01313131, 0.01414141,
       0.01515152, 0.01616162, 0.01717172, 0.01818182, 0.01919192,
       0.02020202, 0.02121212, 0.02222222, 0.02323232, 0.02424242,
       0.02525253, 0.02626263, 0.02727273, 0.02828283, 0.02929293,
       0.03030303, 0.03131313, 0.03232323, 0.03333333, 0.03434343,
       0.03535354, 0.03636364, 0.03737374, 0.03838384, 0.03939394,
       0.04040404, 0.04141414, 0.04242424, 0.04343434, 0.04444444,
       0.04545455, 0.04646465, 0.04747475, 0.04848485, 0.04949495,
       0.05050505, 0.05151515, 0.05252525, 0.05353535, 0.05454545,
       0.05555556, 0.05656566, 0.05757576, 0.05858586, 0.05959596,
       0.06060606, 0.06161616, 0.06262626, 0.06363636, 0.06464646,
       0.06565657, 0.06666667, 0.06767677, 0.06868687, 0.06969697,
       0.07070707, 0.07171717, 0.07272727, 0.07373737, 0.07474747,
       0.07575758, 0.07676768, 0.07777778, 0.07878788, 0.07979798,
       0.08080808, 0.08181818, 0.08282828, 0.08383838, 0.08484848,
       0.08585859, 0.08686869, 0.08787879, 0.08888889, 0.08989899,
       0.09090909, 0.09191919, 0.09292929, 0.09393939, 0.09494949,
       0.0959596 , 0.0969697 , 0.0979798 , 0.0989899 , 0.1       ])}

8 Cross-validation (CV)

Set cross-validation partitions.

set.seed(203)

folds <- vfold_cv(Hitters_other, v = 5)
folds
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits           id   
  <list>           <chr>
1 <split [104/27]> Fold1
2 <split [105/26]> Fold2
3 <split [105/26]> Fold3
4 <split [105/26]> Fold4
5 <split [105/26]> Fold5

Fit cross-validation.

tree_fit <- tree_wf %>%
  tune_grid(
    resamples = folds,
    grid = tree_grid,
    metrics = metric_set(rmse, rsq)
    )
tree_fit
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 4
  splits           id    .metrics             .notes          
  <list>           <chr> <list>               <list>          
1 <split [104/27]> Fold1 <tibble [1,000 × 6]> <tibble [0 × 4]>
2 <split [105/26]> Fold2 <tibble [1,000 × 6]> <tibble [0 × 4]>
3 <split [105/26]> Fold3 <tibble [1,000 × 6]> <tibble [0 × 4]>
4 <split [105/26]> Fold4 <tibble [1,000 × 6]> <tibble [0 × 4]>
5 <split [105/26]> Fold5 <tibble [1,000 × 6]> <tibble [0 × 4]>

Visualize CV results:

tree_fit %>%
  collect_metrics() %>%
  print(width = Inf) %>%
  filter(.metric == "rmse") %>%
  mutate(tree_depth = as.factor(tree_depth)) %>%
  ggplot(mapping = aes(x = cost_complexity, y = mean, color = tree_depth)) +
  geom_point() + 
  geom_line() + 
  labs(x = "cost_complexity", y = "CV mse")
# A tibble: 1,000 × 8
   cost_complexity tree_depth .metric .estimator    mean     n std_err
             <dbl>      <int> <chr>   <chr>        <dbl> <int>   <dbl>
 1    0.0000000001          1 rmse    standard   397.        5 34.7   
 2    0.0000000001          1 rsq     standard     0.342     5  0.0626
 3    0.0000000001          4 rmse    standard   378.        5 45.5   
 4    0.0000000001          4 rsq     standard     0.453     5  0.107 
 5    0.0000000001          8 rmse    standard   407.        5 39.4   
 6    0.0000000001          8 rsq     standard     0.410     5  0.0822
 7    0.0000000001         11 rmse    standard   408.        5 39.5   
 8    0.0000000001         11 rsq     standard     0.408     5  0.0817
 9    0.0000000001         15 rmse    standard   408.        5 39.5   
10    0.0000000001         15 rsq     standard     0.408     5  0.0817
   .config          
   <chr>            
 1 pre0_mod001_post0
 2 pre0_mod001_post0
 3 pre0_mod002_post0
 4 pre0_mod002_post0
 5 pre0_mod003_post0
 6 pre0_mod003_post0
 7 pre0_mod004_post0
 8 pre0_mod004_post0
 9 pre0_mod005_post0
10 pre0_mod005_post0
# ℹ 990 more rows

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 6
search = GridSearchCV(
  pipe,
  tuned_parameters,
  cv = n_folds, 
  scoring = "neg_root_mean_squared_error",
  # Refit the best model on the whole data set
  refit = True
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=6,
             estimator=Pipeline(steps=[('model',
                                        DecisionTreeRegressor(random_state=425))]),
             param_grid={'model__ccp_alpha': array([0.        , 0.0010101 , 0.0020202 , 0.0030303 , 0.0040404 ,
       0.00505051, 0.00606061, 0.00707071, 0.00808081, 0.00909091,
       0.01010101, 0.01111111, 0.01212121, 0.01313131, 0.01414141,
       0.01515152, 0.01616162, 0.01717172, 0.01818182, 0.01919192,
       0.020202...
       0.07070707, 0.07171717, 0.07272727, 0.07373737, 0.07474747,
       0.07575758, 0.07676768, 0.07777778, 0.07878788, 0.07979798,
       0.08080808, 0.08181818, 0.08282828, 0.08383838, 0.08484848,
       0.08585859, 0.08686869, 0.08787879, 0.08888889, 0.08989899,
       0.09090909, 0.09191919, 0.09292929, 0.09393939, 0.09494949,
       0.0959596 , 0.0969697 , 0.0979798 , 0.0989899 , 0.1       ])},
             scoring='neg_root_mean_squared_error')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize CV results.

Code
cv_res = pd.DataFrame({
  "ccp_alpha": np.array(search.cv_results_["param_model__ccp_alpha"]),
  "rmse": -search.cv_results_["mean_test_score"]
  })

plt.figure()
sns.relplot(
  # kind = "line",
  data = cv_res,
  x = "ccp_alpha",
  y = "rmse"
  ).set(
    xlabel = "CCP Alpha",
    ylabel = "CV RMSE"
);
plt.show()

Best CV RMSE:

-search.best_score_
np.float64(0.5217085199223737)

9 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

tree_fit %>%
  show_best(metric = "rmse")
# A tibble: 5 × 8
  cost_complexity tree_depth .metric .estimator  mean     n std_err .config     
            <dbl>      <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
1          0.0187          4 rmse    standard    364.     5    44.7 pre0_mod457…
2          0.0231          4 rmse    standard    364.     5    44.7 pre0_mod462…
3          0.0231          8 rmse    standard    364.     5    44.7 pre0_mod463…
4          0.0231         11 rmse    standard    364.     5    44.7 pre0_mod464…
5          0.0231         15 rmse    standard    364.     5    44.7 pre0_mod465…

Let’s select the best model.

best_tree <- tree_fit %>%
  select_best(metric = "rmse")
best_tree
# A tibble: 1 × 3
  cost_complexity tree_depth .config          
            <dbl>      <int> <chr>            
1          0.0187          4 pre0_mod457_post0
# Final workflow
final_wf <- tree_wf %>%
  finalize_workflow(best_tree)
final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_naomit()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (regression)

Main Arguments:
  cost_complexity = 0.0187381742286039
  tree_depth = 4
  min_n = 5

Computational engine: rpart 
# Fit the whole training set, then predict the test cases
final_fit <- 
  final_wf %>%
  last_fit(data_split)
final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits            id               .metrics .notes   .predictions .workflow 
  <list>            <chr>            <list>   <list>   <list>       <list>    
1 <split [131/132]> train/test split <tibble> <tibble> <tibble>     <workflow>
# Test metrics
final_fit %>% 
  collect_metrics()
# A tibble: 2 × 4
  .metric .estimator .estimate .config        
  <chr>   <chr>          <dbl> <chr>          
1 rmse    standard     357.    pre0_mod0_post0
2 rsq     standard       0.410 pre0_mod0_post0

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('model',
                 DecisionTreeRegressor(ccp_alpha=np.float64(0.03636363636363636),
                                       random_state=425))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Visualize the best regression tree.

plt.figure()
plot_tree(
  search.best_estimator_['model'],
  feature_names = features
  );
plt.show()

Feature importances:

vi_df = pd.DataFrame({
  "feature": features,
  "vi": search.best_estimator_['model'].feature_importances_
  })

plt.figure()
sns.barplot(
  data = vi_df,
  x = "feature",
  y = "vi"
  ).set(
    xlabel = "Feature",
    ylabel = "VI"
);
plt.xticks(rotation = 90);
plt.show()

The final prediction RMSE on the test set is

from sklearn.metrics import mean_squared_error

mean_squared_error(
  y_test, 
  search.best_estimator_.predict(X_test)
  )
0.32467980212189657

10 Visualize the final model

library(rpart.plot)
final_tree <- extract_workflow(final_fit)
final_tree
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_naomit()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
n= 131 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 131 29124950.00  542.8152  
   2) CAtBat< 1934.5 64  4818615.00  246.4948  
     4) AtBat>=135 62  1218139.00  217.9570  
       8) CAtBat< 1257.5 41   157533.40  141.3211 *
       9) CAtBat>=1257.5 21   349686.00  367.5794 *
     5) AtBat< 135 2  1984695.00 1131.1660 *
   3) CAtBat>=1934.5 67 13318820.00  825.8676  
     6) Walks< 65.5 55  5778590.00  696.4545  
      12) PutOuts< 1113.5 53  3498168.00  658.1132  
        24) AtBat< 365 20   898765.90  484.5834 *
        25) AtBat>=365 33  1632149.00  763.2828 *
      13) PutOuts>=1113.5 2   137812.50 1712.5000 *
     7) Walks>=65.5 12  2397281.00 1419.0100  
      14) Hits< 147 6   346037.10 1109.4440 *
      15) Hits>=147 6   901269.60 1728.5770  
        30) AtBat>=572.5 3    51666.67 1416.6670 *
        31) AtBat< 572.5 3   265875.80 2040.4870 *
final_tree %>%
  extract_fit_engine() %>%
  rpart.plot(roundint = FALSE)

library(vip)

final_tree %>% 
  extract_fit_parsnip() %>% 
  vip()