Exploring Interaction Effects and S-Learners
By Ken Koon Wong in r R interaction meta learner s learner boost tree lightgbm
September 4, 2023
Interaction adventures through simulations and gradient boosting trees using the S-learner approach. I hadn’t realized that lightGBM and XGBoost could reveal interaction terms without explicit specification. Quite intriguing!
picture resembles interaction
🤣
Objectives:
- What is interaction?
- Simulate interaction
- Visualize interaction
- True Model ✅
- Wrong Model ❌
- What is S Learner?
- What is CATE?
- Boost Tree Model
- Limitation
- Acknowledgement
- Lessons Learnt
What is interaction?
In statistics, interaction refers to the phenomenon where the effect of one variable on an outcome is influenced by the presence or change in another variable. It indicates that the relationship between variables is not simply additive, but rather depends on the interaction between their values. Understanding interactions is crucial for capturing complex relationships and building accurate predictive models that consider the combined influence of variables, providing deeper insights into data analysis across various fields.
Still don’t understand? No worries, you’re not alone. I’ve been there too! However, I found that simulating and visualizing interactions really helped solidify my understanding of their significance. Missing out on understanding interactions is like skipping a few chapters in the story – it’s an essential part of grasping the whole picture.
Simulate interaction
library(tidymodels)
library(tidyverse)
library(bonsai)
library(kableExtra)
library(ggpubr)
set.seed(1)
n <- 1000
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
y1 <- 0.2*x1 + rnorm(n)
y2 <- 1 + 0.6*x2 + rnorm(n)
y3 <- 2 + -0.2*x3 + rnorm(n)
# combining all y_i to 1 vector
y <- c(y1,y2,y3)
# categorize x1, x2, and x3
df <- tibble(y=y,x=c(x1,x2,x3),x_i=c(rep("x1",n),rep("x2",n),rep("x3",n)))
kable(df |> head(5))
y | x | x_i |
---|---|---|
0.6138242 | -0.6264538 | x1 |
0.4233374 | 0.1836433 | x1 |
1.1292714 | -0.8356286 | x1 |
-0.4845022 | 1.5952808 | x1 |
-1.5367241 | 0.3295078 | x1 |
But hold on, Ken, that’s not interaction, right? Doesn’t interaction involve terms like y = x + w + x*w
? Well, yes, you’re right about that equation, but the explanation above offers a more intuitive grasp of what interaction entails. It’s like taking the nominal categories x1, x2, x3
, treating them as on-off switches, and then consolidating all three equations into one primary equation. This approach helps in calculating the interrelationships between them. Trust me. Also this was a great
resource in understanding the interaction formula too.
This is essentially our formula:
\(Y_i = B_0 + B_1X_{1i} + B_2X_{2i} + B_3X_{3i} + B_4X_{1i}X_{2i} + B_5X_{1i}X_{3i} + e_i\)
.
Wow, there are too many \(B\)
’s. I can see why it is hard to follow. We’ll carry this formula with us and unpack it later on.
Visualize interaction
df |>
ggplot(aes(x=x,y=y,color=as.factor(x_i))) +
geom_point(alpha = 0.5) +
geom_smooth(method = "lm") +
theme_minimal()
Wow, look at those slopes of x1, x2, x3
. We can all agree that they all have different intercepts
and slopes
. That my friend, is interaction.
True Model ✅
model_true <- lm(y~x*x_i,df)
summary(model_true)
##
## Call:
## lm(formula = y ~ x * x_i, data = df)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3.6430 -0.6684 -0.0105 0.6703 3.0552
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.01700 0.03172 0.536 0.592
## x 0.22356 0.03066 7.291 3.91e-13 ***
## x_ix2 0.96309 0.04486 21.471 < 2e-16 ***
## x_ix3 1.97089 0.04486 43.938 < 2e-16 ***
## x:x_ix2 0.38554 0.04325 8.913 < 2e-16 ***
## x:x_ix3 -0.39672 0.04344 -9.133 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1.003 on 2994 degrees of freedom
## Multiple R-squared: 0.4465, Adjusted R-squared: 0.4455
## F-statistic: 483 on 5 and 2994 DF, p-value: < 2.2e-16
Intercepts of x1
, x2
, and x3
Notice that the first intercept is the first equation of y1 = x1 + noise
and the intercept should be zero, or at least close to zero. The second intercept is x2
’s. Notice that in the equation that simulates the data has a 1
as intercept? Where does this recide in the true model? Yes you’re right, on x_ix2
! It’s not 1
but it’s close to 1
. And the third intercept for x3
is on x_ix3
.
What if intercept for x1
is not zero? Read on below.
Slopes of x1
, x2
, and x3
What about the slopes? the slope for x1
is on x
, x2
is on x:x_ix2
, x3
is on x:x_ix3
. The interaction is the slope! How cool! Wait, you may say, they don’t add up! x2
coefficient should be 0.6
but why is x:x_ix2
about 0.22
less? Wait a minute, does 0.22
of coefficient look familiar to you? It’s x1
’s coefficient (or in this case listed as x
).
Wowowow, hold the phone, so the slopes for x2
is the sum of x1
coefficient and x:x_ix2
the interaction term !?! YES, precisely! And the same would be for x3
too? Let’s do the math, 0.224 + (-0.397) = -0.173
which is very close to -0.2
which is the true x3
coefficient! And you can even see the negative slope
from the visualization too (represented by blue color). Superb! 🙌
So, if the intercept x1
is not zero, it’s the same thing as the slope
coefficients, you just simply add them up! How neat to be able to combine all 3 equations (of the same measurements of course, meaning y and x are measuring the same things) into 1 equation!
Wrong Model ❌
model_wrong <- lm(y~x + x_i,df)
summary(model_wrong)
##
## Call:
## lm(formula = y ~ x + x_i, data = df)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3.9952 -0.7085 -0.0088 0.7068 3.4459
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.01698 0.03339 0.509 0.611
## x 0.22206 0.01863 11.922 <2e-16 ***
## x_ix2 0.95681 0.04721 20.265 <2e-16 ***
## x_ix3 1.96486 0.04722 41.614 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1.056 on 2996 degrees of freedom
## Multiple R-squared: 0.3862, Adjusted R-squared: 0.3856
## F-statistic: 628.4 on 3 and 2996 DF, p-value: < 2.2e-16
Notice that if we only naively model y
and x
relationship without taking account to the interaction terms, we’re not seeing the whole picture. We’ll be missing out the actual slopes
and falsely thinking that x_ix3
is the slope, when it is supposed to be negative.
If we were to visualize it would be something like this
df |>
ggplot(aes(x=x,y=y,color=x_i)) +
geom_point() +
geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x"]], color = "red") +
geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x_ix2"]], color = "green") +
geom_abline(intercept = model_wrong$coefficients[[1]], slope = model_wrong$coefficients[["x_ix3"]], color = "blue") +
theme_minimal()
Doesn’t look right, does it?
What is S Learner?
S-learner is one of the early Machine Learning Meta-learners that can be used for estimating the conditional average treatment effect (CATE) in causal inference. The simplest in my opinion, which is great, makes us understand the other Meta-learners better, such as T, X, R, DL, after understanding how to construct this one first. We basically take treatment
variable just like any other covariate, train the model with a machine learning model of your choice, and then use that to estimate CATE, all in ONE model!
Response function:
\(\mu(x,z) := \mathbb{E}(Y^{obs}|X=x,Z=z)\)
Estimate CATE by:
\(\hat{t}(x) = \mu(X=1,Z=z) - \mu(X=0, Z=z)\)
$\mu$
: Model of choice.
Y
: Outcome.
X
: binary treatment variable.
Z
: other covariates. Though, we don’t have any in our current simulation.
$\hat{t}$
: CATE
For excellent explaination on S-learner please check this link by José Luis Cañadas Reche and this link on Statistical Odds & Ends.
What is CATE?
Conditional Average Treatment Effect (CATE) is a foundational concept in causal inference, focusing on the difference in expected outcomes between individuals who receive a treatment and those who do not, while considering their unique characteristics or covariates. Representing the average impact of treatment while accounting for individual differences, CATE helps answer how treatments influence outcomes in a given context. It’s calculated as the difference between the expected outcomes under treatment and no treatment conditions for individuals with specific covariate values, providing crucial insights into causal relationships and guiding decision-making across various domains.
Boost Tree Model
I’m a great supporter of tidymodels. In this context, we’ll utilize this framework to apply boosting tree methods and determine whether they can reveal interaction terms without requiring explicit specification. For S-learner, we basically will use a model, train with all data, then use the model to calculate CATE.
Light GBM
#split
# split <- initial_split(df, prop = 0.8)
# train <- training(split)
# test <- testing(split)
#preprocess
rec <- recipe(y ~ ., data = df) |>
step_dummy(all_nominal())
rec |> prep() |> juice() |> head(5) |> kable()
x | y | x_i_x2 | x_i_x3 |
---|---|---|---|
-0.6264538 | 0.6138242 | 0 | 0 |
0.1836433 | 0.4233374 | 0 | 0 |
-0.8356286 | 1.1292714 | 0 | 0 |
1.5952808 | -0.4845022 | 0 | 0 |
0.3295078 | -1.5367241 | 0 | 0 |
#cv
cv <- vfold_cv(data = df, v = 5, repeats = 5)
#engine
gbm <- boost_tree() |>
set_engine("lightgbm") |>
set_mode("regression")
#workflow
gbm_wf <- workflow() |>
add_recipe(rec) |>
add_model(gbm)
#assess
gbm_assess <- gbm_wf %>%
fit_resamples(
resamples = cv,
metrics = metric_set(rmse, rsq, ccc),
control = control_resamples(save_pred = TRUE, verbose = TRUE)
)
gbm_assess |>
collect_metrics() |> kable()
.metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|
ccc | standard | 0.5971348 | 25 | 0.0028353 | Preprocessor1_Model1 |
rmse | standard | 1.0401944 | 25 | 0.0054528 | Preprocessor1_Model1 |
rsq | standard | 0.4072577 | 25 | 0.0033576 | Preprocessor1_Model1 |
#fit
gbm_fit <- gbm_wf |>
fit(df)
Alright, let’s use default hyperparameters without tuning lightbgm
and see how things go. Observe that during our preprocessing, we did not specify interaction terms such as y ~ x*x_i
.
Let’s look at our True Model CATE of x2
comparing to x1
when x
is 3
✅
predict(model_true, newdata=tibble(x=3,x_i="x2")) - predict(model_true, newdata=tibble(x=3,x_i="x1"))
## 1
## 2.119704
Let’s look at our Wrong Model CATE of x2
comparing to x1
when x
is 3
❌
predict(model_wrong, newdata=tibble(x=3,x_i="x2")) - predict(model_wrong, newdata=tibble(x=3,x_i="x1"))
## 1
## 0.9568113
Let’s look at LightGBM CATE of x2
comparing to x1
when x
is 3
. 🌸
predict(gbm_fit, new_data=tibble(x=3,x_i="x2")) - predict(gbm_fit, new_data=tibble(x=3,x_i="x1"))
## .pred
## 1 1.952928
Wow, LightGBM is quite close to the tru model. Let’s sequence a vector of x
and assess all 3 models.
# write a function
assess <- function(model,x,x_i,x_base="x1") {
if (class(model)!="workflow") {
diff <- predict(model, newdata=tibble(x=!!x,x_i=!!x_i)) - predict(model, newdata=tibble(x=!!x,x_i=x_base))
} else {
diff <- (predict(model, new_data=tibble(x=!!x,x_i=!!x_i)) - predict(model, new_data=tibble(x=!!x,x_i=x_base))) |> pull() }
return(tibble(x=!!x,diff=diff))
}
# sequence of x's
x <- seq(-3,3,0.1)
# type of x_i of interest
x_i <- "x2"
gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
# geom_smooth(method="lm") +
theme_minimal() +
ggtitle("LightGBM") +
ylab("CATE")
true_g <- assess(model=model_true,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model With Interaction Terms") +
ylab("CATE")
wrong_g <- assess(model=model_wrong,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model Without Interaction Terms") +
ylab("CATE")
ggarrange(gbm_g,true_g,wrong_g, ncol = 1)
Very interesting indeed! LightGBM
seems to be able to get very similar CATE compared to True model! The sequence of prediction by LightGBM
also looks quite familiar, doesn’t it? If we were to use its prediction regressing on x
, it looks like the slope might be very similar to the true model. Let’s give a a try
LightGBM’s CATE regressing on x
gbm_pred <- assess(gbm_fit, x = x, x_i = "x2")
summary(lm(diff~x,gbm_pred))
##
## Call:
## lm(formula = diff ~ x, data = gbm_pred)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.65741 -0.10984 -0.01303 0.13675 0.54059
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.98884 0.03152 31.38 <2e-16 ***
## x 0.37032 0.01790 20.69 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.2462 on 59 degrees of freedom
## Multiple R-squared: 0.8788, Adjusted R-squared: 0.8768
## F-statistic: 428 on 1 and 59 DF, p-value: < 2.2e-16
True model’s CATE regressing on x
model_true_pred <- assess(model_true, x, "x2")
summary(lm(diff~x,model_true_pred))
##
## Call:
## lm(formula = diff ~ x, data = model_true_pred)
##
## Residuals:
## Min 1Q Median 3Q Max
## -2.883e-15 -2.088e-17 3.261e-17 7.979e-17 2.103e-15
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 9.631e-01 6.155e-17 1.565e+16 <2e-16 ***
## x 3.855e-01 3.496e-17 1.103e+16 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4.808e-16 on 59 degrees of freedom
## Multiple R-squared: 1, Adjusted R-squared: 1
## F-statistic: 1.216e+32 on 1 and 59 DF, p-value: < 2.2e-16
very very similar indeed! Let’s visualize them side by side
gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
geom_smooth(method="lm") +
theme_minimal() +
ggtitle("LightGBM") +
ylab("CATE")
ggarrange(gbm_g,true_g)
What if we look at x3
?
# sequence of x's
x <- seq(-3,3,0.1)
# type of x_i of interest
x_i <- "x3"
gbm_g<- assess(model=gbm_fit,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
geom_smooth(method="lm") +
theme_minimal() +
ggtitle("LightGBM") +
ylab("CATE")
true_g <- assess(model=model_true,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model With Interaction Terms") +
ylab("CATE")
wrong_g <- assess(model=model_wrong,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model Without Interaction Terms") +
ylab("CATE")
ggarrange(gbm_g,true_g,wrong_g, ncol = 1)
XGBoost
gbb <- boost_tree() |>
set_engine("xgboost") |>
set_mode("regression")
#workflow
gbb_wf <- workflow() |>
add_recipe(rec) |>
add_model(gbb)
#assess
gbb_assess <- gbb_wf %>%
fit_resamples(
resamples = cv,
metrics = metric_set(rmse, rsq, ccc),
control = control_resamples(save_pred = TRUE, verbose = TRUE)
)
gbb_assess |>
collect_metrics() |> kable()
.metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|
ccc | standard | 0.5968061 | 25 | 0.0031061 | Preprocessor1_Model1 |
rmse | standard | 1.0318791 | 25 | 0.0057834 | Preprocessor1_Model1 |
rsq | standard | 0.4144199 | 25 | 0.0038228 | Preprocessor1_Model1 |
gbb_fit <- gbb_wf |>
fit(df)
# sequence of x's
x <- seq(-3,3,0.1)
# type of x_i of interest
x_i <- "x2"
gbb_g<- assess(model=gbb_fit,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
geom_smooth(method="lm") +
theme_minimal() +
ggtitle("LightGBM") +
ylab("CATE")
true_g <- assess(model=model_true,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model With Interaction Terms") +
ylab("CATE")
ggarrange(gbb_g,true_g)
x_i <- "x3"
gbb_g<- assess(model=gbb_fit,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
geom_smooth(method="lm") +
theme_minimal() +
ggtitle("LightGBM") +
ylab("CATE")
true_g <- assess(model=model_true,x=x,x_i=x_i) |>
ggplot(aes(x=x,y=diff)) +
geom_point() +
theme_minimal() +
ggtitle("Linear Model With Interaction Terms") +
ylab("CATE")
ggarrange(gbb_g,true_g)
Not too shabby either!
Limitation
- We haven’t truly added confounders such as
z
to the mix, we’ll hae to see how that pans out with interaction. I sense another blog coming up! - We may not be actually be using the S-learner method in the right setting. It’s usually to assess a binary treatment and its outcome. Here we’re focused more on interaction terms and how boost tree can tease that out.
- If you see any mistakes or any comments, please feel free to reach out!
Acknowledgement
- Thanks to José Luis Cañadas Reche’s inspiring S-learner blog, please check this link. I wasn’t planning on doing interaction and S-learner at the same time, but this gave me an opportunity.
- Barr, Dale J. (2021). Learning statistical models through simulation in R: An interactive textbook really helped me to understand interaction
- Last but not least, Aleksander Molak’s book is awesome and learnt about Meta-learners.
Lessons Learnt
geom_abline
to draw custom line based on intercept + slopes\mathbb{E}
symbol for Expected in LaTeX- Gradient boosting models such as
lightGBM
can tease out interaction, quite a handy tool! - Learnt what interaction is and how to interpret its summary
- Learnt S-learner
If you like this article:
- please feel free to send me a comment or visit my other blogs
- please feel free to follow me on twitter, GitHub or Mastodon
- if you would like collaborate please feel free to contact me
- Posted on:
- September 4, 2023
- Length:
- 12 minute read, 2526 words
- Categories:
- r R interaction meta learner s learner boost tree lightgbm