Save and load

A fitted Riesz estimator can be persisted two ways:

Both paths round-trip every built-in estimand and loss automatically.

The shared metadata format is defined in rieszreg.estimator.RieszEstimator.save. Backends register their Predictor loader at import time so the inherited load(...) classmethod can hydrate them by predictor_kind.

sklearn idiom: joblib.dump

import joblib, tempfile, pathlib
import numpy as np, pandas as pd
from rieszboost import RieszBooster
from rieszreg import ATE

rng = np.random.default_rng(0)
n = 800
x = rng.uniform(0, 1, n)
pi = 1 / (1 + np.exp(-(-0.02 * x - x**2 + 4 * np.log(x + 0.3) + 1.5)))
a = rng.binomial(1, pi)
df = pd.DataFrame({"a": a.astype(float), "x": x})

booster = RieszBooster(
    estimand=ATE(),
    n_estimators=200, learning_rate=0.05, max_depth=4,
).fit(df)
preds_before = booster.predict(df)

with tempfile.TemporaryDirectory() as td:
    p = pathlib.Path(td) / "booster.pkl"
    joblib.dump(booster, p)
    loaded = joblib.load(p)
    print("predictions match :", np.array_equal(preds_before, loaded.predict(df)))

    # Loaded booster still composes with sklearn:
    from sklearn.base import clone
    cloned = clone(loaded)
    print("clone OK          :", cloned.estimand.name)
['/var/folders/07/xf4cwwjd7nggpmy8y006ww6w0000gq/T/tmp2wp65cs8/booster.pkl']
predictions match : True
clone OK          : ATE

For R-side workflows, use the directory format below. To pickle from R via reticulate:

joblib <- reticulate::import("joblib")
joblib$dump(booster$py, "booster.pkl")
loaded_py <- joblib$load("booster.pkl")
# Wrap loaded_py back into an R6 RieszBooster if needed.

Portable directory format

import numpy as np
import pandas as pd
import tempfile, pathlib
from rieszboost import RieszBooster
from rieszreg import ATE

rng = np.random.default_rng(0)
n = 800
x = rng.uniform(0, 1, n)
pi = 1 / (1 + np.exp(-(-0.02*x - x**2 + 4*np.log(x + 0.3) + 1.5)))
a = rng.binomial(1, pi)
df = pd.DataFrame({"a": a.astype(float), "x": x})

booster = RieszBooster(
    estimand=ATE(),
    n_estimators=200, learning_rate=0.05, max_depth=4,
    early_stopping_rounds=20, validation_fraction=0.2,
).fit(df)
preds_before = booster.predict(df)

with tempfile.TemporaryDirectory() as td:
    save_path = pathlib.Path(td) / "my_alpha"
    booster.save(save_path)
    print("Saved files:", sorted(p.name for p in save_path.iterdir()))

    loaded = RieszBooster.load(save_path)
    preds_after = loaded.predict(df)
    print(f"predictions match exactly: {np.array_equal(preds_before, preds_after)}")
    print(f"best_iteration preserved : {loaded.best_iteration_}")
Saved files: ['booster.ubj', 'metadata.json']
predictions match exactly: True
best_iteration preserved : 199
set.seed(0)
n  <- 800
x  <- runif(n)
pi <- 1 / (1 + exp(-(-0.02 * x - x^2 + 4 * log(x + 0.3) + 1.5)))
a  <- rbinom(n, 1, pi)
df <- data.frame(a = as.numeric(a), x = x)

booster <- RieszBooster$new(
  estimand = ATE("a", "x"),
  n_estimators = 200L, learning_rate = 0.05, max_depth = 4L,
  early_stopping_rounds = 20L, validation_fraction = 0.2
)
booster$fit(df)
preds_before <- booster$predict(df)

td <- tempfile()
booster$save(td)
cat("Saved files:", paste(list.files(td), collapse = ", "), "\n")
Saved files: booster.ubj, metadata.json 
loaded <- load_riesz_booster(td)
preds_after <- loaded$predict(df)
cat(sprintf("predictions match exactly: %s\n",
            isTRUE(all.equal(preds_before, preds_after))))
predictions match exactly: TRUE
unlink(td, recursive = TRUE)

Cross-language

Save in one language, load in the other. The backend payload is in its native format (xgboost UBJSON, joblib, kernel npz), and the metadata sidecar is plain JSON.

# In R: save the booster
booster$save("alpha_ate")
# In Python: load the same directory and continue working
loaded = RieszBooster.load("alpha_ate")
loaded.predict(new_df)

The R parity test exercises both directions with tolerance = 1e-12. The Python suite asserts that predictions are bitwise-identical pre- and post-save.

Custom estimands

Custom (non-built-in) Estimands can’t be reconstructed from JSON — the user’s m() callable doesn’t survive a save/load cycle. The estimator file itself is fine; you just have to re-pass the estimand on load:

import numpy as np
import pandas as pd
import tempfile, pathlib
from rieszboost import RieszBooster
from rieszreg import FiniteEvalEstimand

# Set up data and a custom estimand (not a built-in)
rng = np.random.default_rng(0)
n = 200
df = pd.DataFrame({"a": rng.binomial(1, 0.5, n).astype(float),
                   "x": rng.uniform(0, 1, n)})

def m_custom(alpha):
    def inner(z, y=None):
        return alpha(a=1, x=z["x"]) - alpha(a=0, x=z["x"])
    return inner

custom = FiniteEvalEstimand(feature_keys=("a", "x"), m=m_custom, name="my_custom")
booster = RieszBooster(estimand=custom, n_estimators=20).fit(df)

with tempfile.TemporaryDirectory() as td:
    p = pathlib.Path(td) / "custom"
    booster.save(p)

    # Loading without the original estimand raises:
    try:
        RieszBooster.load(p)
    except ValueError as e:
        print("expected error:", str(e)[:80], "...")

    # Pass the original estimand to reconstruct:
    loaded = RieszBooster.load(p, estimand=custom)
    print("loaded with custom estimand OK")
expected error: Saved estimator at /var/folders/07/xf4cwwjd7nggpmy8y006ww6w0000gq/T/tmpmn976cl5/ ...
loaded with custom estimand OK

What’s in the file

Per-backend binary payload, plus metadata.json containing:

  • predictor_kind: "xgboost" (rieszboost), "sklearn" (rieszboost), "krrr" (krrr), "forestriesz" (forestriesz moment-style), "aug-forestriesz" (forestriesz augmentation-style), or "riesznet" (riesznet).
  • loss: {"type": "SquaredLoss" | "KLLoss" | …, "args": {…}} — re-instantiated with loss_from_spec(...) on load.
  • estimand_factory_spec: {"factory": "ATE" | …} for built-ins, null for custom.
  • feature_keys, base_score, best_iteration, best_score.
  • estimator_class and hyperparameters — informational; lets get_params(deep=False) on the loaded estimator mirror what you fit with.

The JSON is plain text and diffable; if the file format ever changes, rieszreg_format_version lets us migrate.