Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 212 additions & 3 deletions docs/docs/tutorials/fitting-bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib widget"
"%matplotlib inline"
]
},
{
Expand Down Expand Up @@ -492,11 +492,220 @@
"ax.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "bcadc094",
"metadata": {},
"source": [
"## Persist and reload a chain\n",
"\n",
"In a real analysis you may want to **save a running chain to disk** for later inspection, or to resume it across compute sessions.\n",
"\n",
"`easyscience` provides robust I/O via ``save_sampler_state(state, prefix)`` and ``load_sampler_state(prefix)``. Additionally, ``Fitter.mcmc_sample()`` accepts a ``resume_state`` parameter — pass the ``MCMCDraw`` object from a previous run (or one reloaded from disk via ``load_sampler_state``) and DREAM **continues** the saved chain instead of starting cold.\n",
"\n",
"**Ring-buffer contract (important!)** DREAM stores draws in a fixed-size ring buffer sized to *samples*. To add N new draws on top of an existing chain of M draws without losing the old ones, pass ``samples = M + N`` and ``burn = 0``:\n",
"\n",
"```python\n",
"extended = fitter.mcmc_sample(..., samples=M + N, burn=0,\n",
" resume_state=previous_state)\n",
"```\n",
"\n",
"Here we save the chain from the previous exercise to disk, reload it, and verify the round-trip preserved the draws."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7ef160c",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"\n",
"from easyscience.fitting.minimizers import save_sampler_state\n",
"\n",
"# Save the chain from the previous exercise to disk.\n",
"# save_sampler_state writes several files using 'prefix' as a filename stem.\n",
"prefix = os.path.abspath('dream_chain_example')\n",
"save_sampler_state(result['internal_bumps_object'], prefix)\n",
"\n",
"saved_files = sorted(glob.glob(prefix + '-*'))\n",
"print(f'Saved {len(saved_files)} file(s):')\n",
"for f in saved_files:\n",
" print(f' {os.path.basename(f)}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3acaa5c6",
"metadata": {},
"outputs": [],
"source": [
"from easyscience.fitting.minimizers import load_sampler_state\n",
"\n",
"# Reload the saved state from disk.\n",
"reloaded = load_sampler_state(prefix)\n",
"\n",
"# Compare the draws before and after the round-trip. We draw with\n",
"# portion=1.0 on both so the comparison uses the full chain: an in-memory\n",
"# state may carry a convergence-trimmed `portion` (so draw() returns only a\n",
"# tail of the samples), whereas a state reloaded from disk always reports\n",
"# portion=1.0. The underlying chain data is identical either way.\n",
"_draw_orig = result['internal_bumps_object'].draw(portion=1.0)\n",
"_draw_reloaded = reloaded.draw(portion=1.0)\n",
"\n",
"print('Draws match:', np.allclose(_draw_orig.points, _draw_reloaded.points))\n",
"print('logp match:', np.allclose(_draw_orig.logp, _draw_reloaded.logp))\n",
"print('Nvar (params):', reloaded.Nvar)\n",
"print('Npop (pop. size):', reloaded.Npop)\n",
"print('Ncr:', reloaded.Ncr)\n",
"print('Labels:', reloaded.labels)\n",
"\n",
"# The reloaded state is a fully functional MCMCDraw that can be passed\n",
"# as resume_state= to MultiFitter.sample() to continue the chain.\n",
"# (Note: BUMPS' DREAM resume has a known limitation with thin>1 and\n",
"# outlier removal. load_sampler_state also works around a BUMPS >=1.0.4\n",
"# regression where a short chain's single-row buffers are read back as 1-D\n",
"# arrays; prefer it over calling bumps.dream.state.load_state directly.)"
]
},
{
"cell_type": "markdown",
"id": "03339658",
"metadata": {},
"source": [
"## Extend the chain and check convergence\n",
"\n",
"The original run used ``samples=10000`` (5000 retained after thinning). Here we **extend the chain by 5000 more raw samples** using ``resume_state`` — DREAM continues from the saved population instead of starting cold.\n",
"\n",
"**Ring-buffer contract:** DREAM stores draws in a fixed-size ring buffer sized to *samples* (raw samples before thinning). To keep everything, pass ``samples = old_raw + new_raw`` and ``burn=0``:\n",
"\n",
"```python\n",
"extended = fitter.mcmc_sample(..., samples=old_raw + new_raw, burn=0,\n",
" resume_state=previous_state)\n",
"```\n",
"\n",
"After the extension we compare the posterior summaries and check convergence with Gelman-Rubin R-hat."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "293b140b",
"metadata": {},
"outputs": [],
"source": [
"# Extend the existing chain by 5000 more raw samples.\n",
"# samples = original_raw + new_raw = 10000 + 5000 = 15000 raw\n",
"# burn=0 because the chain is already warm.\n",
"total_raw = 10000 + 5000\n",
"\n",
"extended = mle_fitter.mcmc_sample(\n",
" x=omega,\n",
" y=intensity_obs,\n",
" weights=1 / intensity_error,\n",
" samples=total_raw,\n",
" burn=0,\n",
" thin=2,\n",
" resume_state=result['internal_bumps_object'], # continue from where we left off\n",
")\n",
"\n",
"short_draws = result['draws']\n",
"extended_draws = extended['draws']\n",
"\n",
"print(f'Short chain: {short_draws.shape[0]} retained draws')\n",
"print(f'Extended chain: {extended_draws.shape[0]} retained draws (requested {total_raw} raw)')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ec4302c",
"metadata": {},
"outputs": [],
"source": [
"# Build helpers to look up parameter columns.\n",
"name_to_col_ext = {name: idx for idx, name in enumerate(extended['param_names'])}\n",
"name_to_col_orig = {name: idx for idx, name in enumerate(result['param_names'])}\n",
"\n",
"\n",
"def col_orig(par):\n",
" return result['draws'][:, name_to_col_orig[par.unique_name]]\n",
"\n",
"\n",
"def col_ext(par):\n",
" return extended['draws'][:, name_to_col_ext[par.unique_name]]\n",
"\n",
"\n",
"# Side-by-side posterior summary: compare the first 5000 draws (before extension)\n",
"# with the FULL extended set.\n",
"print(f'{\"param\":<10s} {\"metric\":>12s} {\"first 5k\":>12s} {\"full ext\":>12s} diff')\n",
"print('-' * 65)\n",
"for label, par in (('area', area), ('gamma', gamma), ('omega_0', omega_0), ('sigma', sigma)):\n",
" c_first = col_orig(par) # first 5000 draws\n",
" c_full = col_ext(par) # full extended ~7500 draws\n",
" for metric, fn in [\n",
" ('mean', np.mean),\n",
" ('std', np.std),\n",
" ('q2.5%', lambda c: np.percentile(c, 2.5)),\n",
" ('q50%', lambda c: np.percentile(c, 50)),\n",
" ('q97.5%', lambda c: np.percentile(c, 97.5)),\n",
" ]:\n",
" vf = fn(c_first)\n",
" vx = fn(c_full)\n",
" diff = vx - vf\n",
" print(f'{label:<10s} {metric:>12s} {vf:12.4g} {vx:12.4g} {diff:+.2e}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0f30be6",
"metadata": {},
"outputs": [],
"source": [
"# Visual comparison: overlay the first 5000 draws with the extended set.\n",
"fig, axes = plt.subplots(1, 4, figsize=(14, 3))\n",
"for ax, label, par in zip(\n",
" axes, ('area', 'gamma', 'omega_0', 'sigma'), (area, gamma, omega_0, sigma)\n",
"):\n",
" c_first = col_orig(par)\n",
" c_full = col_ext(par)\n",
" ax.hist(c_first, bins=40, density=True, alpha=0.5, color='C0', label=f'first 5k')\n",
" ax.hist(\n",
" c_full, bins=40, density=True, alpha=0.5, color='C3', label=f'extended ({len(c_full)})'\n",
" )\n",
" ax.set_title(label)\n",
" ax.set_yticks([])\n",
"axes[0].legend(fontsize=9)\n",
"fig.suptitle('Marginal posterior: first 5000 draws vs extended chain')\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3449e0a7",
"metadata": {},
"outputs": [],
"source": [
"# Convergence diagnostic: Gelman-Rubin R-hat from the extended chain state.\n",
"# Values close to 1.0 indicate good convergence.\n",
"print('Gelman-Rubin R-hat (extended chain) — values < 1.05 indicate convergence:')\n",
"rhat = extended['internal_bumps_object'].gelman()\n",
"for name, val in zip(extended['param_names'], rhat):\n",
" status = '✓' if val < 1.05 else '?' if val < 1.1 else '✗'\n",
" print(f' {name:<20s} {val:.4f} {status}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "p312",
"display_name": ".venv (3.11.9)",
"language": "python",
"name": "python3"
},
Expand All @@ -510,7 +719,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Loading
Loading