forestriesz ships two random-forest backends. They differ in whether the user has to make per-estimand modeling choices to fit the model.
AugForestRieszRegressor — augmentation-style. An ensemble of single-tree Riesz regressors fit on the augmented dataset of evaluation points \(z_r\) with weights \((D_r, C_r)\) that Estimand.augment produces. Each tree uses sklearn-standard hyperparameters and a loss-aware splitter that handles every built-in Bregman loss directly. Works on every estimand without per-estimand configuration — built-in estimands (ATE, ATT, TSM, AdditiveShift, LocalShift) and any user-defined Estimand use the same call.
ForestRieszRegressor — moment-style. Implements Chernozhukov, Newey, Quintas-Martínez, Syrgkanis (ICML 2022). To fit the model the user must supply a list of basis functions of the data — the riesz_feature_fns parameter — that the forest combines linearly inside each leaf. The required list is per-estimand: ATE, ATT, and TSM get reasonable defaults from riesz_feature_fns="auto", but other estimands need a user-constructed list. Supports honest-split confidence intervals via predict_interval for single-basis fits.
rieszreg exposes the two entry points as Backend.fit_augmented and MomentBackend.fit_rows respectively.
How they work
Augmentation-style (AugForestRieszRegressor)
For each original row \(i\), Estimand.augment emits the original observation \(Z_i\) and one evaluation point per linear-form term in \(m(\mu)(Z_i, Y_i)\). Each augmented row \(r\) carries an evaluation point \(z_r\), a weight \(D_r \in \{0, 1\}\) indicating whether \(z_r\) is the original observation, and a coefficient \(C_r\) collecting the trace coefficients at \(z_r\). The empirical Bregman-Riesz loss is
For SquaredLoss this has the closed form \(\hat\alpha_\ell = -\sum_{r \in \ell} C_r / \sum_{r \in \ell} D_r\); for KLLoss, BernoulliLoss, and BoundedSquaredLoss the per-leaf optimum has its own closed form (see Tree backend for the derivations). The forest averages the per-tree predictions of \(\hat\alpha\).
Because \(D_r\) and \(C_r\) already vary across augmented rows for every estimand — original observations have \(D_r = 1, C_r = 0\) and counterfactual evaluation points have \(D_r = 0, C_r \ne 0\) — the splitter learns from the full feature space directly. The user does not supply any basis functions.
Moment-style (ForestRieszRegressor)
The user supplies a list of functions \(\varphi = (\varphi_1, \dots, \varphi_p)\) of the data. Inside each leaf \(\ell\) the per-leaf parameter \(\theta_\ell \in \mathbb R^p\) solves the linear-moment equation
and predictions are \(\hat\alpha(z) = \hat\theta(z_{\text{split}}) \cdot \varphi(z)\). The forest is fit on the \(n\) original rows, and the MSE splitting criterion picks splits to minimize the sum of in-leaf residuals against this leaf optimum.
The list \(\varphi\) has to be chosen per estimand. The default riesz_feature_fns="auto" picks defaults for built-in estimands:
Estimand
Default list
Effect
ATE, ATT
[1{A=0}, 1{A=1}]
Forest splits on covariates only; the basis resolves treatment in each leaf.
TSM(level=v)
[1{A=v}]
Single-basis fit; intervals supported.
AdditiveShift, LocalShift, custom
none — auto-resolution falls back to a constant function, which is row-degenerate
Backend raises with a hint. Use AugForestRieszRegressor instead, or pass a custom list.
For ATE the leaf solve recovers \(1/\hat P(A = a \mid X\text{-leaf})\) — the IPW representer.
How they differ
AugForestRieszRegressor
ForestRieszRegressor
Per-estimand setup
none — works on every estimand directly
user supplies riesz_feature_fns (auto-resolved for ATE, ATT, TSM)
yes — predict_interval on single-basis fits with honest=True, inference=True
Splits on
full feature space
covariates only when the basis handles treatment
Backend Protocol
Backend.fit_augmented
MomentBackend.fit_rows
When to pick
default for general use; required for shift-style or custom estimands
when you need CIs on ATE, ATT, or TSM
Bregman losses
AugForestRieszRegressor accepts every loss rieszreg ships. The per-tree splitter dispatches to a per-loss leaf-loss formula (squared, KL, Bernoulli, bounded-squared) so splits are chosen against the loss the user actually supplied. The forest averages per-tree \(\hat\alpha\) predictions.
import numpy as np, pandas as pdfrom forestriesz import AugForestRieszRegressor, TSMfrom rieszreg import KLLossrng = np.random.default_rng(0)n =1500x = rng.uniform(0, 1, n)pi =1/ (1+ np.exp(-(0.5* x -0.3)))a = rng.binomial(1, pi).astype(float)df_tsm = pd.DataFrame({"a": a, "x": x})# KLLoss is matched to density-ratio estimands like TSM and IPSI;# it requires non-negative C-coefficients in the augmented dataset.est = AugForestRieszRegressor( estimand=TSM(level=1, treatment="a", covariates=("x",)), loss=KLLoss(), n_estimators=200, min_samples_leaf=10, random_state=0,)est.fit(df_tsm)
AugForestRieszRegressor(estimand=<rieszreg.estimands.base.TSM object at 0x7feacbdea120>,
loss=<rieszreg.losses.kl.KLLoss object at 0x7feacbdea330>,
min_samples_leaf=10, n_estimators=200)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
AugForestRieszRegressor(estimand=<rieszreg.estimands.base.TSM object at 0x7feacbdea120>,
loss=<rieszreg.losses.kl.KLLoss object at 0x7feacbdea330>,
min_samples_leaf=10, n_estimators=200)
KLLoss and BernoulliLoss reject ATE / ATT / shift-style data at fit time — they require non-negative \(C\)-coefficients, which difference-style estimands violate. BoundedSquaredLoss(lo, hi) accepts any signed coefficients and clips \(\hat\alpha\) into the interval.
ForestRieszRegressor(estimand=<rieszreg.estimands.base.ATE object at 0x7feacadfa180>,
min_samples_leaf=10, n_estimators=500)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ForestRieszRegressor(estimand=<rieszreg.estimands.base.ATE object at 0x7feacadfa180>,
min_samples_leaf=10, n_estimators=500)
alpha_hat = fr.predict(df)true_alpha = a / pi - (1- a) / (1- pi)print(f"corr(α̂, α₀) = {np.corrcoef(alpha_hat, true_alpha)[0, 1]:.3f}")
pkgload::load_all("../packages/forestriesz/r/forestriesz")use_python_forestriesz(file.path(getwd(), "../.venv/bin/python"))set.seed(0)n <-1500x <-runif(n)pi <-1/ (1+exp(-(-0.02* x - x^2+4*log(x +0.3) +1.5)))a <-rbinom(n, 1, pi)df <-data.frame(a =as.numeric(a), x = x)# ATE / ATT use a multi-basis list, which is Python-only in v1.# Call into Python directly via reticulate.fr_py <- reticulate::import("forestriesz", convert =FALSE)pd_py <- reticulate::import("pandas", convert =FALSE)py_df <- pd_py$DataFrame(reticulate::r_to_py(list(a =as.numeric(a), x = x)))fr <- fr_py$ForestRieszRegressor(estimand = fr_py$ATE(treatment ="a", covariates =list("x")),n_estimators =500L, min_samples_leaf =10L, random_state =0L)fr$fit(py_df)alpha_hat <-as.numeric(reticulate::py_to_r(fr$predict(py_df)))
Quickstart — AdditiveShift with the augmentation-style backend
AdditiveShift has no canonical list of basis functions, so the moment-style backend would require the user to construct one. AugForestRieszRegressor works on it directly.
AugForestRieszRegressor(estimand=<rieszreg.estimands.base.AdditiveShift object at 0x7feacae4ad50>,
min_samples_leaf=10, n_estimators=500)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
AugForestRieszRegressor(estimand=<rieszreg.estimands.base.AdditiveShift object at 0x7feacae4ad50>,
min_samples_leaf=10, n_estimators=500)
ForestRieszRegressor(estimand=<rieszreg.estimands.base.TSM object at 0x7feacadf7f80>,
honest=True, inference=True, n_estimators=500)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ForestRieszRegressor(estimand=<rieszreg.estimands.base.TSM object at 0x7feacadf7f80>,
honest=True, inference=True, n_estimators=500)
Minimum count of original (\(D > 0\)) augmented rows in a node before considering a split.
min_samples_leaf
1
Minimum count of original rows in each child of a candidate split.
max_features
1.0
Per-split feature-subsampling rule.
max_leaf_nodes
None
Cap on per-tree leaf count.
min_impurity_decrease
0.0
Reject splits below this gain.
ccp_alpha
0.0
Cost-complexity pruning per tree.
bootstrap
True
Per-tree resampling at the original-row level (block bootstrap).
max_samples
None
Per-tree subsample size as fraction or count of original rows.
n_jobs
None
Parallelism across trees via joblib.Parallel.
splitter
"exact"
"exact", "hist", "random", or "python" — riesztree splitter dispatch.
max_bins
255
Histogram-splitter bin count.
categorical_features
None
Column indices treated as integer category labels.
When splitter="hist" and the configuration is “simple” (no categoricals, no max_features subsampling, no ccp_alpha, no leaf-count cap, built-in loss), the bin mapper is fitted once on the full augmented training data and the binned matrix is shared across joblib workers — sklearn-HistGradientBoostingRegressor convention. The win is largest at shallow depths where per-tree binning dominates tree-build wall time (~2× speedup at max_depth=8).
ForestRieszRegressor
Forwarded to EconML’s BaseGRF:
Knob
Default
Notes
n_estimators
100
Number of trees. With inference=True must be divisible by subforest_size (default 4).
max_depth
None
Grow until leaves saturate min_samples_leaf.
min_samples_leaf
5
Tighter leaves give lower bias, higher variance. Paper guidance: 5–20.
min_samples_split
10
max_samples
0.45
Per-tree subsample fraction. GRF default.
honest
False
Enable for valid-coverage CIs (50/50 sample split).
inference
False
Enable for predict_interval; preserves subforest structure.
subforest_size
4
Trees per subforest when computing inference variance.
For built-in estimands the moment-style basis is auto-resolved on load. For custom bases, repass the callables: ForestRieszRegressor.load(path, riesz_feature_fns=my_basis). AugForestRieszRegressor round-trips without any extra arguments.
Sklearn integration
Both regressors are BaseEstimators, so they compose with clone, GridSearchCV, cross_val_predict, Pipeline. See Tuning and Cross-Fitting.
When the constant basis raises (moment-style only)
For built-in estimands the per-row moment \(m(\varphi)(Z_i, Y_i)\) doesn’t depend on \(Z_i\) when \(\varphi \equiv 1\), so the constant basis is degenerate (every leaf would predict the same constant). ForestRieszRegressor detects this and raises with a hint to use riesz_feature_fns="auto" (the default) or supply a custom list. Explicit riesz_feature_fns=None is what raises. AugForestRieszRegressor is unaffected — its splitter does not depend on a user-supplied basis.