riesznet trains the Riesz representer α(x) with a PyTorch model, in the spirit of Chernozhukov et al. (2021), RieszNet and ForestRiesz. It fits the Riesz representer only — outcome regression is handled separately.
rieszreg exposes the entry point as MomentBackend.fit_rows. The neural backend uses the per-row moment formulation: for each original row z_i,
where the (c_j, p_j) pairs come from rieszreg.trace(estimand, z_i) and the model produces α via a Bregman link. The training loop minimizes the mean of L_i with autograd; no augmented dataset is built.
Two surfaces
RieszNet — convenience class with a default MLP. Surfaces hidden_sizes, activation, dropout, learning_rate, weight_decay, epochs, device, dtype. Trains full-batch.
TorchBackend — power-user surface. Pass any nn.Module factory and any optimizer factory. Compose with rieszreg.RieszEstimator for full control.
RieszNet(early_stopping_rounds=30, epochs=400,
estimand=<rieszreg.estimands.base.ATE object at 0x7f0c0d92bb60>,
learning_rate=0.005, validation_fraction=0.2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RieszNet(early_stopping_rounds=30, epochs=400,
estimand=<rieszreg.estimands.base.ATE object at 0x7f0c0d92bb60>,
learning_rate=0.005, validation_fraction=0.2)
alpha_hat = rn.predict(df)true_alpha = a / pi - (1- a) / (1- pi)print(f"corr(α̂, α₀) = {np.corrcoef(alpha_hat, true_alpha)[0, 1]:.3f}")
RieszNet(early_stopping_rounds=30, epochs=400,
estimand=<rieszreg.estimands.base.TSM object at 0x7f0c021cb170>,
learning_rate=0.005,
loss=<rieszreg.losses.kl.KLLoss object at 0x7f0c02328230>,
validation_fraction=0.2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RieszNet(early_stopping_rounds=30, epochs=400,
estimand=<rieszreg.estimands.base.TSM object at 0x7f0c021cb170>,
learning_rate=0.005,
loss=<rieszreg.losses.kl.KLLoss object at 0x7f0c02328230>,
validation_fraction=0.2)
alpha_hat = rn.predict(df)true_alpha = (a ==1).astype(float) / piprint(f"min α̂ = {alpha_hat.min():.4f} (positive by construction)")
KLLoss and BernoulliLoss reject ATE / ATT / shift-style data at fit time (they require non-negative m-coefficients, which difference-of-evaluations estimands violate). BoundedSquaredLoss(lo, hi) accepts any signed coefficients and clips α̂ into the interval.
Custom architectures
The convenience class RieszNet covers MLPs. For an arbitrary nn.Module, instantiate TorchBackend directly and compose with rieszreg.RieszEstimator:
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
module_factory and optimizer_factory must be top-level callables (functools.partial over a top-level function is fine). Closures and lambdas raise on save() because they cannot be reconstructed by qualname.
Hyperparameters
RieszNet constructor:
Knob
Default
Notes
hidden_sizes
(64, 64)
MLP layer widths. Empty tuple gives a linear model.
activation
"relu"
One of relu, tanh, gelu, elu, silu, leaky_relu.
dropout
0.0
Dropout probability after each activation.
learning_rate
1e-3
Adam learning rate.
weight_decay
0.0
Adam weight decay.
epochs
200
Max epochs. Combine with early_stopping_rounds for adaptive stopping.
batch_size
64
Original rows per minibatch. None is full-batch GD.
device
"cpu"
Also "cuda", "mps", "auto".
dtype
"float32"
Also "float64" (slower, more precise).
grad_clip_norm
None
Global L2 clip on the gradient.
loss
SquaredLoss()
Any of the four built-in Bregman losses.
init
None
α-space initialization. Default is m̄ = E[m(Z, 1)] projected into the loss’s domain (the loss-minimizing constant); pass a float to override.
validation_fraction
0.0
Hold out for early stopping.
early_stopping_rounds
None
Patience (in epochs) for validation Riesz loss.
random_state
0
Seeds torch, torch.cuda, and the DataLoader generator.
TorchBackend constructor (power-user surface):
Knob
Default
Notes
module_factory
required
Callable[[int], nn.Module] returning a model with scalar output.
optimizer_factory
required
Callable[[Iterable[Parameter]], Optimizer].
scheduler_factory
None
Optional Callable[[Optimizer], LRScheduler]. Stepped per epoch.
Save writes state_dict.pt (PyTorch weights, on CPU) plus predictor.json (carries the factory’s module + qualname, the input dim, and the Bregman-loss spec). Load re-imports the factory by qualname, rebuilds the module, and calls load_state_dict.
For built-in estimands the metadata round-trips automatically. For a custom Estimand, pass estimand= to load.
Sklearn integration
RieszNet is a BaseEstimator, so it composes with clone, GridSearchCV, cross_val_predict, Pipeline. See Tuning and cross-fitting.
Loss support
Loss
Supported
Link
Notes
SquaredLoss
✅
identity
Default. Works on every estimand.
KLLoss
✅
exp
Density-ratio targets only (TSM, IPSI). Requires non-negative m-coefficients.
BernoulliLoss
✅
sigmoid
Forces α̂ ∈ (0, 1). Density-ratio targets only.
BoundedSquaredLoss(lo, hi)
✅
scaled-sigmoid
Forces α̂ ∈ (lo, hi). Works on every estimand.
The autograd implementation matches the analytic gradient in the loss spec exactly — verified per loss in the test suite.
Device, dtype, reproducibility
The training loop seeds torch, torch.cuda, and a torch.Generator for the data loader, so single-device fits are reproducible run-to-run on the same hardware. Bitwise reproducibility on CUDA across machines is not promised; torch.use_deterministic_algorithms is not enabled by default because it slows training significantly.
device="auto" picks cuda → mps → cpu in that order. dtype="float64" runs end-to-end in double precision, useful for tightly-coupled gradient checks but slower.