|
47 | 47 | " + an [xarray.Dataset](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html)\n", |
48 | 48 | "\n", |
49 | 49 | "\n", |
50 | | - "In addtion, the `CmdStanMCMC` object has accessor methods for\n", |
| 50 | + "In addition, the `CmdStanMCMC` object has accessor methods for\n", |
51 | 51 | "\n", |
52 | 52 | "- The per-chain HMC tuning parameters `step_size` and `metric` \n", |
53 | 53 | "\n", |
|
168 | 168 | "Your model make take a long time to fit. The `sample` method provides two arguments:\n", |
169 | 169 | " \n", |
170 | 170 | "- visual progress bar: `show_progress=True`\n", |
171 | | - "- stream CmdStan ouput to the console - `show_console=True`\n", |
| 171 | + "- stream CmdStan output to the console - `show_console=True`\n", |
172 | 172 | " \n", |
173 | 173 | "To illustrate how progress bars work, we will run the bernoulli model. Since the progress bars are only visible while the sampler is running and the bernoulli model takes no time at all to fit, we run this model for 200K iterations, in order to see the progress bars in action." |
174 | 174 | ] |
|
205 | 205 | "cell_type": "markdown", |
206 | 206 | "metadata": {}, |
207 | 207 | "source": [ |
208 | | - "## Accessing the sampler outputs" |
| 208 | + "## Checking the fit\n", |
| 209 | + "\n", |
| 210 | + "The first question to ask of the `CmdStanMCMC` object is: _is this a valid sample from the posterior?_" |
209 | 211 | ] |
210 | 212 | }, |
211 | 213 | { |
212 | 214 | "cell_type": "markdown", |
213 | 215 | "metadata": {}, |
214 | 216 | "source": [ |
215 | | - "### Extracting the draws in tabular format\n", |
| 217 | + "It is important to check whether or not the sampler was able to fit the model given the data. Often, this is not possible, for any number of reasons.\n", |
| 218 | + "To appreciate the sampler diagnostics, we use a hierarchical model which, given a small amount of data, encounters difficulty: the centered parameterization of the \n", |
| 219 | + "\"8-schools\" model (Rubin, 1981).\n", |
| 220 | + "The \"8-schools\" model is a simple hierarchical model, first developed on a dataset taken from\n", |
| 221 | + "an experiment was conducted in 8 schools, with only treatment effects and their standard errors reported.\n", |
216 | 222 | "\n", |
217 | | - "The sample can be accessed either as a `numpy` array or a pandas `DataFrame`:" |
| 223 | + "The Stan model and the original dataset are in files `eight_schools.stan` and `eight_schools.data.json`." |
218 | 224 | ] |
219 | 225 | }, |
220 | 226 | { |
221 | | - "cell_type": "code", |
222 | | - "execution_count": null, |
| 227 | + "cell_type": "markdown", |
223 | 228 | "metadata": {}, |
224 | | - "outputs": [], |
225 | 229 | "source": [ |
226 | | - "print(f'sample as ndarray: {fit.draws().shape}\\nfirst 2 draws, chain 1:\\n{fit.draws()[:2, 0, :]}')" |
| 230 | + "**eight_schools.stan**" |
227 | 231 | ] |
228 | 232 | }, |
229 | 233 | { |
|
232 | 236 | "metadata": {}, |
233 | 237 | "outputs": [], |
234 | 238 | "source": [ |
235 | | - "fit.draws_pd().head()" |
| 239 | + "with open('eight_schools.stan', 'r') as fd:\n", |
| 240 | + " print(fd.read())" |
236 | 241 | ] |
237 | 242 | }, |
238 | 243 | { |
239 | 244 | "cell_type": "markdown", |
240 | 245 | "metadata": {}, |
241 | 246 | "source": [ |
242 | | - "### Extracting the draws as structured Stan program variables" |
| 247 | + "**eight_schools.data.json**" |
243 | 248 | ] |
244 | 249 | }, |
245 | 250 | { |
|
248 | 253 | "metadata": {}, |
249 | 254 | "outputs": [], |
250 | 255 | "source": [ |
251 | | - "for k, v in fit.stan_variables().items():\n", |
252 | | - " print(f'name: {k}, shape: {v.shape}')" |
| 256 | + "with open('eight_schools.data.json', 'r') as fd:\n", |
| 257 | + " print(fd.read())" |
253 | 258 | ] |
254 | 259 | }, |
255 | 260 | { |
256 | | - "cell_type": "code", |
257 | | - "execution_count": null, |
| 261 | + "cell_type": "markdown", |
258 | 262 | "metadata": {}, |
259 | | - "outputs": [], |
260 | 263 | "source": [ |
261 | | - "fit.draws_xr('theta')" |
| 264 | + "Because there is not much data, the geometry of posterior distribution is highly curved, \n", |
| 265 | + "thus the sampler may encounter difficulty in fitting the model.\n", |
| 266 | + "By specifying the initial seed for the pseudo-random number generator,\n", |
| 267 | + "we insure that the sampler will have difficulty in fitting this model.\n", |
| 268 | + "In particular, some post-warmup iterations diverge, resulting in a biased sample.\n", |
| 269 | + "In addition, some post-warmup iterations hit the maximum allowed treedepth before\n", |
| 270 | + "the trajectory hits the \"U-turn\" condition of the NUTS algorithm,\n", |
| 271 | + "in which case the sampler may fail to properly explore the entire posterior.\n", |
| 272 | + "\n", |
| 273 | + "These diagnostics are checked for automatically at the end of each run; if problems are detected, a WARNING message is logged." |
262 | 274 | ] |
263 | 275 | }, |
264 | 276 | { |
265 | | - "cell_type": "markdown", |
| 277 | + "cell_type": "code", |
| 278 | + "execution_count": null, |
266 | 279 | "metadata": {}, |
| 280 | + "outputs": [], |
267 | 281 | "source": [ |
268 | | - "### Extracting sampler method diagnostics" |
| 282 | + "eight_schools_model = CmdStanModel(stan_file='eight_schools.stan')\n", |
| 283 | + "eight_schools_fit = eight_schools_model.sample(data='eight_schools.data.json', seed=55157)" |
269 | 284 | ] |
270 | 285 | }, |
271 | 286 | { |
272 | | - "cell_type": "code", |
273 | | - "execution_count": null, |
| 287 | + "cell_type": "markdown", |
274 | 288 | "metadata": {}, |
275 | | - "outputs": [], |
276 | 289 | "source": [ |
277 | | - "for k, v in fit.method_variables().items():\n", |
278 | | - " print(f'name: {k}, shape: {v.shape}')" |
| 290 | + "The number of post-warmup divergences and iterations which hit the maximum treedepth can be inspected directly via properties `divergences` and `max_treedepths`." |
279 | 291 | ] |
280 | 292 | }, |
281 | 293 | { |
|
284 | 296 | "metadata": {}, |
285 | 297 | "outputs": [], |
286 | 298 | "source": [ |
287 | | - "print(f'divergences per chain?\\n{fit.divergences}\\niterations at maxtreedepth per chain?\\n{fit.max_treedepths}')" |
| 299 | + "print(f'divergences:\\n{eight_schools_fit.divergences}\\niterations at max_treedepth:\\n{eight_schools_fit.max_treedepths}')" |
288 | 300 | ] |
289 | 301 | }, |
290 | 302 | { |
291 | 303 | "cell_type": "markdown", |
292 | 304 | "metadata": {}, |
293 | 305 | "source": [ |
294 | | - "### Extracting the per-chain HMC tuning parameters" |
| 306 | + "### Summarizing the sample\n", |
| 307 | + "\n", |
| 308 | + "The `summary` method reports the R-hat statistic, a measure of how well the sampler chains have converged." |
295 | 309 | ] |
296 | 310 | }, |
297 | 311 | { |
298 | 312 | "cell_type": "code", |
299 | 313 | "execution_count": null, |
300 | | - "metadata": {}, |
| 314 | + "metadata": { |
| 315 | + "scrolled": true |
| 316 | + }, |
301 | 317 | "outputs": [], |
302 | 318 | "source": [ |
303 | | - "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\nmetric:\\n{fit.metric}')" |
| 319 | + "fit.summary()" |
304 | 320 | ] |
305 | 321 | }, |
306 | 322 | { |
307 | 323 | "cell_type": "markdown", |
308 | 324 | "metadata": {}, |
309 | 325 | "source": [ |
310 | | - "### Extracting the sample meta-data" |
| 326 | + "### Sampler Diagnostics\n", |
| 327 | + "\n", |
| 328 | + "The `diagnose()` method provides more information about the sample." |
311 | 329 | ] |
312 | 330 | }, |
313 | 331 | { |
|
316 | 334 | "metadata": {}, |
317 | 335 | "outputs": [], |
318 | 336 | "source": [ |
319 | | - "print('sample method variables:\\n{}\\n'.format(fit.metadata.method_vars_cols.keys()))\n", |
320 | | - "print('stan model variables:\\n{}'.format(fit.metadata.stan_vars_cols.keys()))" |
| 337 | + "print(eight_schools_fit.diagnose())" |
321 | 338 | ] |
322 | 339 | }, |
323 | 340 | { |
324 | 341 | "cell_type": "markdown", |
325 | 342 | "metadata": {}, |
326 | 343 | "source": [ |
327 | | - "## Summarizing the sample" |
| 344 | + "## Accessing the sampler outputs" |
| 345 | + ] |
| 346 | + }, |
| 347 | + { |
| 348 | + "cell_type": "markdown", |
| 349 | + "metadata": {}, |
| 350 | + "source": [ |
| 351 | + "### Extracting the draws in tabular format\n", |
| 352 | + "\n", |
| 353 | + "The sample can be accessed either as a `numpy` array or a pandas `DataFrame`:" |
328 | 354 | ] |
329 | 355 | }, |
330 | 356 | { |
331 | 357 | "cell_type": "code", |
332 | 358 | "execution_count": null, |
333 | | - "metadata": { |
334 | | - "scrolled": true |
335 | | - }, |
| 359 | + "metadata": {}, |
336 | 360 | "outputs": [], |
337 | 361 | "source": [ |
338 | | - "fit.summary()" |
| 362 | + "print(f'sample as ndarray: {fit.draws().shape}\\nfirst 2 draws, chain 1:\\n{fit.draws()[:2, 0, :]}')" |
339 | 363 | ] |
340 | 364 | }, |
341 | 365 | { |
342 | | - "cell_type": "markdown", |
| 366 | + "cell_type": "code", |
| 367 | + "execution_count": null, |
343 | 368 | "metadata": {}, |
| 369 | + "outputs": [], |
344 | 370 | "source": [ |
345 | | - "## Sampler Diagnostics" |
| 371 | + "fit.draws_pd().head()" |
346 | 372 | ] |
347 | 373 | }, |
348 | 374 | { |
349 | 375 | "cell_type": "markdown", |
350 | 376 | "metadata": {}, |
351 | 377 | "source": [ |
352 | | - "It is important to check whether or not the sampler was able to fit the model given the data. Often, this is not possible, for any number of reasons.\n", |
353 | | - "To appreciate the sampler diagnostics, we use a hierachical model which, given a small amount of data, encounters difficulty: the centered parameterization of the \n", |
354 | | - "\"8-schools\" model (Rubin, 1981).\n", |
355 | | - "The \"8-schools\" model is a simple hiearchical model, first developed on a dataset taken from\n", |
356 | | - "an experiment was conducted in 8 schools, with only treatment effects and their standard errors reported.\n", |
357 | | - "\n", |
358 | | - "The Stan model and the original dataset are in files `eight_schools.stan` and `eight_schools.data.json`." |
| 378 | + "### Extracting the draws as structured Stan program variables" |
359 | 379 | ] |
360 | 380 | }, |
361 | 381 | { |
362 | | - "cell_type": "markdown", |
| 382 | + "cell_type": "code", |
| 383 | + "execution_count": null, |
363 | 384 | "metadata": {}, |
| 385 | + "outputs": [], |
364 | 386 | "source": [ |
365 | | - "**eight_schools.stan**" |
| 387 | + "for k, v in fit.stan_variables().items():\n", |
| 388 | + " print(f'name: {k}, shape: {v.shape}')" |
366 | 389 | ] |
367 | 390 | }, |
368 | 391 | { |
|
371 | 394 | "metadata": {}, |
372 | 395 | "outputs": [], |
373 | 396 | "source": [ |
374 | | - "with open('eight_schools.stan', 'r') as fd:\n", |
375 | | - " print(fd.read())" |
| 397 | + "fit.draws_xr('theta')" |
376 | 398 | ] |
377 | 399 | }, |
378 | 400 | { |
379 | 401 | "cell_type": "markdown", |
380 | 402 | "metadata": {}, |
381 | 403 | "source": [ |
382 | | - "**eight_schools.data.json**" |
| 404 | + "### Extracting sampler method diagnostics" |
383 | 405 | ] |
384 | 406 | }, |
385 | 407 | { |
|
388 | 410 | "metadata": {}, |
389 | 411 | "outputs": [], |
390 | 412 | "source": [ |
391 | | - "with open('eight_schools.data.json', 'r') as fd:\n", |
392 | | - " print(fd.read())" |
| 413 | + "for k, v in fit.method_variables().items():\n", |
| 414 | + " print(f'name: {k}, shape: {v.shape}')" |
| 415 | + ] |
| 416 | + }, |
| 417 | + { |
| 418 | + "cell_type": "code", |
| 419 | + "execution_count": null, |
| 420 | + "metadata": {}, |
| 421 | + "outputs": [], |
| 422 | + "source": [ |
| 423 | + "print(f'divergences per chain?\\n{fit.divergences}\\niterations at maxtreedepth per chain?\\n{fit.max_treedepths}')" |
393 | 424 | ] |
394 | 425 | }, |
395 | 426 | { |
396 | 427 | "cell_type": "markdown", |
397 | 428 | "metadata": {}, |
398 | 429 | "source": [ |
399 | | - "Because there is not much data, the geometry of posterior distribution is highly curved, \n", |
400 | | - "thus the sampler may encouter difficulty in fitting the model.\n", |
401 | | - "By specifying the initial seed for the psuedo-random number generator,\n", |
402 | | - "we insure that the sampler will have difficulty in fitting this model.\n", |
403 | | - "In particular, some post-warmup iterations diverge, resulting in a biased sample.\n", |
404 | | - "In addition, some post-warmup iterations hit the maximum allowed treedepth before\n", |
405 | | - "the trajectory hits the \"U-turn\" condition of the NUTS algorithm,\n", |
406 | | - "in which case the sampler may fail to properly explore the entire posterior.\n", |
407 | | - "\n", |
408 | | - "If any iterations diverged or hit the maximum treedepth, these are reported, along with the\n", |
409 | | - "recommendation to run the `diagnose()` method, which provides more information about the sample." |
| 430 | + "### Extracting the per-chain HMC tuning parameters" |
410 | 431 | ] |
411 | 432 | }, |
412 | 433 | { |
|
415 | 436 | "metadata": {}, |
416 | 437 | "outputs": [], |
417 | 438 | "source": [ |
418 | | - "eight_schools_model = CmdStanModel(stan_file='eight_schools.stan')\n", |
419 | | - "eight_schools_fit = eight_schools_model.sample(data='eight_schools.data.json', seed=55157)" |
| 439 | + "print(f'adapted step_size per chain\\n{fit.step_size}\\nmetric_type: {fit.metric_type}\\nmetric:\\n{fit.metric}')" |
| 440 | + ] |
| 441 | + }, |
| 442 | + { |
| 443 | + "cell_type": "markdown", |
| 444 | + "metadata": {}, |
| 445 | + "source": [ |
| 446 | + "### Extracting the sample meta-data" |
420 | 447 | ] |
421 | 448 | }, |
422 | 449 | { |
|
425 | 452 | "metadata": {}, |
426 | 453 | "outputs": [], |
427 | 454 | "source": [ |
428 | | - "print(eight_schools_fit.diagnose())" |
| 455 | + "print('sample method variables:\\n{}\\n'.format(fit.metadata.method_vars_cols.keys()))\n", |
| 456 | + "print('stan model variables:\\n{}'.format(fit.metadata.stan_vars_cols.keys()))" |
429 | 457 | ] |
430 | 458 | }, |
431 | 459 | { |
|
0 commit comments