Causal AI, exploring the integration of causal reasoning into machine learning
14 hours ago
Welcome to my series on Causal AI, where we will explore the integration of causal reasoning into machine learning models. Expect to explore a number of practical applications across different business contexts.
In the last article we covered measuring the intrinsic causal influence of your marketing campaigns. In this article we will move onto validating the causal impact of the synthetic controls.
If you missed the last article on intrinsic causal influence, check it out here:
In this article we will focus on understanding the synthetic control method and exploring how we can validate the estimated causal impact.
The following aspects will be covered:
- What is the synthetic control method?
- What challenge does it try to overcome?
- How can we validate the estimated causal impact?
- A Python case study using realistic google trend data, demonstrating how we can validate the estimated causal impact of the synthetic controls.
The full notebook can be found here:
What is it?
The synthetic control method is a causal technique which can be used to assess the causal impact of an intervention or treatment when a randomised control trial (RCT) or A/B test was not possible. It was originally proposed in 2003 by Abadie and Gardezabal. The following paper includes a great case study to help you understand the proposed method:
https://web.stanford.edu/~jhain/Paper/JASA2010.pdf
Let’s cover some of the basics ourselves… The synthetic control method creates a counterfactual version of the treatment unit by creating a weighted combination of control units that did not receive the intervention or treatment.
- Treated unit: The unit which receives the intervention.
- Control units: A set of similar units which did not receive the intervention.
- Counterfactual: Created as a weighted combination of the control units. Aim is to find weights for each control unit that result in a counterfactual which closely matches the treated unit in the pre-intervention period.
- Causal impact: The difference between the post-intervention treatment unit and counterfactual.
If we wanted to really simplify things, we could think of it as linear regression where each control unit is a feature and the treatment unit is the target. The pre-intervention period is our train set, and we use the model to score our post-intervention period. The difference between the actual and predicted is the causal impact.
Below are a couple examples to bring it to life when we might consider using it:
- When running a TV marketing campaign, we are unable to randomly assign the audience into those that can and can’t see the campaign. We could however, carefully select a region to trial the campaign and use the remaining regions as control units. Once we have measured the effect the campaign could be rolled out to other regions. This is often called a geo-lift test.
- Policy changes which are brought into some regions but not others — For example a local council may bring a policy change into force to reduce unemployment. Other regions where the policy wasn’t in place could be used as control units.
What challenge does it try to overcome?
When we combine high-dimensionality (lots of features) with limited observations, we can get a model which overfits.
Let’s take the geo-lift example to illustrate. If we use weekly data from the last year as our pre-intervention period, this gives us 52 observations. If we then decide to test our intervention across countries in Europe, that will give us an observation to feature ratio of 1:1!
Earlier we talked about how the synthetic control method could be implemented using linear regression. However, the observation to feature ratio mean it is very likely linear regression will overfit resulting in a poor causal impact estimate in the post-intervention period.
In linear regression the weights (coefficients) for each feature (control unit) could be negative or positive and they may sum to a number greater than 1. However, the synthetic control method learns the weights whilst applying the below constraints:
- Constraining weights to sum to 1
- Constraining weights to be ≥ 0
These constraints help with regularisation and avoid extrapolation beyond the range of the observed data.
It is worth noting that in terms of regularisation, Ridge and Lasso regression can achieve this, and in some cases are reasonable alternatives. But we will test this out in the case study!
How can we validate the estimated causal impact?
An arguably bigger challenge is the fact that we are unable to validate the estimated causal impact in the post-intervention period.
How long should my pre-intervention period be? Are we sure we haven’t overfit our pre-intervention period? How can we know whether our model generalises well in the post intervention period? What if I want to try out different implementations of synthetic control method?
We could randomly select a few observations from the pre-intervention period and hold them back for validation — But we have already highlighted the challenge which comes from having limited observations so we may make things even worse!
What if we could run some sort of pre-intervention simulation? Could that help us answer some of the questions highlighted above and gain confidence in our models estimated causal impact? All will be explained in the case study!
Background
After convincing Finance that brand marketing is driving some serious value, the marketing team approach you to ask about geo-lift testing. Someone from Facebook has told them it’s the next big thing (although it was the same person who told them Prophet was a good forecasting model) and they want to know whether they could use it to measure their new TV campaign which is coming up.
You are a little concerned, as the last time you ran a geo-lift test the marketing analytics team thought it was a good idea to play around with the pre-intervention period used until they had a nice big causal impact.
This time round, you suggest that they run a “pre-intervention simulation” after which you propose that the pre-intervention period is agreed before the test begins.
So let’s explore what a “pre-intervention simulation” looks like!
Creating the data
To make this as realistic as possible, I extracted some google trend data for the majority of countries in Europe. What the search term was isn’t relevant, just pretend it’s the sales for you company (and that you operate across Europe).
However, if you are interested in how I got the google trend data, check out my notebook:
Below we can see the dataframe. We have sales for the past 3 years across 50 European countries. The marketing team plan to run their TV campaign in Great Britain.
Now here comes the clever bit. We will simulate an intervention in the last 7 weeks of the time series.
np.random.seed(1234)# Create intervention flag
mask = (df['date'] >= "2024-04-14") & (df['date'] <= "2024-06-02")
df['intervention'] = mask.astype(int)
row_count = len(df)
# Create intervention uplift
df['uplift_perc'] = np.random.uniform(0.10, 0.20, size=row_count)
df['uplift_abs'] = round(df['uplift_perc'] * df['GB'])
df['y'] = df['GB']
df.loc[df['intervention'] == 1, 'y'] = df['GB'] + df['uplift_abs']
Now let’s plot the actual and counterfactual sales across GB to bring what we have done to life:
def synth_plot(df, counterfactual):plt.figure(figsize=(14, 8))
sns.set_style("white")
# Create plot
sns.lineplot(data=df, x='date', y='y', label='Actual', color='b', linewidth=2.5)
sns.lineplot(data=df, x='date', y=counterfactual, label='Counterfactual', color='r', linestyle='--', linewidth=2.5)
plt.title('Synthetic Control Method: Actual vs. Counterfactual', fontsize=24)
plt.xlabel('Date', fontsize=20)
plt.ylabel('Metric Value', fontsize=20)
plt.legend(fontsize=16)
plt.gca().xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
plt.xticks(rotation=90)
plt.grid(True, linestyle='--', alpha=0.5)
# High the intervention point
intervention_date = '2024-04-07'
plt.axvline(pd.to_datetime(intervention_date), color='k', linestyle='--', linewidth=1)
plt.text(pd.to_datetime(intervention_date), plt.ylim()[1]*0.95, 'Intervention', color='k', fontsize=18, ha='right')
plt.tight_layout()
plt.show()
synth_plot(df, 'GB')
So now we have simulated an intervention, we can explore how well the synthetic control method will work.
Pre-processing
All of the European countries apart from GB are set as control units (features). The treatment unit (target) is the sales in GB with the intervention applied.
# Delete the original target column so we don't use it as a feature by accident
del df['GB']# set feature & targets
X = df.columns[1:50]
y = 'y'
Regression
Below I have setup a function which we can re-use with different pre-intervention periods and different regression models (e.g. Ridge, Lasso):
def train_reg(df, start_index, reg_class):df_temp = df.iloc[start_index:].copy().reset_index()
X_pre = df_temp[df_temp['intervention'] == 0][X]
y_pre = df_temp[df_temp['intervention'] == 0][y]
X_train, X_test, y_train, y_test = train_test_split(X_pre, y_pre, test_size=0.10, random_state=42)
model = reg_class
model.fit(X_train, y_train)
yhat_train = model.predict(X_train)
yhat_test = model.predict(X_test)
mse_train = mean_squared_error(y_train, yhat_train)
mse_test = mean_squared_error(y_test, yhat_test)
print(f"Mean Squared Error train: {round(mse_train, 2)}")
print(f"Mean Squared Error test: {round(mse_test, 2)}")
r2_train = r2_score(y_train, yhat_train)
r2_test = r2_score(y_test, yhat_test)
print(f"R2 train: {round(r2_train, 2)}")
print(f"R2 test: {round(r2_test, 2)}")
df_temp['pred'] = model.predict(df_temp.loc[:, X])
df_temp['delta'] = df_temp['y'] - df_temp['pred']
pred_lift = df_temp[df_temp['intervention'] == 1]['delta'].sum()
actual_lift = df_temp[df_temp['intervention'] == 1]['uplift_abs'].sum()
abs_error_perc = abs(pred_lift - actual_lift) / actual_lift
print(f"Predicted lift: {round(pred_lift, 2)}")
print(f"Actual lift: {round(actual_lift, 2)}")
print(f"Absolute error percentage: {round(abs_error_perc, 2)}")
return df_temp, abs_error_perc
To start us off we keep things simple and use linear regression to estimate the causal impact, using a small pre-intervention period:
df_lin_reg_100, pred_lift_lin_reg_100 = train_reg(df, 100, LinearRegression())
Looking at the results, linear regression doesn’t do great. But this isn’t surprising given the observation to feature ratio.
synth_plot(df_lin_reg_100, 'pred')
Synthetic control method
Let’s jump right in and see how it compares to the synthetic control method. Below I have setup a similar function as before, but applying the synthetic control method using sciPy:
def synthetic_control(weights, control_units, treated_unit):synthetic = np.dot(control_units.values, weights)
return np.sqrt(np.sum((treated_unit - synthetic)**2))
def train_synth(df, start_index):
df_temp = df.iloc[start_index:].copy().reset_index()
X_pre = df_temp[df_temp['intervention'] == 0][X]
y_pre = df_temp[df_temp['intervention'] == 0][y]
X_train, X_test, y_train, y_test = train_test_split(X_pre, y_pre, test_size=0.10, random_state=42)
initial_weights = np.ones(len(X)) / len(X)
constraints = ({'type': 'eq', 'fun': lambda w: np.sum(w) - 1})
bounds = [(0, 1) for _ in range(len(X))]
result = minimize(synthetic_control,
initial_weights,
args=(X_train, y_train),
method='SLSQP',
bounds=bounds,
constraints=constraints,
options={'disp': False, 'maxiter': 1000, 'ftol': 1e-9},
)
optimal_weights = result.x
yhat_train = np.dot(X_train.values, optimal_weights)
yhat_test = np.dot(X_test.values, optimal_weights)
mse_train = mean_squared_error(y_train, yhat_train)
mse_test = mean_squared_error(y_test, yhat_test)
print(f"Mean Squared Error train: {round(mse_train, 2)}")
print(f"Mean Squared Error test: {round(mse_test, 2)}")
r2_train = r2_score(y_train, yhat_train)
r2_test = r2_score(y_test, yhat_test)
print(f"R2 train: {round(r2_train, 2)}")
print(f"R2 test: {round(r2_test, 2)}")
df_temp['pred'] = np.dot(df_temp.loc[:, X].values, optimal_weights)
df_temp['delta'] = df_temp['y'] - df_temp['pred']
pred_lift = df_temp[df_temp['intervention'] == 1]['delta'].sum()
actual_lift = df_temp[df_temp['intervention'] == 1]['uplift_abs'].sum()
abs_error_perc = abs(pred_lift - actual_lift) / actual_lift
print(f"Predicted lift: {round(pred_lift, 2)}")
print(f"Actual lift: {round(actual_lift, 2)}")
print(f"Absolute error percentage: {round(abs_error_perc, 2)}")
return df_temp, abs_error_perc
I keep the pre-intervention period the same to create a fair comparison to linear regression:
df_synth_100, pred_lift_synth_100 = train_synth(df, 100)
Wow! I’ll be the first to admit I wasn’t expecting such a significant improvement!
synth_plot(df_synth_100, 'pred')
Comparison of results
Let’s not get too carried away yet. Below we run a few more experiments exploring model types and pre-interventions periods:
# run regression experiments
df_lin_reg_00, pred_lift_lin_reg_00 = train_reg(df, 0, LinearRegression())
df_lin_reg_100, pred_lift_lin_reg_100 = train_reg(df, 100, LinearRegression())
df_ridge_00, pred_lift_ridge_00 = train_reg(df, 0, RidgeCV())
df_ridge_100, pred_lift_ridge_100 = train_reg(df, 100, RidgeCV())
df_lasso_00, pred_lift_lasso_00 = train_reg(df, 0, LassoCV())
df_lasso_100, pred_lift_lasso_100 = train_reg(df, 100, LassoCV())# run synthetic control experiments
df_synth_00, pred_lift_synth_00 = train_synth(df, 0)
df_synth_100, pred_lift_synth_100 = train_synth(df, 100)
experiment_data = {
"Method": ["Linear", "Linear", "Ridge", "Ridge", "Lasso", "Lasso", "Synthetic Control", "Synthetic Control"],
"Data Size": ["Large", "Small", "Large", "Small", "Large", "Small", "Large", "Small"],
"Value": [pred_lift_lin_reg_00, pred_lift_lin_reg_100, pred_lift_ridge_00, pred_lift_ridge_100,pred_lift_lasso_00, pred_lift_lasso_100, pred_lift_synth_00, pred_lift_synth_100]
}
df_experiments = pd.DataFrame(experiment_data)
We will use the code below to visualise the results:
# Set the style
sns.set_style="whitegrid"# Create the bar plot
plt.figure(figsize=(10, 6))
bar_plot = sns.barplot(x="Method", y="Value", hue="Data Size", data=df_experiments, palette="muted")
# Add labels and title
plt.xlabel("Method")
plt.ylabel("Absolute error percentage")
plt.title("Synthetic Controls - Comparison of Methods Across Different Data Sizes")
plt.legend(title="Data Size")
# Show the plot
plt.show()
The results for the small dataset are really interesting! As expected, regularisation helped improve the causal impact estimates. The synthetic control then took it one step further!
The results of the large dataset suggest that longer pre-intervention periods aren’t always better.
However, the thing I want you to take away is how valuable carrying out a pre-intervention simulation is. There are so many avenues you could explore with your own dataset!
Today we explored the synthetic control method and how you can validate the causal impact. I’ll leave you with a few final thoughts:
- The simplicity of the synthetic control method make it one of the most widely used technique from the causal AI toolbox.
- Unfortunately it is also the most widely abused — Lets run the R CausalImpact package, changing the pre-intervention period until we see an uplift we like. 😭
- This is where I highly recommend running pre-intervention simulations to agree test design upfront.
- Synthetic control method is a heavily researched area. It’s worth checking out the proposed adaptions Augmented SC, Robust SC and Penalized SC.
Alberto Abadie, Alexis Diamond & Jens Hainmueller (2010) Synthetic Control Methods for Comparative Case Studies: Estimating the Effect of California’s Tobacco Control Program, Journal of the American Statistical Association, 105:490, 493–505, DOI: 10.1198/jasa.2009.ap08746