Custom code: DML and TMLE

Cross-fit \(\hat\alpha\) from rieszreg plus a cross-fit outcome regression \(\hat\mu\) are the two ingredients of a one-step / DML / TMLE estimator for any linear-functional estimand. This page builds both estimators on the Lalonde NSW dataset in ~50 lines.

Data: Lalonde NSW (treated) + CPS (controls)

The Dehejia-Wahba dataset: NSW-treated job-trainees vs CPS observational controls, outcome re78 (1978 earnings). The standard covariates are age, education, race, marital status, schooling, and pre-period earnings.

import numpy as np, pandas as pd
from causaldata import nsw_mixtape, cps_mixtape

COVARIATES = ["age", "educ", "black", "hisp", "marr", "nodegree", "re74", "re75"]

treated  = nsw_mixtape.load_pandas().data
treated  = treated[treated["treat"] == 1].copy()
controls = cps_mixtape.load_pandas().data
df = pd.concat([treated, controls], ignore_index=True)
df = df.rename(columns={"treat": "a"})
df["y"] = df["re78"]
df = df[["a", "y"] + COVARIATES].copy()
df["a"] = df["a"].astype(float)
print(f"n = {len(df)}  ({(df['a']==1).sum()} treated, {(df['a']==0).sum()} control)")
n = 16177  (185 treated, 15992 control)
library(reticulate)
cd_py <- import("causaldata")
COVARIATES <- c("age", "educ", "black", "hisp", "marr", "nodegree", "re74", "re75")

treated  <- py_to_r(cd_py$nsw_mixtape$load_pandas()$data)
treated  <- treated[treated$treat == 1, ]
controls <- py_to_r(cd_py$cps_mixtape$load_pandas()$data)
df <- rbind(treated, controls)
df$a <- as.numeric(df$treat)
df$y <- df$re78
df <- df[, c("a", "y", COVARIATES)]
cat(sprintf("n = %d  (%d treated, %d control)\n",
            nrow(df), sum(df$a == 1), sum(df$a == 0)))
n = 16177  (185 treated, 15992 control)

Cross-fit nuisances

Both \(\hat\mu\) and \(\hat\alpha\) go through cross_val_predict so every prediction used in the influence function is out-of-fold.

import xgboost as xgb
from sklearn.model_selection import KFold, cross_val_predict
from rieszreg import ATE
from rieszboost import RieszBooster

cv = KFold(n_splits=5, shuffle=True, random_state=0)

# Outcome regression μ̂(A, X)
Z_full = df[["a"] + COVARIATES].to_numpy(dtype=float)
y      = df["y"].to_numpy(dtype=float)
mu_hat = xgb.XGBRegressor(
    objective="reg:squarederror", learning_rate=0.05,
    max_depth=4, reg_lambda=1.0, n_estimators=300,
    random_state=0, verbosity=0,
)
mu_oof = cross_val_predict(mu_hat, Z_full, y, cv=cv)

# Riesz representer α̂(A, X) for ATE
booster = RieszBooster(
    estimand=ATE(treatment="a", covariates=tuple(COVARIATES)),
    n_estimators=2000, early_stopping_rounds=20, validation_fraction=0.2,
    learning_rate=0.05, max_depth=3, reg_lambda=10.0, random_state=0,
)
alpha_hat = cross_val_predict(booster, df[["a"] + COVARIATES], cv=cv)

# Counterfactual μ̂(1, X) and μ̂(0, X) need a model fit on the full data.
mu_full = mu_hat.fit(Z_full, y)
Z1 = np.column_stack([np.ones(len(df)),  df[COVARIATES].to_numpy(dtype=float)])
Z0 = np.column_stack([np.zeros(len(df)), df[COVARIATES].to_numpy(dtype=float)])
mu1 = mu_full.predict(Z1)
mu0 = mu_full.predict(Z0)
library(xgboost)
suppressPackageStartupMessages(library(rsample))
set.seed(0)

# Shared 5-fold split — used for both μ̂ and α̂ so they're aligned.
splits <- vfold_cv(df, v = 5)$splits

# Outcome regression μ̂(A, X) via xgboost (R) with 5-fold cross-fit.
X_full <- as.matrix(df[, c("a", COVARIATES)])
y      <- df$y
mu_oof <- numeric(nrow(df))
for (split in splits) {
  te_idx <- complement(split)
  tr_idx <- setdiff(seq_len(nrow(df)), te_idx)
  bst <- xgb.train(
    params = list(objective = "reg:squarederror", eta = 0.05,
                  max_depth = 4, lambda = 1.0, verbosity = 0),
    data = xgb.DMatrix(X_full[tr_idx, , drop = FALSE], label = y[tr_idx]),
    nrounds = 300
  )
  mu_oof[te_idx] <- predict(bst, X_full[te_idx, , drop = FALSE])
}

# Riesz representer α̂ via the same 5-fold loop, calling $fit / $predict on
# the R6 booster per fold.
make_booster <- function() RieszBooster$new(
  estimand = ATE(treatment = "a", covariates = COVARIATES),
  n_estimators = 2000L, early_stopping_rounds = 20L, validation_fraction = 0.2,
  learning_rate = 0.05, max_depth = 3L, reg_lambda = 10.0
)
alpha_hat <- numeric(nrow(df))
for (split in splits) {
  m <- make_booster()
  m$fit(analysis(split))
  alpha_hat[complement(split)] <- m$predict(assessment(split))
}
booster <- make_booster()  # used by the TMLE step below to produce α at counterfactual treatment levels

# Counterfactual μ̂(1, X) and μ̂(0, X) from a full-data fit.
mu_full <- xgb.train(
  params = list(objective = "reg:squarederror", eta = 0.05,
                max_depth = 4, lambda = 1.0, verbosity = 0),
  data = xgb.DMatrix(X_full, label = y),
  nrounds = 300
)
X1 <- cbind(a = 1, as.matrix(df[, COVARIATES]))
X0 <- cbind(a = 0, as.matrix(df[, COVARIATES]))
mu1 <- predict(mu_full, X1)
mu0 <- predict(mu_full, X0)

DML / one-step

The efficient influence function for ATE is \(\mathrm{IF}(O) = m(\hat\mu)(Z) + \hat\alpha(Z)\,(Y - \hat\mu(Z))\), with \(m(\hat\mu)(Z) = \hat\mu(1, X) - \hat\mu(0, X)\). Average it for \(\hat\psi\); its sample variance gives the SE.

eif = (mu1 - mu0) + alpha_hat * (y - mu_oof)
psi = float(eif.mean())
se  = float(np.sqrt(np.var(eif, ddof=1) / len(eif)))
ci  = (psi - 1.96 * se, psi + 1.96 * se)
print(f"ψ̂_ATE  = {psi:.2f}")
ψ̂_ATE  = -2263.58
print(f"SE     = {se:.2f}")
SE     = 220.03
print(f"95% CI = [{ci[0]:.2f}, {ci[1]:.2f}]")
95% CI = [-2694.84, -1832.32]
eif <- (mu1 - mu0) + alpha_hat * (y - mu_oof)
psi <- mean(eif)
se  <- sqrt(var(eif) / length(eif))
ci  <- c(psi - 1.96 * se, psi + 1.96 * se)
cat(sprintf("psi_ATE = %.2f\nSE      = %.2f\n95%% CI  = [%.2f, %.2f]\n",
            psi, se, ci[1], ci[2]))
psi_ATE = -2187.86
SE      = 223.51
95% CI  = [-2625.95, -1749.77]

The experimental NSW-only ATE benchmark is roughly $1,794 (Dehejia-Wahba 1999); the NSW-treated vs CPS-controls comparison drifts further from that benchmark when covariates fail to absorb the experimental-vs-observational gap.

TMLE

TMLE solves the same EIF=0 estimating equation as one-step but does it by targeting \(\hat\mu\): fit a one-parameter fluctuation \(\mu^\star(\cdot) = \hat\mu(\cdot) + \epsilon \hat\alpha(\cdot)\) on a “clever covariate” \(\hat\alpha\) to drive the empirical mean of \(\hat\alpha(O)(Y - \mu^\star(O))\) to zero, then take the plug-in \(\hat\psi = (1/n) \sum \mu^\star(1, X) - \mu^\star(0, X)\). For continuous \(Y\) the targeting is a single linear regression of the outcome residual on \(\hat\alpha\).

To target the counterfactual predictions we also need \(\hat\alpha\) at \(A=1\) and \(A=0\). Refit the booster on the full sample and predict at the two counterfactual treatment levels.

booster.fit(df[["a"] + COVARIATES])
RieszBooster(early_stopping_rounds=20,
             estimand=<rieszreg.estimands.base.ATE object at 0x7fa09ff19910>,
             max_depth=3, n_estimators=2000, reg_lambda=10.0,
             validation_fraction=0.2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
df_a1 = df.assign(a=1.0)
df_a0 = df.assign(a=0.0)
alpha_1 = booster.predict(df_a1[["a"] + COVARIATES])
alpha_0 = booster.predict(df_a0[["a"] + COVARIATES])

# Fluctuation: ε from OLS of (y - μ̂_oof) on α̂_oof, no intercept.
eps = float(np.dot(alpha_hat, y - mu_oof) / np.dot(alpha_hat, alpha_hat))

# Targeted counterfactual regressions and observed-row regression.
mu1_star    = mu1   + eps * alpha_1
mu0_star    = mu0   + eps * alpha_0
mu_obs_star = mu_oof + eps * alpha_hat

psi_tmle = float((mu1_star - mu0_star).mean())
eif_tmle = (mu1_star - mu0_star - psi_tmle) + alpha_hat * (y - mu_obs_star)
se_tmle  = float(np.sqrt(np.var(eif_tmle, ddof=1) / len(eif_tmle)))
ci_tmle  = (psi_tmle - 1.96 * se_tmle, psi_tmle + 1.96 * se_tmle)
print(f"ψ̂_TMLE = {psi_tmle:.2f}")
ψ̂_TMLE = -2386.77
print(f"SE     = {se_tmle:.2f}")
SE     = 219.82
print(f"95% CI = [{ci_tmle[0]:.2f}, {ci_tmle[1]:.2f}]")
95% CI = [-2817.61, -1955.93]
df_a1 <- df; df_a1$a <- 1
df_a0 <- df; df_a0$a <- 0
booster$fit(df)
alpha_1 <- booster$predict(df_a1)
alpha_0 <- booster$predict(df_a0)

eps          <- sum(alpha_hat * (y - mu_oof)) / sum(alpha_hat^2)
mu1_star     <- mu1   + eps * alpha_1
mu0_star     <- mu0   + eps * alpha_0
mu_obs_star  <- mu_oof + eps * alpha_hat

psi_tmle <- mean(mu1_star - mu0_star)
eif_tmle <- (mu1_star - mu0_star - psi_tmle) + alpha_hat * (y - mu_obs_star)
se_tmle  <- sqrt(var(eif_tmle) / length(eif_tmle))
ci_tmle  <- c(psi_tmle - 1.96 * se_tmle, psi_tmle + 1.96 * se_tmle)
cat(sprintf("psi_TMLE = %.2f\nSE       = %.2f\n95%% CI   = [%.2f, %.2f]\n",
            psi_tmle, se_tmle, ci_tmle[1], ci_tmle[2]))
psi_TMLE = -1907.39
SE       = 223.73
95% CI   = [-2345.91, -1468.88]

For continuous \(Y\), one-step and TMLE point estimates coincide to \(O_p(n^{-1})\). TMLE’s plug-in is a parameter of an updated regression, which preserves bounds (e.g. probabilities stay in \([0, 1]\) for binary outcomes via the logistic-fluctuation variant — replace the OLS targeting step with a logistic-link single-covariate GLM on the bounded outcome).

Other estimands

The pattern generalizes verbatim. Swap the estimand factory and the \(m(\hat\mu)(Z)\) formula:

Estimand \(m(\hat\mu)(Z)\) Riesz factory
ATT \(A(\hat\mu(1, X) - \hat\mu(0, X))\), then divide by \(\hat{P}(A=1)\) ATT(...)
TSM at \(a^\star\) \(\hat\mu(a^\star, X)\) TSM(level=a_star)
Additive shift \(\delta\) \(\hat\mu(A + \delta, X) - \hat\mu(A, X)\) AdditiveShift(delta)

Worked examples for each: rieszboost/examples/ (tsm.py, stochastic_intervention.py, lee_schuler/binary_dgp.py).

See also