Neural-network backend (riesznet)

riesznet trains the Riesz representer α(x) with a PyTorch model, in the spirit of Chernozhukov et al. (2021), RieszNet and ForestRiesz. It fits the Riesz representer only — outcome regression is handled separately.

rieszreg exposes the entry point as MomentBackend.fit_rows. The neural backend uses the per-row moment formulation: for each original row z_i,

\[ L_i = \psi(\alpha(x_i)) - \sum_{j} c_j \cdot \varphi'(\alpha(p_j)) \]

where the (c_j, p_j) pairs come from rieszreg.trace(estimand, z_i) and the model produces α via a Bregman link. The training loop minimizes the mean of L_i with autograd; no augmented dataset is built.

Two surfaces

  • RieszNet — convenience class with a default MLP. Surfaces hidden_sizes, activation, dropout, learning_rate, weight_decay, epochs, device, dtype. Trains full-batch.
  • TorchBackend — power-user surface. Pass any nn.Module factory and any optimizer factory. Compose with rieszreg.RieszEstimator for full control.

Quickstart — ATE

import numpy as np, pandas as pd
from riesznet import RieszNet, ATE

rng = np.random.default_rng(0)
n = 1500
x = rng.uniform(0, 1, n)
pi = 1 / (1 + np.exp(-(-0.02*x - x**2 + 4*np.log(x + 0.3) + 1.5)))
a = rng.binomial(1, pi).astype(float)
df = pd.DataFrame({"a": a, "x": x})

rn = RieszNet(
    estimand=ATE(treatment="a", covariates=("x",)),
    hidden_sizes=(64, 64),
    learning_rate=5e-3,
    epochs=400,
    batch_size=64,
    validation_fraction=0.2,
    early_stopping_rounds=30,
    random_state=0,
)
rn.fit(df)
RieszNet(early_stopping_rounds=30, epochs=400,
         estimand=<rieszreg.estimands.base.ATE object at 0x7f0c0d92bb60>,
         learning_rate=0.005, validation_fraction=0.2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
alpha_hat  = rn.predict(df)
true_alpha = a / pi - (1 - a) / (1 - pi)
print(f"corr(α̂, α₀) = {np.corrcoef(alpha_hat, true_alpha)[0, 1]:.3f}")
corr(α̂, α₀) = 0.963
print(f"RMSE         = {np.sqrt(np.mean((alpha_hat - true_alpha)**2)):.3f}")
RMSE         = 0.699
print(f"best_iter    = {rn.best_iteration_}")
best_iter    = 17
pkgload::load_all("../packages/riesznet/r/riesznet")
use_python_riesznet(file.path(getwd(), "../.venv/bin/python"))

set.seed(0)
n  <- 1500
x  <- runif(n)
pi <- 1 / (1 + exp(-(-0.02 * x - x^2 + 4 * log(x + 0.3) + 1.5)))
a  <- rbinom(n, 1, pi)
df <- data.frame(a = as.numeric(a), x = x)

rn <- RieszNet$new(
  estimand = ATE(treatment = "a", covariates = "x"),
  hidden_sizes = c(64L, 64L),
  learning_rate = 5e-3,
  epochs = 400L,
  validation_fraction = 0.2,
  early_stopping_rounds = 30L,
  random_state = 0L
)
rn$fit(df)
alpha_hat <- rn$predict(df)

Quickstart — TSM with KL loss

KLLoss is matched to density-ratio estimands like TSM. The exp link keeps predictions positive without any post-hoc clipping.

from riesznet import KLLoss, RieszNet, TSM

rn = RieszNet(
    estimand=TSM(level=1, treatment="a", covariates=("x",)),
    loss=KLLoss(),
    hidden_sizes=(64, 64),
    learning_rate=5e-3,
    epochs=400,
    validation_fraction=0.2,
    early_stopping_rounds=30,
    random_state=0,
)
rn.fit(df)
RieszNet(early_stopping_rounds=30, epochs=400,
         estimand=<rieszreg.estimands.base.TSM object at 0x7f0c021cb170>,
         learning_rate=0.005,
         loss=<rieszreg.losses.kl.KLLoss object at 0x7f0c02328230>,
         validation_fraction=0.2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
alpha_hat  = rn.predict(df)
true_alpha = (a == 1).astype(float) / pi
print(f"min α̂        = {alpha_hat.min():.4f}  (positive by construction)")
min α̂        = 0.0003  (positive by construction)
print(f"corr(α̂, α₀) = {np.corrcoef(alpha_hat, true_alpha)[0, 1]:.3f}")
corr(α̂, α₀) = 0.977

KLLoss and BernoulliLoss reject ATE / ATT / shift-style data at fit time (they require non-negative m-coefficients, which difference-of-evaluations estimands violate). BoundedSquaredLoss(lo, hi) accepts any signed coefficients and clips α̂ into the interval.

Custom architectures

The convenience class RieszNet covers MLPs. For an arbitrary nn.Module, instantiate TorchBackend directly and compose with rieszreg.RieszEstimator:

import functools
import torch
import torch.nn as nn

from rieszreg import ATE, RieszEstimator
from riesznet import TorchBackend


def my_factory(input_dim):
    """Top-level factory; importable by qualname so save/load works."""
    return nn.Sequential(
        nn.Linear(input_dim, 32), nn.GELU(),
        nn.Linear(32, 32), nn.GELU(),
        nn.Linear(32, 1),
    )

backend = TorchBackend(
    module_factory=my_factory,
    optimizer_factory=functools.partial(torch.optim.AdamW, lr=3e-4, weight_decay=1e-3),
    epochs=300,
    batch_size=256,        # in original rows
    device="cpu",
    dtype="float32",
    grad_clip_norm=1.0,
    validation_fraction=0.2,
    early_stopping_rounds=30,
)

est = RieszEstimator(
    estimand=ATE(treatment="a", covariates=("x",)),
    backend=backend,
    random_state=0,
)
est.fit(df)
RieszEstimator(backend=TorchBackend(module_factory=<function my_factory at 0x7f0c0da86660>,
                                    optimizer_factory=functools.partial(<class 'torch.optim.adamw.AdamW'>, lr=0.0003, weight_decay=0.001),
                                    scheduler_factory=None,
                                    epochs=300,
                                    batch_size=256,
                                    device='cpu',
                                    dtype='float32',
                                    grad_clip_norm=1.0,
                                    early_stopping_rounds=30,
                                    validation_fraction=0.2,
                                    snapshot_epochs=()),
               estimand=<rieszreg.estimands.base.ATE object at 0x7f0c02043d40>)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
alpha_hat = est.predict(df)
print(f"α̂ range: [{alpha_hat.min():.3f}, {alpha_hat.max():.3f}]")
α̂ range: [-4.868, 5.654]

module_factory and optimizer_factory must be top-level callables (functools.partial over a top-level function is fine). Closures and lambdas raise on save() because they cannot be reconstructed by qualname.

Hyperparameters

RieszNet constructor:

Knob Default Notes
hidden_sizes (64, 64) MLP layer widths. Empty tuple gives a linear model.
activation "relu" One of relu, tanh, gelu, elu, silu, leaky_relu.
dropout 0.0 Dropout probability after each activation.
learning_rate 1e-3 Adam learning rate.
weight_decay 0.0 Adam weight decay.
epochs 200 Max epochs. Combine with early_stopping_rounds for adaptive stopping.
batch_size 64 Original rows per minibatch. None is full-batch GD.
device "cpu" Also "cuda", "mps", "auto".
dtype "float32" Also "float64" (slower, more precise).
grad_clip_norm None Global L2 clip on the gradient.
loss SquaredLoss() Any of the four built-in Bregman losses.
init None α-space initialization. Default is m̄ = E[m(Z, 1)] projected into the loss’s domain (the loss-minimizing constant); pass a float to override.
validation_fraction 0.0 Hold out for early stopping.
early_stopping_rounds None Patience (in epochs) for validation Riesz loss.
random_state 0 Seeds torch, torch.cuda, and the DataLoader generator.

TorchBackend constructor (power-user surface):

Knob Default Notes
module_factory required Callable[[int], nn.Module] returning a model with scalar output.
optimizer_factory required Callable[[Iterable[Parameter]], Optimizer].
scheduler_factory None Optional Callable[[Optimizer], LRScheduler]. Stepped per epoch.
epochs 200
batch_size None Original-row batch size. None is full-batch.
device "cpu"
dtype "float32"
grad_clip_norm None

Save / load

import tempfile, os
with tempfile.TemporaryDirectory() as tmp:
    save_dir = os.path.join(tmp, "fitted")
    rn.save(save_dir)

    loaded = RieszNet.load(save_dir)
    pred_loaded = loaded.predict(df)
    print(f"max abs diff after round-trip: {np.max(np.abs(pred_loaded - alpha_hat)):.2e}")
max abs diff after round-trip: 9.48e+00

Save writes state_dict.pt (PyTorch weights, on CPU) plus predictor.json (carries the factory’s module + qualname, the input dim, and the Bregman-loss spec). Load re-imports the factory by qualname, rebuilds the module, and calls load_state_dict.

For built-in estimands the metadata round-trips automatically. For a custom Estimand, pass estimand= to load.

Sklearn integration

RieszNet is a BaseEstimator, so it composes with clone, GridSearchCV, cross_val_predict, Pipeline. See Tuning and cross-fitting.

Loss support

Loss Supported Link Notes
SquaredLoss identity Default. Works on every estimand.
KLLoss exp Density-ratio targets only (TSM, IPSI). Requires non-negative m-coefficients.
BernoulliLoss sigmoid Forces α̂ ∈ (0, 1). Density-ratio targets only.
BoundedSquaredLoss(lo, hi) scaled-sigmoid Forces α̂ ∈ (lo, hi). Works on every estimand.

The autograd implementation matches the analytic gradient in the loss spec exactly — verified per loss in the test suite.

Device, dtype, reproducibility

The training loop seeds torch, torch.cuda, and a torch.Generator for the data loader, so single-device fits are reproducible run-to-run on the same hardware. Bitwise reproducibility on CUDA across machines is not promised; torch.use_deterministic_algorithms is not enabled by default because it slows training significantly.

device="auto" picks cudampscpu in that order. dtype="float64" runs end-to-end in double precision, useful for tightly-coupled gradient checks but slower.