Exploring Interaction Effects and S-Learners

By Ken Koon Wong in r R interaction meta learner s learner boost tree lightgbm

September 4, 2023

Interaction adventures through simulations and gradient boosting trees using the S-learner approach. I hadn’t realized that lightGBM and XGBoost could reveal interaction terms without explicit specification. Quite intriguing!

picture resembles interaction 🤣

Objectives:

What is interaction?

In statistics, interaction refers to the phenomenon where the effect of one variable on an outcome is influenced by the presence or change in another variable. It indicates that the relationship between variables is not simply additive, but rather depends on the interaction between their values. Understanding interactions is crucial for capturing complex relationships and building accurate predictive models that consider the combined influence of variables, providing deeper insights into data analysis across various fields.

Still don’t understand? No worries, you’re not alone. I’ve been there too! However, I found that simulating and visualizing interactions really helped solidify my understanding of their significance. Missing out on understanding interactions is like skipping a few chapters in the story – it’s an essential part of grasping the whole picture.

Simulate interaction

library(tidymodels)
library(tidyverse)
library(bonsai)
library(kableExtra)
library(ggpubr)

set.seed(1)
n <- 1000
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
y1 <- 0.2*x1 + rnorm(n)
y2 <- 1 + 0.6*x2 + rnorm(n)
y3 <- 2 + -0.2*x3 + rnorm(n)

# combining all y_i to 1 vector
y <- c(y1,y2,y3)

# categorize x1, x2, and x3
df <- tibble(y=y,x=c(x1,x2,x3),x_i=c(rep("x1",n),rep("x2",n),rep("x3",n))) 

kable(df |> head(5))
y x x_i
0.6138242 -0.6264538 x1
0.4233374 0.1836433 x1
1.1292714 -0.8356286 x1
-0.4845022 1.5952808 x1
-1.5367241 0.3295078 x1

But hold on, Ken, that’s not interaction, right? Doesn’t interaction involve terms like y = x + w + x*w? Well, yes, you’re right about that equation, but the explanation above offers a more intuitive grasp of what interaction entails. It’s like taking the nominal categories x1, x2, x3, treating them as on-off switches, and then consolidating all three equations into one primary equation. This approach helps in calculating the interrelationships between them. Trust me. Also this was a great resource in understanding the interaction formula too.

This is essentially our formula:

\(Y_i = B_0 + B_1X_{1i} + B_2X_{2i} + B_3X_{3i} + B_4X_{1i}X_{2i} + B_5X_{1i}X_{3i} + e_i\).

Wow, there are too many \(B\)’s. I can see why it is hard to follow. We’ll carry this formula with us and unpack it later on.

Visualize interaction

df |>
  ggplot(aes(x=x,y=y,color=as.factor(x_i))) +
  geom_point(alpha = 0.5) +
  geom_smooth(method = "lm") +
  theme_minimal()

Wow, look at those slopes of x1, x2, x3. We can all agree that they all have different intercepts and slopes. That my friend, is interaction.

True Model ✅

model_true <- lm(y~x*x_i,df)
summary(model_true)
## 
## Call:
## lm(formula = y ~ x * x_i, data = df)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.6430 -0.6684 -0.0105  0.6703  3.0552 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  0.01700    0.03172   0.536    0.592    
## x            0.22356    0.03066   7.291 3.91e-13 ***
## x_ix2        0.96309    0.04486  21.471  < 2e-16 ***
## x_ix3        1.97089    0.04486  43.938  < 2e-16 ***
## x:x_ix2      0.38554    0.04325   8.913  < 2e-16 ***
## x:x_ix3     -0.39672    0.04344  -9.133  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.003 on 2994 degrees of freedom
## Multiple R-squared:  0.4465,	Adjusted R-squared:  0.4455 
## F-statistic:   483 on 5 and 2994 DF,  p-value: < 2.2e-16

Intercepts of x1, x2, and x3

Notice that the first intercept is the first equation of y1 = x1 + noise and the intercept should be zero, or at least close to zero. The second intercept is x2’s. Notice that in the equation that simulates the data has a 1 as intercept? Where does this recide in the true model? Yes you’re right, on x_ix2 ! It’s not 1 but it’s close to 1. And the third intercept for x3 is on x_ix3.

What if intercept for x1 is not zero? Read on below.

Slopes of x1, x2, and x3

What about the slopes? the slope for x1 is on x, x2 is on x:x_ix2, x3 is on x:x_ix3. The interaction is the slope! How cool! Wait, you may say, they don’t add up! x2 coefficient should be 0.6 but why is x:x_ix2 about 0.22 less? Wait a minute, does 0.22 of coefficient look familiar to you? It’s x1’s coefficient (or in this case listed as x).

Wowowow, hold the phone, so the slopes for x2 is the sum of x1 coefficient and x:x_ix2 the interaction term !?! YES, precisely! And the same would be for x3 too? Let’s do the math, 0.224 + (-0.397) = -0.173 which is very close to -0.2 which is the true x3 coefficient! And you can even see the negative slope from the visualization too (represented by blue color). Superb! 🙌

So, if the intercept x1 is not zero, it’s the same thing as the slope coefficients, you just simply add them up! How neat to be able to combine all 3 equations (of the same measurements of course, meaning y and x are measuring the same things) into 1 equation!

Wrong Model ❌

model_wrong <- lm(y~x + x_i,df)
summary(model_wrong)
## 
## Call:
## lm(formula = y ~ x + x_i, data = df)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.9952 -0.7085 -0.0088  0.7068  3.4459 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  0.01698    0.03339   0.509    0.611    
## x            0.22206    0.01863  11.922   <2e-16 ***
## x_ix2        0.95681    0.04721  20.265   <2e-16 ***
## x_ix3        1.96486    0.04722  41.614   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.056 on 2996 degrees of freedom
## Multiple R-squared:  0.3862,	Adjusted R-squared:  0.3856 
## F-statistic: 628.4 on 3 and 2996 DF,  p-value: < 2.2e-16

Notice that if we only naively model y and x relationship without taking account to the interaction terms, we’re not seeing the whole picture. We’ll be missing out the actual slopes and falsely thinking that x_ix3 is the slope, when it is supposed to be negative.

If we were to visualize it would be something like this

df |>
  ggplot(aes(x=x,y=y,color=x_i)) +
  geom_point() +
  geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x"]], color = "red") +
  geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x_ix2"]], color = "green") +
  geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x_ix3"]], color = "blue") +
  theme_minimal()

Doesn’t look right, does it?

What is S Learner?

S-learner is one of the early Machine Learning Meta-learners that can be used for estimating the conditional average treatment effect (CATE) in causal inference. The simplest in my opinion, which is great, makes us understand the other Meta-learners better, such as T, X, R, DL, after understanding how to construct this one first. We basically take treatment variable just like any other covariate, train the model with a machine learning model of your choice, and then use that to estimate CATE, all in ONE model!

Response function:
\(\mu(x,z) := \mathbb{E}(Y^{obs}|X=x,Z=z)\)

Estimate CATE by:
\(\hat{t}(x) = \mu(X=1,Z=z) - \mu(X=0, Z=z)\)

$\mu$: Model of choice.
Y: Outcome.
X: binary treatment variable.
Z: other covariates. Though, we don’t have any in our current simulation.
$\hat{t}$: CATE

For excellent explaination on S-learner please check this link by José Luis Cañadas Reche and this link on Statistical Odds & Ends.

What is CATE?

Conditional Average Treatment Effect (CATE) is a foundational concept in causal inference, focusing on the difference in expected outcomes between individuals who receive a treatment and those who do not, while considering their unique characteristics or covariates. Representing the average impact of treatment while accounting for individual differences, CATE helps answer how treatments influence outcomes in a given context. It’s calculated as the difference between the expected outcomes under treatment and no treatment conditions for individuals with specific covariate values, providing crucial insights into causal relationships and guiding decision-making across various domains.

Boost Tree Model

I’m a great supporter of tidymodels. In this context, we’ll utilize this framework to apply boosting tree methods and determine whether they can reveal interaction terms without requiring explicit specification. For S-learner, we basically will use a model, train with all data, then use the model to calculate CATE.

Light GBM

#split
# split <- initial_split(df, prop = 0.8)
# train <- training(split)
# test <- testing(split)

#preprocess
rec <- recipe(y ~ ., data = df) |>
  step_dummy(all_nominal())

rec |> prep() |> juice() |> head(5) |> kable()
x y x_i_x2 x_i_x3
-0.6264538 0.6138242 0 0
0.1836433 0.4233374 0 0
-0.8356286 1.1292714 0 0
1.5952808 -0.4845022 0 0
0.3295078 -1.5367241 0 0
#cv
cv <- vfold_cv(data = df, v = 5, repeats = 5)

#engine
gbm <- boost_tree() |>
  set_engine("lightgbm") |>
  set_mode("regression")


#workflow
gbm_wf <- workflow() |>
  add_recipe(rec) |>
  add_model(gbm)

#assess
gbm_assess <- gbm_wf %>%
  fit_resamples(
    resamples = cv,
    metrics = metric_set(rmse, rsq, ccc),
    control = control_resamples(save_pred = TRUE, verbose = TRUE)
  )

gbm_assess |>
  collect_metrics() |> kable()
.metric .estimator mean n std_err .config
ccc standard 0.5971348 25 0.0028353 Preprocessor1_Model1
rmse standard 1.0401944 25 0.0054528 Preprocessor1_Model1
rsq standard 0.4072577 25 0.0033576 Preprocessor1_Model1
#fit
gbm_fit <- gbm_wf |>
  fit(df)

Alright, let’s use default hyperparameters without tuning lightbgm and see how things go. Observe that during our preprocessing, we did not specify interaction terms such as y ~ x*x_i.

Let’s look at our True Model CATE of x2 comparing to x1 when x is 3

predict(model_true, newdata=tibble(x=3,x_i="x2")) - predict(model_true, newdata=tibble(x=3,x_i="x1")) 
##        1 
## 2.119704

Let’s look at our Wrong Model CATE of x2 comparing to x1 when x is 3

predict(model_wrong, newdata=tibble(x=3,x_i="x2")) - predict(model_wrong, newdata=tibble(x=3,x_i="x1")) 
##         1 
## 0.9568113

Let’s look at LightGBM CATE of x2 comparing to x1 when x is 3. 🌸

predict(gbm_fit, new_data=tibble(x=3,x_i="x2")) - predict(gbm_fit, new_data=tibble(x=3,x_i="x1")) 
##      .pred
## 1 1.952928

Wow, LightGBM is quite close to the tru model. Let’s sequence a vector of x and assess all 3 models.

# write a function
assess <- function(model,x,x_i,x_base="x1") {
  if (class(model)!="workflow") {
    diff <- predict(model, newdata=tibble(x=!!x,x_i=!!x_i)) - predict(model, newdata=tibble(x=!!x,x_i=x_base))
  } else { 
  diff <- (predict(model, new_data=tibble(x=!!x,x_i=!!x_i)) - predict(model, new_data=tibble(x=!!x,x_i=x_base))) |> pull() }
  return(tibble(x=!!x,diff=diff))
}

# sequence of x's
x <- seq(-3,3,0.1)

# type of x_i of interest
x_i <- "x2"

gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point() +
  # geom_smooth(method="lm") +
  theme_minimal() +
  ggtitle("LightGBM") +
  ylab("CATE")

true_g <- assess(model=model_true,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model With Interaction Terms") +
  ylab("CATE")

wrong_g <- assess(model=model_wrong,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model Without Interaction Terms") +
  ylab("CATE")

ggarrange(gbm_g,true_g,wrong_g, ncol = 1)

Very interesting indeed! LightGBM seems to be able to get very similar CATE compared to True model! The sequence of prediction by LightGBM also looks quite familiar, doesn’t it? If we were to use its prediction regressing on x, it looks like the slope might be very similar to the true model. Let’s give a a try

LightGBM’s CATE regressing on x

gbm_pred <- assess(gbm_fit, x = x, x_i = "x2")
summary(lm(diff~x,gbm_pred))
## 
## Call:
## lm(formula = diff ~ x, data = gbm_pred)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.65741 -0.10984 -0.01303  0.13675  0.54059 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  0.98884    0.03152   31.38   <2e-16 ***
## x            0.37032    0.01790   20.69   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.2462 on 59 degrees of freedom
## Multiple R-squared:  0.8788,	Adjusted R-squared:  0.8768 
## F-statistic:   428 on 1 and 59 DF,  p-value: < 2.2e-16

True model’s CATE regressing on x

model_true_pred <- assess(model_true, x, "x2")
summary(lm(diff~x,model_true_pred))
## 
## Call:
## lm(formula = diff ~ x, data = model_true_pred)
## 
## Residuals:
##        Min         1Q     Median         3Q        Max 
## -2.883e-15 -2.088e-17  3.261e-17  7.979e-17  2.103e-15 
## 
## Coefficients:
##              Estimate Std. Error   t value Pr(>|t|)    
## (Intercept) 9.631e-01  6.155e-17 1.565e+16   <2e-16 ***
## x           3.855e-01  3.496e-17 1.103e+16   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4.808e-16 on 59 degrees of freedom
## Multiple R-squared:      1,	Adjusted R-squared:      1 
## F-statistic: 1.216e+32 on 1 and 59 DF,  p-value: < 2.2e-16

very very similar indeed! Let’s visualize them side by side

gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point() +
  geom_smooth(method="lm") +
  theme_minimal() +
  ggtitle("LightGBM") +
  ylab("CATE")

ggarrange(gbm_g,true_g)

What if we look at x3?

# sequence of x's
x <- seq(-3,3,0.1)

# type of x_i of interest
x_i <- "x3"

gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point() +
  geom_smooth(method="lm") +
  theme_minimal() +
  ggtitle("LightGBM") +
  ylab("CATE")

true_g <- assess(model=model_true,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model With Interaction Terms") +
  ylab("CATE")

wrong_g <- assess(model=model_wrong,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model Without Interaction Terms") +
  ylab("CATE")

ggarrange(gbm_g,true_g,wrong_g, ncol = 1)

XGBoost

gbb <- boost_tree() |>
  set_engine("xgboost") |>
  set_mode("regression")

#workflow
gbb_wf <- workflow() |>
  add_recipe(rec) |>
  add_model(gbb)

#assess
gbb_assess <- gbb_wf %>%
  fit_resamples(
    resamples = cv,
    metrics = metric_set(rmse, rsq, ccc),
    control = control_resamples(save_pred = TRUE, verbose = TRUE)
  )

gbb_assess |>
  collect_metrics() |> kable()
.metric .estimator mean n std_err .config
ccc standard 0.5968061 25 0.0031061 Preprocessor1_Model1
rmse standard 1.0318791 25 0.0057834 Preprocessor1_Model1
rsq standard 0.4144199 25 0.0038228 Preprocessor1_Model1
gbb_fit <- gbb_wf |>
  fit(df)

# sequence of x's
x <- seq(-3,3,0.1)

# type of x_i of interest
x_i <- "x2"

gbb_g<- assess(model=gbb_fit,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point() +
  geom_smooth(method="lm") +
  theme_minimal() +
  ggtitle("LightGBM") +
  ylab("CATE")

true_g <- assess(model=model_true,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model With Interaction Terms") +
  ylab("CATE")

ggarrange(gbb_g,true_g)
x_i <- "x3"

gbb_g<- assess(model=gbb_fit,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point() +
  geom_smooth(method="lm") +
  theme_minimal() +
  ggtitle("LightGBM") +
  ylab("CATE")

true_g <- assess(model=model_true,x=x,x_i=x_i) |>
  ggplot(aes(x=x,y=diff)) +
  geom_point()  +
  theme_minimal() +
  ggtitle("Linear Model With Interaction Terms") +
  ylab("CATE")

ggarrange(gbb_g,true_g)

Not too shabby either!

image

Limitation

  • We haven’t truly added confounders such as z to the mix, we’ll hae to see how that pans out with interaction. I sense another blog coming up!
  • We may not be actually be using the S-learner method in the right setting. It’s usually to assess a binary treatment and its outcome. Here we’re focused more on interaction terms and how boost tree can tease that out.
  • If you see any mistakes or any comments, please feel free to reach out!

Acknowledgement

Lessons Learnt

  • geom_abline to draw custom line based on intercept + slopes
  • \mathbb{E} symbol for Expected in LaTeX
  • Gradient boosting models such as lightGBM can tease out interaction, quite a handy tool!
  • Learnt what interaction is and how to interpret its summary
  • Learnt S-learner


If you like this article:

Posted on:
September 4, 2023
Length:
12 minute read, 2526 words
Categories:
r R interaction meta learner s learner boost tree lightgbm
Tags:
r R interaction meta learner s learner boost tree lightgbm
See Also:
Tidyverse 🪐to Polars 🐻‍❄️: My Notes
Stable Diffusion 3 in R? Why not? Thanks to {reticulate} 🙏❤️🙌
Gemini 1.5 Flash Better Than RAG? Let's Check It Out In R!