Skip to content

Commit 35d9593

Browse files
Erik-BMclaude
andcommitted
Add quantile regression with pinball loss to EBMs
Adds a new "quantile" objective for ExplainableBoostingRegressor that enables quantile regression via pinball loss. Usage: objective="quantile:alpha=0.5" where alpha in (0,1) selects the target quantile. This is useful for prediction intervals, asymmetric risk, and robust median regression. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ec085cc commit 35d9593

5 files changed

Lines changed: 412 additions & 1 deletion

File tree

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# EBM Internals - Quantile Regression\n",
8+
"\n",
9+
"This notebook covers quantile regression using pinball loss with Explainable Boosting Machines. For standard regression internals, see [Part 1](./ebm-internals-regression.ipynb). For classification, see [Part 2](./ebm-internals-classification.ipynb). For multiclass, see [Part 3](./ebm-internals-multiclass.ipynb).\n",
10+
"\n",
11+
"Standard regression models (e.g. with RMSE) predict the conditional mean of the target. Quantile regression instead predicts a specific quantile (e.g. median, 10th percentile, 90th percentile). This is useful for:\n",
12+
"\n",
13+
"- **Prediction intervals**: Fit models at the 10th and 90th percentiles to get an 80% prediction interval.\n",
14+
"- **Asymmetric risk**: When over-predicting and under-predicting have different costs.\n",
15+
"- **Robustness**: Median regression (alpha=0.5) is more robust to outliers than mean regression.\n",
16+
"\n",
17+
"EBMs support quantile regression via the `\"quantile\"` objective, which uses the pinball loss (also called the quantile loss). The pinball loss for quantile alpha is:\n",
18+
"\n",
19+
"$$L(y, \\hat{y}) = \\begin{cases} \\alpha \\cdot (y - \\hat{y}) & \\text{if } y \\geq \\hat{y} \\\\ (1 - \\alpha) \\cdot (\\hat{y} - y) & \\text{if } y < \\hat{y} \\end{cases}$$\n",
20+
"\n",
21+
"This loss penalizes under-predictions by a factor of alpha and over-predictions by a factor of (1 - alpha), causing the model to learn the alpha-quantile of the conditional distribution."
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# boilerplate\n",
31+
"from interpret import show\n",
32+
"from interpret.glassbox import ExplainableBoostingRegressor\n",
33+
"import numpy as np"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"metadata": {},
39+
"source": [
40+
"## Median Regression (alpha=0.5)\n",
41+
"\n",
42+
"Let's start with median regression. With alpha=0.5, the pinball loss penalizes over- and under-predictions equally, so the model learns the conditional median rather than the conditional mean."
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"# make a dataset composed of a nominal categorical, and a continuous feature\n",
52+
"X = [[\"Peru\", 7.0], [\"Fiji\", 8.0], [\"Peru\", 9.0]]\n",
53+
"y = [450.0, 550.0, 350.0]\n",
54+
"\n",
55+
"# Fit a quantile EBM for the median (alpha=0.5)\n",
56+
"# Eliminate the validation set to handle the small dataset\n",
57+
"ebm_median = ExplainableBoostingRegressor(\n",
58+
" objective=\"quantile:alpha=0.5\",\n",
59+
" interactions=0,\n",
60+
" validation_size=0, outer_bags=1, min_samples_leaf=1, min_hessian=1e-9)\n",
61+
"ebm_median.fit(X, y)\n",
62+
"show(ebm_median.explain_global())"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"metadata": {},
68+
"source": [
69+
"The model structure is identical to a standard regression EBM: an intercept plus additive score contributions from each feature, looked up via binning. The only difference is the loss function used during training."
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"print(\"Intercept:\", ebm_median.intercept_)\n",
79+
"print(\"Feature types:\", ebm_median.feature_types_in_)\n",
80+
"print(\"Bins:\", ebm_median.bins_)\n",
81+
"print(\"Categorical scores:\", ebm_median.term_scores_[0])\n",
82+
"print(\"Continuous scores:\", ebm_median.term_scores_[1])"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"Predictions are computed identically to standard regression EBMs: start from the intercept and add lookup table scores for each feature."
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"print(\"Median predictions:\", ebm_median.predict(X))\n",
99+
"print(\"Original y values: \", y)"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"metadata": {},
105+
"source": [
106+
"## Prediction Intervals\n",
107+
"\n",
108+
"A key use case for quantile regression is constructing prediction intervals. By fitting separate models at different quantiles, we can estimate the range within which future observations are likely to fall.\n",
109+
"\n",
110+
"Let's use a larger, noisier dataset to demonstrate this. We'll fit models at the 10th, 50th, and 90th percentiles to get an 80% prediction interval."
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"metadata": {},
117+
"outputs": [],
118+
"source": [
119+
"from sklearn.datasets import make_regression\n",
120+
"\n",
121+
"X_train, y_train = make_regression(\n",
122+
" n_samples=1000, n_features=5, noise=20.0, random_state=42)\n",
123+
"\n",
124+
"X_test, y_test = make_regression(\n",
125+
" n_samples=200, n_features=5, noise=20.0, random_state=123)\n",
126+
"\n",
127+
"# Fit quantile models at the 10th, 50th, and 90th percentiles\n",
128+
"ebm_10 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.1\")\n",
129+
"ebm_50 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.5\")\n",
130+
"ebm_90 = ExplainableBoostingRegressor(objective=\"quantile:alpha=0.9\")\n",
131+
"\n",
132+
"ebm_10.fit(X_train, y_train)\n",
133+
"ebm_50.fit(X_train, y_train)\n",
134+
"ebm_90.fit(X_train, y_train)\n",
135+
"\n",
136+
"pred_10 = ebm_10.predict(X_test)\n",
137+
"pred_50 = ebm_50.predict(X_test)\n",
138+
"pred_90 = ebm_90.predict(X_test)\n",
139+
"\n",
140+
"print(\"First 5 test samples:\")\n",
141+
"print(\" 10th percentile:\", np.round(pred_10[:5], 2))\n",
142+
"print(\" 50th percentile:\", np.round(pred_50[:5], 2))\n",
143+
"print(\" 90th percentile:\", np.round(pred_90[:5], 2))\n",
144+
"print(\" Actual y: \", np.round(y_test[:5], 2))"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"# Verify quantile ordering: q10 < q50 < q90 for most predictions\n",
154+
"print(\"Fraction where q10 < q50:\", np.mean(pred_10 < pred_50))\n",
155+
"print(\"Fraction where q50 < q90:\", np.mean(pred_50 < pred_90))\n",
156+
"\n",
157+
"# Check empirical coverage of the 80% prediction interval [q10, q90]\n",
158+
"coverage = np.mean((y_test >= pred_10) & (y_test <= pred_90))\n",
159+
"print(f\"Empirical coverage of [q10, q90] interval: {coverage:.1%} (target: ~80%)\")"
160+
]
161+
},
162+
{
163+
"cell_type": "markdown",
164+
"metadata": {},
165+
"source": [
166+
"## Visualizing the Prediction Interval\n",
167+
"\n",
168+
"Let's visualize the prediction interval on a sorted subset of test samples."
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"metadata": {},
175+
"outputs": [],
176+
"source": [
177+
"try:\n",
178+
" import matplotlib\n",
179+
" import matplotlib.pyplot as plt\n",
180+
"\n",
181+
" # Sort by predicted median for a cleaner plot\n",
182+
" sort_idx = np.argsort(pred_50)\n",
183+
" x_axis = np.arange(len(sort_idx))\n",
184+
"\n",
185+
" fig, ax = plt.subplots(figsize=(12, 5))\n",
186+
" ax.fill_between(x_axis, pred_10[sort_idx], pred_90[sort_idx],\n",
187+
" alpha=0.3, label=\"80% prediction interval (q10-q90)\")\n",
188+
" ax.plot(x_axis, pred_50[sort_idx], label=\"Median prediction (q50)\", linewidth=1.5)\n",
189+
" ax.scatter(x_axis, y_test[sort_idx], s=8, color=\"red\", alpha=0.6, label=\"Actual values\")\n",
190+
" ax.set_xlabel(\"Test samples (sorted by predicted median)\")\n",
191+
" ax.set_ylabel(\"Target value\")\n",
192+
" ax.set_title(\"EBM Quantile Regression: 80% Prediction Interval\")\n",
193+
" ax.legend()\n",
194+
" plt.tight_layout()\n",
195+
" plt.show()\n",
196+
"except ImportError:\n",
197+
" print(\"matplotlib not installed, skipping plot\")"
198+
]
199+
},
200+
{
201+
"cell_type": "markdown",
202+
"metadata": {},
203+
"source": [
204+
"## Interpretability\n",
205+
"\n",
206+
"One of the key advantages of quantile EBMs is that they remain fully interpretable. The global explanations show how each feature contributes to the predicted quantile, and local explanations show the additive score breakdown for individual predictions.\n",
207+
"\n",
208+
"Let's compare the shape functions for the same feature across different quantiles."
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"# Show global explanations for each quantile model\n",
218+
"print(\"=== 10th Percentile Model ===\")\n",
219+
"show(ebm_10.explain_global())"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {},
226+
"outputs": [],
227+
"source": [
228+
"print(\"=== 90th Percentile Model ===\")\n",
229+
"show(ebm_90.explain_global())"
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": null,
235+
"metadata": {},
236+
"outputs": [],
237+
"source": [
238+
"# Local explanation for a single test sample\n",
239+
"show(ebm_50.explain_local(X_test[:5], y_test[:5]), 0)"
240+
]
241+
},
242+
{
243+
"cell_type": "markdown",
244+
"metadata": {},
245+
"source": [
246+
"## Making Predictions Manually\n",
247+
"\n",
248+
"Just like standard regression EBMs, quantile EBM predictions are computed by summing the intercept with lookup table scores for each feature. The prediction logic is identical; only the training loss differs."
249+
]
250+
},
251+
{
252+
"cell_type": "code",
253+
"execution_count": null,
254+
"metadata": {},
255+
"outputs": [],
256+
"source": [
257+
"# Use the small dataset to demonstrate manual predictions\n",
258+
"X_small = [[\"Peru\", 7.0], [\"Fiji\", 8.0], [\"Peru\", 9.0]]\n",
259+
"y_small = [450.0, 550.0, 350.0]\n",
260+
"\n",
261+
"sample_scores = []\n",
262+
"for sample in X_small:\n",
263+
" score = ebm_median.intercept_\n",
264+
" print(\"intercept: \" + str(score))\n",
265+
"\n",
266+
" for feature_idx, feature_val in enumerate(sample):\n",
267+
" bins = ebm_median.bins_[feature_idx][0]\n",
268+
" if isinstance(bins, dict):\n",
269+
" bin_idx = bins[feature_val]\n",
270+
" else:\n",
271+
" bin_idx = np.digitize(feature_val, bins) + 1\n",
272+
"\n",
273+
" local_score = ebm_median.term_scores_[feature_idx][bin_idx]\n",
274+
" print(ebm_median.feature_names_in_[feature_idx] + \": \" + str(local_score))\n",
275+
" score += local_score\n",
276+
" sample_scores.append(score)\n",
277+
" print()\n",
278+
"\n",
279+
"print(\"PREDICTIONS (manual):\")\n",
280+
"print(np.array(sample_scores))\n",
281+
"print(\"PREDICTIONS (ebm.predict):\")\n",
282+
"print(ebm_median.predict(X_small))"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"## Summary\n",
290+
"\n",
291+
"- Use `objective=\"quantile:alpha=0.5\"` for median regression, or any alpha in (0, 1) for other quantiles.\n",
292+
"- The prediction mechanism is identical to standard regression EBMs (intercept + additive score lookups). Only the training loss function changes.\n",
293+
"- Fitting multiple quantile models (e.g. alpha=0.1 and alpha=0.9) provides interpretable prediction intervals.\n",
294+
"- All EBM interpretability tools (global/local explanations) work with quantile models.\n",
295+
"- For the complete prediction code that handles interactions, missing values, and all model types, see [Part 3](./ebm-internals-multiclass.ipynb)."
296+
]
297+
}
298+
],
299+
"metadata": {
300+
"kernelspec": {
301+
"display_name": "3.10.13",
302+
"language": "python",
303+
"name": "python3"
304+
},
305+
"language_info": {
306+
"codemirror_mode": {
307+
"name": "ipython",
308+
"version": 3
309+
},
310+
"file_extension": ".py",
311+
"mimetype": "text/x-python",
312+
"name": "python",
313+
"nbconvert_exporter": "python",
314+
"pygments_lexer": "ipython3",
315+
"version": "3.10.13"
316+
}
317+
},
318+
"nbformat": 4,
319+
"nbformat_minor": 4
320+
}

docs/interpret/ebm-internals.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ This section is divided into 3 parts that build upon each other:
66
[Part 1](./ebm-internals-regression.ipynb) Covers regression for pure GAM models (no interactions).
77
[Part 2](./ebm-internals-classification.ipynb) Covers binary classification with interactions, ordinals, and missing values.
88
[Part 3](./ebm-internals-multiclass.ipynb) Covers multiclass, and unseen values.
9+
[Quantile Regression](./ebm-internals-quantile-regression.ipynb) Covers quantile regression with pinball loss and prediction intervals.

python/interpret-core/interpret/glassbox/_ebm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3940,7 +3940,8 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
39403940
objective : str, default="rmse"
39413941
The objective to optimize. Options include: "rmse",
39423942
"poisson_deviance", "tweedie_deviance:variance_power=1.5", "gamma_deviance",
3943-
"pseudo_huber:delta=1.0", "rmse_log" (rmse with a log link function)
3943+
"pseudo_huber:delta=1.0", "rmse_log" (rmse with a log link function),
3944+
"quantile:alpha=0.5" (quantile regression with pinball loss)
39443945
n_jobs : int, default=-2
39453946
Number of jobs to run in parallel. Negative integers are interpreted as following joblib's formula
39463947
(n_cpus + 1 + n_jobs), just like scikit-learn. Eg: -2 means using all threads except 1.

0 commit comments

Comments
 (0)