Kernel backend (krrr)

krrr provides KernelRidgeBackend and the KernelRieszRegressor convenience class. It implements Singh, Kernel Ridge Riesz Representers (arXiv:2102.11076) for the full set of estimands, by piping rieszreg’s augmentation engine into a closed-form kernel solve.

The augmented dataset gives per-row coefficients \(a_k\) and \(b_k\). The squared Riesz loss with kernel ridge regularization

\[ L_n(\alpha) = \frac{1}{n} \sum_k \big[ a_k\,\alpha(p_k)^2 + b_k\,\alpha(p_k) \big] + \lambda\,\|\alpha\|^2_\mathcal{H} \]

has the representer-theorem solution \(\hat\alpha = \sum_k \gamma_k\,k(\cdot, p_k)\) with \(\gamma\) solving the linear system

\[ (\mathrm{diag}(a)\,K + n\lambda\,I)\,\gamma \;=\; -b/2. \]

Partitioning into \(o = \{a_k > 0\}\) (original rows) and \(c = \{a_k = 0\}\) (counterfactual evaluation points): \(\gamma_c\) is closed form, and \(\gamma_o\) solves a symmetric PSD system on the o-block. A single eigendecomposition of the rescaled kernel matrix solves the entire λ path.

Quickstart

import numpy as np, pandas as pd
from krrr import KernelRieszRegressor, Gaussian
from rieszreg import ATE

rng = np.random.default_rng(0)
n = 1000
x = rng.uniform(0, 1, n)
pi = 1 / (1 + np.exp(-(-0.02*x - x**2 + 4*np.log(x + 0.3) + 1.5)))
a = rng.binomial(1, pi)
df = pd.DataFrame({"a": a.astype(float), "x": x})

krr = KernelRieszRegressor(
    estimand=ATE(treatment="a", covariates=("x",)),
    kernel=Gaussian(length_scale="median"),
    lambda_grid=np.logspace(-4, 0, 21),
    solver="auto",
    validation_fraction=0.25,
)
krr.fit(df)
KernelRieszRegressor(estimand=<rieszreg.estimands.base.ATE object at 0x7fc0f0460590>,
                     kernel=Gaussian(length_scale='median',
                                     _resolved=1.0000011555386372),
                     lambda_grid=array([1.00000000e-04, 1.58489319e-04, 2.51188643e-04, 3.98107171e-04,
       6.30957344e-04, 1.00000000e-03, 1.58489319e-03, 2.51188643e-03,
       3.98107171e-03, 6.30957344e-03, 1.00000000e-02, 1.58489319e-02,
       2.51188643e-02, 3.98107171e-02, 6.30957344e-02, 1.00000000e-01,
       1.58489319e-01, 2.51188643e-01, 3.98107171e-01, 6.30957344e-01,
       1.00000000e+00]),
                     validation_fraction=0.25)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
alpha_hat  = krr.predict(df)
true_alpha = a / pi - (1 - a) / (1 - pi)
print(f"selected lambda = {krr.lambda_:.4g}")
selected lambda = 0.0001
print(f"corr(α̂, α₀)     = {np.corrcoef(alpha_hat, true_alpha)[0, 1]:.3f}")
corr(α̂, α₀)     = 0.950
print(f"RMSE             = {np.sqrt(np.mean((alpha_hat - true_alpha)**2)):.3f}")
RMSE             = 0.795
# Load the krrr R wrapper (one-time per session) and configure reticulate.
pkgload::load_all("../packages/krrr/r/krrr")
use_python_krrr(file.path(getwd(), "../.venv/bin/python"))

set.seed(0)
n  <- 1000
x  <- runif(n)
pi <- 1 / (1 + exp(-(-0.02 * x - x^2 + 4 * log(x + 0.3) + 1.5)))
a  <- rbinom(n, 1, pi)
df <- data.frame(a = as.numeric(a), x = x)

krr <- KernelRieszRegressor$new(
  estimand = ATE("a", "x"),
  kernel = Gaussian(length_scale = "median"),
  lambda_grid = 10^seq(-4, 0, length.out = 21),
  solver = "auto",
  validation_fraction = 0.25
)
krr$fit(df)

alpha_hat  <- krr$predict(df)
true_alpha <- a / pi - (1 - a) / (1 - pi)
cat(sprintf("selected lambda = %.4g\n", reticulate::py_to_r(krr$py$lambda_)))
cat(sprintf("corr             = %.3f\n", cor(alpha_hat, true_alpha)))
cat(sprintf("RMSE             = %.3f\n", sqrt(mean((alpha_hat - true_alpha)^2))))

Kernels

krrr ships eight kernels with algebra (Sum, Product, Scaled, Tensor):

from krrr import Gaussian, Matern, Linear, Polynomial, Tensor

Gaussian(length_scale="median")          # default; median pairwise distance
Gaussian(length_scale="scott")            # Scott's rule
Gaussian(length_scale=0.5)                # fixed
Matern(nu=2.5, length_scale="median")     # ν ∈ {0.5, 1.5, 2.5}
Linear()                                  # k(x, y) = x · y
Polynomial(degree=3, gamma=1.0, coef0=1.0)

# Algebra
Gaussian() + Linear()                     # Sum
0.5 * Gaussian()                           # Scaled
Gaussian() * Linear()                      # Product (Hadamard on Gram)
Tensor(Gaussian(), [0, 1], Linear(), [2])  # tensor product over disjoint columns

Length-scale "median", "scott", and "silverman" heuristics resolve at fit time on the augmented training points, so the bandwidth adapts to whatever scale matters for this dataset.

Solver tier

Solver When to use Cost
"direct" n_aug ≤ 3,000 One eigendecomposition; entire λ-path is O(n²) per λ. Exact.
"nystrom_cg" n_aug ≤ 50,000 Preconditioned CG on the o-block; m landmarks.
"rff" n_aug very large; shift-invariant kernel Primal D × D solve via random Fourier features.
"falkon" n_aug very large; GPU available Wraps the optional falkon package.
"auto" default Dispatches by n_aug.

The solver consumes the augmented dataset directly; you never deal with kernel matrices yourself.

WarningFalkon limitation

The Falkon backend currently drops the \(K_{oc}\,b_c\) coupling on the o-block — Falkon’s standalone API only solves vanilla KRR, not the modified-RHS system the augmentation produces. For estimands where \(n_c\) is small or λ is moderate the bias is small; for tight overlap or extreme λ it is not. Use solver="nystrom_cg" if exactness matters more than scale.

λ selection

Pass a grid of λ values; the backend selects by validation Riesz loss when validation_fraction > 0 or eval_set is given.

krr = KernelRieszRegressor(
    estimand=ATE(),
    lambda_grid=np.logspace(-5, 1, 31),
    validation_fraction=0.25,
).fit(df)
print(krr.lambda_)   # the chosen λ

For consistency theory, λ should scale \(O(1/n)\). Cross-fitting users should re-tune per fold (sklearn’s cross_val_predict does this if KernelRieszRegressor is wrapped in GridSearchCV).

Loss support

KernelRidgeBackend currently supports SquaredLoss only — non-quadratic losses (KLLoss, BernoulliLoss, BoundedSquaredLoss) require Newton iteration on the kernel system, planned for v0.2. Pass an unsupported loss and the backend raises at fit time with a clear error.

For non-quadratic losses today, use the boosting backend.

Reference parity

krrr exercises a TSM1 numerical-parity test against the dml-tmle R reference at tolerance 1e-8.

Sharp edges

  • Median-heuristic bandwidth on the augmented dataset. The median is computed on augmented points (originals + counterfactuals from \(m\)). For shift-style estimands this includes the shifted treatment values.
  • solver="falkon" drops \(K_{oc}\,b_c\). See callout above.