RieszReg

RieszReg — Riesz regression for semiparametric inference

rieszreg estimates the Riesz representer \(\alpha\) of a linear estimand \(\psi = \mathbb{E}[m(\mu)(Z,Y)]\) — the building block of one-step, TMLE, and DML inference — with gradient boosting (rieszboost), kernel ridge (krrr), random forests (forestriesz), or neural nets (riesznet). Think of it as the sklearn, tidymodels or superlearner for Riesz regression.

rieszreg ships built-in factories for ATE, ATT, TSM, additive shifts, local shifts, stochastic interventions, and arbitrary user-defined estimands; the squared Riesz loss plus the wider Bregman family (KL, Bernoulli, bounded squared); and sklearn-native tuning, model selection, and cross-fitting. Python and R, with bitwise-identical predictions on the same input.

What is Riesz regression?

You want to estimate some causal or structural estimand — say, the average treatment effect \(\psi = \mathbb{E}[\mu(1, X) - \mu(0, X)]\), where \(\mu(a, x) = \mathbb{E}[Y \mid A = a, X = x]\). A regression \(\hat\mu\) for \(\mu\) is biased; the standard fix uses a weight. For ATE, the weight is the inverse-propensity score:

\[ \alpha(a, x) \;=\; \frac{a}{\pi(x)} - \frac{1-a}{1-\pi(x)}, \]

Under the Riesz representation theorem, similar weights exist for many different estimands. rieszreg learns these weights directly from data: no propensity model, no truncation, etc. The same machinery handles continuous shifts, stochastic interventions, and many custom estimands a user can write down — for which closed-form weights could otherwise involve tricky or slow density estimates. These weights are then suitable for downstream use in weighting estimators or

For the longer derivation and a survey of the closed-form weights for each built-in estimand, see What is Riesz regression?.

Quickstart

Fit a Riesz representer for ATE on a synthetic binary-treatment dataset, select between two backends with cross-validation, and use the cross-fit \(\hat\alpha\) for a DML point estimate and Wald CI.

Simulate
import os
os.environ.update({"OMP_NUM_THREADS": "1", "MKL_NUM_THREADS": "1"})
import numpy as np
import pandas as pd

rng = np.random.default_rng(0)
n = 2000
x = rng.uniform(0, 1, n)
pi = 1 / (1 + np.exp(-(-0.02 * x - x**2 + 4 * np.log(x + 0.3) + 1.5)))
a = rng.binomial(1, pi)
mu = 5 * x + 9 * x * a + 5 * np.sin(x * np.pi) + 25 * (a - 2)
y = mu + rng.normal(0, 1.0, n)
df = pd.DataFrame({"a": a.astype(float), "x": x, "y": y})
true_alpha = a / pi - (1 - a) / (1 - pi)

We start with a data frame df with treatment a, covariate x, outcome y. In this case, since this is simulated data, we know the true Riesz representer \(\alpha_0\) to evaluate against later.

import torch
torch.set_num_threads(1)
from rieszreg import ATE, BoundedSquaredLoss
from rieszboost import RieszBooster
from riesznet import RieszNet

estimand = ATE(treatment="a", covariates=("x",))
loss     = BoundedSquaredLoss(lo=-50.0, hi=50.0)

boost = RieszBooster(
    estimand=estimand, loss=loss,
    n_estimators=200, learning_rate=0.05, max_depth=3,
)
net   = RieszNet(
    estimand=estimand, loss=loss,
    hidden_sizes=(32, 32), epochs=60, learning_rate=1e-2,
)

Wrap the candidates in a Pipeline and let GridSearchCV pick the winner via inner CV; nest that selector inside cross_val_predict so each outer fold runs its own inner model selection on the training portion only. The result alpha_hat is out-of-fold — every row was predicted by a model fit without seeing it.

from sklearn.model_selection import GridSearchCV, KFold, cross_val_predict
from sklearn.pipeline import Pipeline

cv = KFold(n_splits=5, shuffle=True, random_state=0)
selector = GridSearchCV(
    Pipeline([("riesz", boost)]),
    param_grid=[{"riesz": [boost, net]}],
    cv=3, refit=True,
)
alpha_hat = cross_val_predict(selector, df, cv=cv)

Sanity check

Out-of-fold \(\hat\alpha\) vs the closed-form \(\alpha_0\) on the synthetic ATE DGP.

DML point estimate and CI

Cross-fit outcome regression
from sklearn.ensemble import GradientBoostingRegressor

mu_model = GradientBoostingRegressor(
    n_estimators=300, learning_rate=0.05, max_depth=3, random_state=0,
)
Z_full = df[["a", "x"]].to_numpy()
mu_oof = cross_val_predict(mu_model, Z_full, df["y"].to_numpy(), cv=cv)

# Counterfactual predictions need a model fit on the full data.
mu_full = mu_model.fit(Z_full, df["y"].to_numpy())
mu1 = mu_full.predict(np.column_stack([np.ones(n), x]))
mu0 = mu_full.predict(np.column_stack([np.zeros(n), x]))
eif  = (mu1 - mu0) + alpha_hat * (df["y"].to_numpy() - mu_oof)
psi  = float(eif.mean())
se   = float(np.sqrt(np.var(eif, ddof=1) / n))
ci   = (psi - 1.96 * se, psi + 1.96 * se)
print(f"ψ̂_ATE  = {psi:.3f}")
ψ̂_ATE  = 29.459
print(f"SE     = {se:.3f}")
SE     = 0.096
print(f"95% CI = [{ci[0]:.3f}, {ci[1]:.3f}]")
95% CI = [29.270, 29.647]
Load riesznet R wrapper
.project_root <- Sys.getenv("QUARTO_PROJECT_DIR", unset = getwd())
suppressMessages(pkgload::load_all(
  file.path(.project_root, "../packages/riesznet/r/riesznet"), quiet = TRUE))
Simulate
set.seed(0)
n     <- 2000
x     <- runif(n)
prop  <- 1 / (1 + exp(-(-0.02 * x - x^2 + 4 * log(x + 0.3) + 1.5)))
a     <- rbinom(n, 1, prop)
mu_dgp <- 5 * x + 9 * x * a + 5 * sin(x * pi) + 25 * (a - 2)
y     <- mu_dgp + rnorm(n, 0, 1)
df    <- data.frame(a = as.numeric(a), x = x, y = y)
true_alpha <- a / prop - (1 - a) / (1 - prop)
estimand <- ATE(treatment = "a", covariates = "x")
loss     <- BoundedSquaredLoss(lo = -50, hi = 50)
make_boost <- function() RieszBooster$new(
  estimand = estimand, loss = loss,
  n_estimators = 200L, learning_rate = 0.05, max_depth = 3L
)
make_net <- function() RieszNet$new(
  estimand = estimand, loss = loss,
  hidden_sizes = c(32L, 32L), epochs = 60L, learning_rate = 0.01
)

Use rsample (tidymodels) for fold splits, and a small loop for the nested cross-fit + per-fold model selection. Each outer fold runs its own inner-CV winner pick on its training portion only.

suppressPackageStartupMessages(library(rsample))

cv_outer <- vfold_cv(df, v = 5)
alpha_hat <- numeric(nrow(df))
for (split in cv_outer$splits) {
  tr <- analysis(split); te <- assessment(split); te_idx <- complement(split)
  cv_inner <- vfold_cv(tr, v = 3)
  inner_score <- function(make) mean(vapply(cv_inner$splits, function(s) {
    m <- make(); m$fit(analysis(s)); m$score(assessment(s))
  }, numeric(1)))
  best_make <- if (inner_score(make_boost) >= inner_score(make_net)) make_boost else make_net
  best <- best_make(); best$fit(tr)
  alpha_hat[te_idx] <- best$predict(te)
}

Sanity check \(\hat\alpha\) against the closed-form \(\alpha_0\):

Out-of-fold \(\hat\alpha\) vs the closed-form \(\alpha_0\) (R).

Cross-fit outcome regression \(\hat\mu\) via xgboost, then assemble the DML estimate and Wald CI:

Cross-fit outcome regression
suppressPackageStartupMessages(library(xgboost))

xgb_params <- list(objective = "reg:squarederror", eta = 0.05,
                   max_depth = 3, lambda = 1, verbosity = 0)
mu_oof <- numeric(nrow(df))
for (split in cv_outer$splits) {
  tr <- analysis(split); te <- assessment(split); te_idx <- complement(split)
  bst <- xgb.train(
    params = xgb_params,
    data = xgb.DMatrix(as.matrix(tr[, c("a", "x")]), label = tr$y),
    nrounds = 300
  )
  mu_oof[te_idx] <- predict(bst, as.matrix(te[, c("a", "x")]))
}
mu_full <- xgb.train(
  params = xgb_params,
  data = xgb.DMatrix(as.matrix(df[, c("a", "x")]), label = df$y),
  nrounds = 300
)
mu1 <- predict(mu_full, cbind(a = 1, as.matrix(df[, "x", drop = FALSE])))
mu0 <- predict(mu_full, cbind(a = 0, as.matrix(df[, "x", drop = FALSE])))
eif <- (mu1 - mu0) + alpha_hat * (df$y - mu_oof)
psi <- mean(eif)
se  <- sqrt(var(eif) / length(eif))
ci  <- c(psi - 1.96 * se, psi + 1.96 * se)
cat(sprintf("psi_ATE = %.3f\nSE      = %.3f\n95%% CI  = [%.3f, %.3f]\n",
            psi, se, ci[1], ci[2]))
psi_ATE = 29.497
SE      = 0.096
95% CI  = [29.308, 29.685]

For TMLE on top of these nuisances, and integration with DoubleML and other downstream packages, see Estimation. For backend-specific tuning, see Backends.

TipCode execution

Every code block on this page ran at build time. The numbers and figures came from the version of rieszreg, rieszboost, krrr, forestriesz, and riesznet that produced this page.