Skip to content

Commit 0857b52

Browse files
committed
Doc updates
1 parent c290397 commit 0857b52

3 files changed

Lines changed: 19 additions & 4 deletions

File tree

cmdstanpy/stanfit/vb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ def stan_variable(self, var: str) -> np.ndarray:
121121
a leading axis added for the number of draws from the variational
122122
approximation.
123123
124+
* If the variable is a scalar variable, the return array has shape
125+
( draws, ).
126+
* If the variable is a vector, the return array has shape
127+
( draws, len(vector))
128+
* If the variable is a matrix, the return array has shape
129+
( draws, size(dim 1), size(dim 2) )
130+
* If the variable is an array with N dimensions, the return array
131+
has shape ( draws, size(dim 1), ..., size(dim N))
132+
124133
This functionaltiy is also available via a shortcut using ``.`` -
125134
writing ``fit.a`` is a synonym for ``fit.stan_variable("a")``
126135

docsrc/api.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ CmdStanVB
6262
.. autoclass:: cmdstanpy.CmdStanVB
6363
:members:
6464

65+
CmdStanLaplace
66+
==============
67+
68+
.. autoclass:: cmdstanpy.CmdStanLaplace
69+
:members:
6570

6671
*********
6772
Functions

docsrc/users-guide/examples/VI as Sampler Inits.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@
6464
"The ADVI algorithm provides estimates of all model parameters.\n",
6565
"\n",
6666
"The `variational` method returns a `CmdStanVB` object, with method `stan_variables`, which\n",
67-
"returns the approximate estimates of all model parameters as a Python dictionary."
67+
"returns the approximat posterior samples of all model parameters as a Python dictionary. \n",
68+
"Here, we report the approximate posterior mean."
6869
]
6970
},
7071
{
@@ -73,7 +74,8 @@
7374
"metadata": {},
7475
"outputs": [],
7576
"source": [
76-
"print(vb_fit.stan_variables())"
77+
"vb_mean = {var: samples.mean(axis=0) for var, samples in vb_fit.stan_variables().items()}\n",
78+
"print(vb_mean)"
7779
]
7880
},
7981
{
@@ -93,9 +95,8 @@
9395
"metadata": {},
9496
"outputs": [],
9597
"source": [
96-
"vb_vars = vb_fit.stan_variables()\n",
9798
"mcmc_vb_inits_fit = model.sample(\n",
98-
" data=data_file, inits=vb_vars, iter_warmup=75, seed=12345\n",
99+
" data=data_file, inits=vb_mean, iter_warmup=75, seed=12345\n",
99100
")"
100101
]
101102
},

0 commit comments

Comments
 (0)