Backends

A backend fits the Riesz representer for a given estimand and loss. Four implementation packages ship today; each slots into RieszEstimator through one of two Protocols. rieszboost and krrr consume the augmented dataset of evaluation points with weights \((D_r, C_r)\) (Backend.fit_augmented). forestriesz ships both: an augmentation-style AugForestRieszRegressor that works on every estimand without per-estimand configuration (Backend.fit_augmented), and a moment-style ForestRieszRegressor that consumes original rows plus a user-supplied list of basis functions (MomentBackend.fit_rows) and supports honest-split confidence intervals. riesznet consumes original rows plus per-row moments via MomentBackend.fit_rows, training a PyTorch model with autograd.

Pick a backend

Boosting (rieszboost) Kernel ridge (krrr) Forest (forestriesz) Neural (riesznet)
Method Gradient boosting (xgboost or sklearn) Closed-form solve of \((\mathrm{diag}(a) K + n\lambda I)\gamma = -b/2\) Ensemble of single-tree Riesz regressors (augmented) or per-leaf linear-moment solve (moment) PyTorch autograd on the per-row Bregman-Riesz loss
Algorithm Lee & Schuler (2501.04871) Singh (2102.11076) Chernozhukov, Newey, Quintas-Martínez, Syrgkanis (ICML 2022) for the moment-style backend Chernozhukov, Newey, Quintas-Martínez, Syrgkanis (2110.03031)
Best for Tabular data, \(n \gtrsim 1{,}000\), mixed feature types Smooth \(\alpha_0\), small/medium \(n\), low-dimensional \(X\) Many covariates; the moment-style flavor adds asymptotic CIs on \(\alpha(z)\) High-dimensional inputs, custom architectures, image / text features
Iterative? Yes (rounds + early stopping) No (closed-form per λ) No (single forest fit) Yes (epochs + early stopping)
Hyperparameters n_estimators, learning_rate, max_depth, reg_lambda, subsample kernel, lambda_grid, solver n_estimators, max_depth, min_samples_leaf, plus honest / riesz_feature_fns for the moment-style flavor hidden_sizes, learning_rate, epochs, dropout, or any module_factory / optimizer_factory
Loss support SquaredLoss, KLLoss, BernoulliLoss, BoundedSquaredLoss SquaredLoss only (today) All four for the augmented flavor; SquaredLoss only for the moment flavor SquaredLoss, KLLoss, BernoulliLoss, BoundedSquaredLoss
Scaling Linear in \(n\) via xgboost \(O(n^3)\) direct; Nyström + RFF + Falkon for larger \(n\) \(O(n \log n)\) per tree Per-epoch cost is one forward pass over \(n + \sum_i k_i\) points; scales to GPU
Confidence intervals No No Yes — predict_interval on single-basis fits (moment-style only) No
Convenience class rieszboost.RieszBooster krrr.KernelRieszRegressor forestriesz.AugForestRieszRegressor (works on every estimand without per-estimand configuration) and forestriesz.ForestRieszRegressor (moment-style; user supplies a list of basis functions; supports CIs) riesznet.RieszNet (default MLP); riesznet.TorchBackend for arbitrary nn.Module factories

Compose explicitly

Both convenience classes are subclasses of rieszreg.RieszEstimator with the backend defaulted. To swap backends or build your own, compose explicitly:

import functools
from rieszreg import RieszEstimator, ATE, SquaredLoss
from rieszboost.backends import XGBoostBackend, SklearnBackend
from krrr import KernelRidgeBackend, Gaussian
from forestriesz import AugForestRieszBackend, ForestRieszBackend
from riesznet import TorchBackend
from riesznet.modules import build_adam, build_mlp

# Default: xgboost
est = RieszEstimator(estimand=ATE(), backend=XGBoostBackend())

# Sklearn boosting with KernelRidge as the base learner
from sklearn.kernel_ridge import KernelRidge
est = RieszEstimator(
    estimand=ATE(),
    backend=SklearnBackend(lambda: KernelRidge(alpha=1.0, kernel="rbf", gamma=2.0)),
)

# Kernel ridge directly
est = RieszEstimator(
    estimand=ATE(),
    backend=KernelRidgeBackend(kernel=Gaussian(length_scale="median")),
)

# Random forest (augmentation-style backend; fit_augmented path)
est = RieszEstimator(
    estimand=ATE(),
    backend=AugForestRieszBackend(n_estimators=500),
)

# Random forest (moment-style backend; fit_rows path; supports CIs)
est = RieszEstimator(
    estimand=ATE(),
    backend=ForestRieszBackend(n_estimators=500),
)

# Neural network (moment-style backend; fit_rows path)
est = RieszEstimator(
    estimand=ATE(),
    backend=TorchBackend(
        module_factory=functools.partial(build_mlp, hidden_sizes=(64, 64)),
        optimizer_factory=functools.partial(build_adam, lr=5e-3),
        epochs=300,
    ),
)

The convenience classes are thin subclasses that bake in the backend and surface backend-specific hyperparameters as constructor args, so that they fit cleanly into GridSearchCV. See the boosting, kernel, forest, and neural pages for each.

Adding a backend

The RieszEstimator orchestrator is learner-agnostic. To run it with an algorithm we don’t ship, wrap that algorithm as a small class that satisfies one Protocol and register a predictor loader. This section walks through the minimum viable wrapper, a worked example, and how to plug it into cross_val_predict, GridSearchCV, and Pipeline.

This is the end-user route: keep the wrapper in your own module, depend on rieszreg, ship nothing back. Contributing a learner package to the family adds a few requirements (R parity, the full conformance suite, in-tree tier placement) covered in DESIGN.md Part B.

Pick an entry point

Two Protocols live in rieszreg.backends.base. Each backend implements one.

Backend.fit_augmented(aug_train, aug_valid, loss, …) is for learners whose loss decomposes naturally over augmented evaluation points. The orchestrator hands you an AugmentedDataset of pseudo-rows with is_original and potential_deriv_coef columns. The empirical Bregman-Riesz loss at \(\alpha\) is the sum over augmented rows \(r\) of \(D_r\,\tilde h(\alpha(Z_r)) + C_r\,h'(\alpha(Z_r))\), where \(D_r\) is is_original and \(C_r\) is potential_deriv_coef. Reference learners using this entry point: kernel ridge (KernelRidgeBackend), gradient boosting (XGBoostBackend, SklearnBackend), and random forests (AugForestRieszBackend).

MomentBackend.fit_rows(rows_train, rows_valid, estimand, loss, …) is for learners whose loss decomposes per original row. You receive raw rows plus the estimand, and you call rieszreg.trace(estimand, row, y) per row to read the \((c_j, a_j)\) pairs. Each row contributes a single loss term, so there is no \(k+1\)-fold augmentation blow-up. Reference learners using this entry point: random forests (ForestRieszBackend) and neural nets (TorchBackend).

Pick by your learner’s natural loss decomposition. Closed-form linear solves and gradient boosting fit fit_augmented, because the augmented sum has the same shape as a per-row weighted regression. Random forests and neural nets fit fit_rows, because each original row is a single training example to the underlying learner. When both fit, pick fit_augmented: the orchestrator dispatches to the moment path only when fit_rows is the only method present.

Minimum viable wrapper

Three pieces: a backend class implementing one Protocol method, a predictor class with predict_eta and predict_alpha, and a one-line registration of the predictor’s loader. All three live in your own module.

Augmentation-style

import numpy as np
from rieszreg.backends import (
    FitResult, register_predictor_loader,
)

class MyLearnerPredictor:
    kind = "my-learner"          # registry key, must be unique

    def __init__(self, params, base_score, loss):
        self.params = params      # whatever your learner returns
        self.base_score = base_score
        self.loss = loss

    def predict_eta(self, features):
        return self.base_score + my_learner_predict(self.params, features)

    def predict_alpha(self, features):
        return self.loss.link_to_alpha(self.predict_eta(features))

    def save(self, dir_path):
        np.save(dir_path / "params.npy", self.params)

    @classmethod
    def load(cls, dir_path, *, base_score, loss, best_iteration):
        params = np.load(dir_path / "params.npy")
        return cls(params=params, base_score=base_score, loss=loss)


class MyLearnerBackend:
    """Augmentation-style wrapper around `my_learner_fit`."""

    def __init__(self, regularization=1.0, validation_fraction=0.0):
        self.regularization = regularization
        self.validation_fraction = validation_fraction

    def fit_augmented(self, aug_train, aug_valid, loss,
                      *, base_score, random_state, hyperparams):
        # Each augmented row contributes D · h_tilde(α) + C · h'(α) to the loss.
        # Translate (features, D, C) into whatever your learner consumes.
        params = my_learner_fit(
            features=aug_train.features,
            is_original=aug_train.is_original,
            potential_deriv_coef=aug_train.potential_deriv_coef,
            regularization=self.regularization,
            random_state=random_state,
        )
        predictor = MyLearnerPredictor(params, base_score=base_score, loss=loss)
        return FitResult(predictor=predictor)


register_predictor_loader("my-learner", MyLearnerPredictor.load)

Moment-style

Replace fit_augmented with fit_rows. The predictor stays the same.

from rieszreg import trace

class MyMomentBackend:
    def __init__(self, learning_rate=1e-3):
        self.learning_rate = learning_rate

    def fit_rows(self, rows_train, rows_valid, estimand, loss,
                 *, base_score, random_state, hyperparams,
                 ys_train=None, ys_valid=None):
        # For each row i, trace(...) returns the pairs (c_j, a_j) where
        # m(α)(Z_i, Y_i) = sum_j c_j · α(a_j). Each row's contribution
        # to the empirical loss is α(Z_i)² − 2 · sum_j c_j · α(a_j) for
        # squared loss; the general Bregman form lives on `loss`.
        moments = []
        for i, row in enumerate(rows_train):
            y_i = ys_train[i] if ys_train is not None else None
            moments.append(trace(estimand, row, y_i))

        params = my_moment_fit(
            rows_train, moments,
            learning_rate=self.learning_rate,
            random_state=random_state,
        )
        predictor = MyLearnerPredictor(params, base_score=base_score, loss=loss)
        return FitResult(predictor=predictor)

Backend, MomentBackend, Predictor, and FitResult are all importable from rieszreg.backends. The Protocols are structural (typing.Protocol), so your wrapper does not need to inherit from anything as long as the method signatures match.

Worked example

The canonical linear-Gaussian DGP from rieszreg.testing.dgps is a binary treatment \(A \in \{0, 1\}\) with covariate \(X \in \mathbb R\) and propensity \(\pi(x) = \mathrm{logit}^{-1}(0.5\,x)\). The true representer for ATE is the inverse propensity weight \(\alpha_0(A, X) = (2A - 1) \,/\, [A\pi(X) + (1-A)(1-\pi(X))]\).

import numpy as np
from rieszreg import RieszEstimator, ATE
from rieszreg.testing.dgps import linear_gaussian_ate

dgp = linear_gaussian_ate()
df = dgp.sample(n=2000, rng=np.random.default_rng(0))
Z = df[["a", "x"]]
y = df["y"]

est = RieszEstimator(
    estimand=ATE(treatment="a", covariates=("x",)),
    backend=MyLearnerBackend(regularization=0.1),
).fit(Z, y)

alpha_hat = est.predict(Z)
alpha_true = dgp.true_alpha(df)
print(f"Pearson(α̂, α₀) = {np.corrcoef(alpha_hat, alpha_true)[0, 1]:.3f}")

RieszEstimator.fit reads feature_keys off the estimand, builds the augmented dataset by tracing \(m\) row by row, computes base_score from the loss and the empirical mean \(\bar m = \mathbb E_n[m(\alpha = 1)(Z, Y)]\), and dispatches to MyLearnerBackend.fit_augmented. The fitted predictor gets stored on est.predictor_. est.predict(Z) calls predictor.predict_alpha(features) on the rows you pass in.

y is the per-row outcome vector. Built-in estimands ignore it; pass it anyway so the call follows the sklearn convention. Custom \(Y\)-dependent estimands receive it inside m(\alpha)(z, y).

Sklearn composition

RieszEstimator already inherits from sklearn.base.BaseEstimator. Composition with cross_val_predict, GridSearchCV, clone, and Pipeline works as long as the wrapper plays by sklearn’s two structural rules.

The constructor stores its arguments by name without modification. MyLearnerBackend.__init__ writes self.regularization = regularization and stops. No coercion, no validation, no derived attributes. clone(MyLearnerBackend(regularization=0.1)) then returns a fresh equivalent object, which sklearn relies on between folds.

Fit-time state lives on the predictor, not on the backend. The backend stays clone-clean across folds; fitted parameters live on the Predictor returned in FitResult, attached to est.predictor_ after fit.

With those two in place, the standard sklearn idioms compose:

from sklearn.model_selection import cross_val_predict, GridSearchCV, KFold

cv = KFold(n_splits=5, shuffle=True, random_state=0)

# Cross-fit α̂ for downstream DML / TMLE.
alpha_oof = cross_val_predict(est, Z, y, cv=cv)

# Tune the regularization on held-out Riesz loss.
grid = GridSearchCV(
    estimator=est,
    param_grid={"backend__regularization": [0.01, 0.1, 1.0, 10.0]},
    cv=cv,
).fit(Z, y)
print(grid.best_params_)

backend__regularization is sklearn’s nested-parameter syntax. It forwards to est.backend.regularization so any backend hyperparameter exposed as a constructor attribute is reachable.

If your backend uses a held-out slice for fit-time logic such as early stopping or regularization-path selection, expose validation_fraction as a constructor attribute. The orchestrator reads it via getattr and splits the rows before augmentation. Backends that compute their own internal CV or that don’t need a holdout leave the attribute off.

Save, load, conformance

RieszEstimator.save(path) writes metadata.json (loss spec, estimand factory_spec, base_score, your predictor.kind) and calls predictor.save(dir_path) for the binary payload. RieszEstimator.load(path) reads the metadata and looks up your registered predictor loader by kind. Built-in estimands round-trip automatically; custom estimands require RieszEstimator.load(path, estimand=...) so the user’s \(m\) is back in scope.

Run register_predictor_loader("my-learner", MyLearnerPredictor.load) at the top level of your wrapper module so import my_wrapper registers the loader.

rieszreg.testing.conformance ships small helpers (assert_clone_roundtrip, assert_get_params_round_trip) that verify the load-bearing sklearn properties on your wrapper:

from rieszreg.testing.conformance import assert_clone_roundtrip

def test_clone():
    assert_clone_roundtrip(
        lambda: RieszEstimator(
            estimand=ATE(treatment="a", covariates=("x",)),
            backend=MyLearnerBackend(regularization=0.1),
        )
    )

For a private wrapper, those two helpers plus a fit/predict/score smoke test on linear_gaussian_ate cover the regression surface. The full suite (GridSearchCV, cross_val_predict, save/load on every built-in estimand, the estimator-consistency suite over the canonical DGPs) is the bar an in-tree contributed learner has to clear.