Skip to content

Commit b1aff80

Browse files
authored
Merge pull request #669 from stan-dev/stan-2-32
[Stan 2.32] Laplace method and other changes
2 parents 730742a + 3514acf commit b1aff80

10 files changed

Lines changed: 605 additions & 44 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Method(Enum):
2929
OPTIMIZE = auto()
3030
GENERATE_QUANTITIES = auto()
3131
VARIATIONAL = auto()
32+
LAPLACE = auto()
3233

3334
def __repr__(self) -> str:
3435
return '<%s.%s>' % (self.__class__.__name__, self.name)
@@ -398,8 +399,8 @@ def __init__(
398399
tol_rel_grad: Optional[float] = None,
399400
tol_param: Optional[float] = None,
400401
history_size: Optional[int] = None,
402+
jacobian: bool = False,
401403
) -> None:
402-
403404
self.algorithm = algorithm or ""
404405
self.init_alpha = init_alpha
405406
self.iter = iter
@@ -410,11 +411,10 @@ def __init__(
410411
self.tol_rel_grad = tol_rel_grad
411412
self.tol_param = tol_param
412413
self.history_size = history_size
414+
self.jacobian = jacobian
413415
self.thin = None
414416

415-
def validate(
416-
self, chains: Optional[int] = None # pylint: disable=unused-argument
417-
) -> None:
417+
def validate(self, _chains: Optional[int] = None) -> None:
418418
"""
419419
Check arguments correctness and consistency.
420420
"""
@@ -511,8 +511,7 @@ def validate(
511511
else:
512512
raise ValueError('history_size must be type of int')
513513

514-
# pylint: disable=unused-argument
515-
def compose(self, idx: int, cmd: List[str]) -> List[str]:
514+
def compose(self, _idx: int, cmd: List[str]) -> List[str]:
516515
"""compose command string for CmdStan for non-default arg values."""
517516
cmd.append('method=optimize')
518517
if self.algorithm:
@@ -535,7 +534,37 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
535534
cmd.append('iter={}'.format(self.iter))
536535
if self.save_iterations:
537536
cmd.append('save_iterations=1')
537+
if self.jacobian:
538+
cmd.append("jacobian=1")
539+
return cmd
540+
541+
542+
class LaplaceArgs:
543+
"""Arguments needed for laplace method."""
538544

545+
def __init__(
546+
self, mode: str, draws: Optional[int] = None, jacobian: bool = True
547+
) -> None:
548+
self.mode = mode
549+
self.jacobian = jacobian
550+
self.draws = draws
551+
552+
def validate(self, _chains: Optional[int] = None) -> None:
553+
"""Check arguments correctness and consistency."""
554+
if not os.path.exists(self.mode):
555+
raise ValueError(f'Invalid path for mode file: {self.mode}')
556+
if self.draws is not None:
557+
if not isinstance(self.draws, (int, np.integer)) or self.draws <= 0:
558+
raise ValueError('draws must be a positive integer')
559+
560+
def compose(self, _idx: int, cmd: List[str]) -> List[str]:
561+
"""compose command string for CmdStan for non-default arg values."""
562+
cmd.append('method=laplace')
563+
cmd.append(f'mode={self.mode}')
564+
if self.draws:
565+
cmd.append(f'draws={self.draws}')
566+
if not self.jacobian:
567+
cmd.append("jacobian=0")
539568
return cmd
540569

541570

@@ -721,7 +750,11 @@ def __init__(
721750
model_exe: OptionalPath,
722751
chain_ids: Optional[List[int]],
723752
method_args: Union[
724-
SamplerArgs, OptimizeArgs, GenerateQuantitiesArgs, VariationalArgs
753+
SamplerArgs,
754+
OptimizeArgs,
755+
GenerateQuantitiesArgs,
756+
VariationalArgs,
757+
LaplaceArgs,
725758
],
726759
data: Union[Mapping[str, Any], str, None] = None,
727760
seed: Union[int, List[int], None] = None,
@@ -753,6 +786,8 @@ def __init__(
753786
self.method = Method.GENERATE_QUANTITIES
754787
elif isinstance(method_args, VariationalArgs):
755788
self.method = Method.VARIATIONAL
789+
elif isinstance(method_args, LaplaceArgs):
790+
self.method = Method.LAPLACE
756791
self.method_args.validate(len(chain_ids) if chain_ids else None)
757792
self.validate()
758793

@@ -913,7 +948,7 @@ def compose_command(
913948
*,
914949
diagnostic_file: Optional[str] = None,
915950
profile_file: Optional[str] = None,
916-
num_chains: Optional[int] = None
951+
num_chains: Optional[int] = None,
917952
) -> List[str]:
918953
"""
919954
Compose CmdStan command for non-default arguments.

cmdstanpy/model.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,16 @@
4040
from cmdstanpy.cmdstan_args import (
4141
CmdStanArgs,
4242
GenerateQuantitiesArgs,
43+
LaplaceArgs,
44+
Method,
4345
OptimizeArgs,
4446
SamplerArgs,
4547
VariationalArgs,
4648
)
4749
from cmdstanpy.compiler_opts import CompilerOptions
4850
from cmdstanpy.stanfit import (
4951
CmdStanGQ,
52+
CmdStanLaplace,
5053
CmdStanMCMC,
5154
CmdStanMLE,
5255
CmdStanVB,
@@ -393,7 +396,7 @@ def format(
393396
+ '.bak-'
394397
+ datetime.now().strftime("%Y%m%d%H%M%S"),
395398
)
396-
with (open(self.stan_file, 'w')) as file_handle:
399+
with open(self.stan_file, 'w') as file_handle:
397400
file_handle.write(result)
398401
else:
399402
print(result)
@@ -589,6 +592,8 @@ def optimize(
589592
refresh: Optional[int] = None,
590593
time_fmt: str = "%Y%m%d%H%M%S",
591594
timeout: Optional[float] = None,
595+
jacobian: bool = False,
596+
# would be nice to move this further up, but that's a breaking change
592597
) -> CmdStanMLE:
593598
"""
594599
Run the specified CmdStan optimize algorithm to produce a
@@ -690,6 +695,11 @@ def optimize(
690695
691696
:param timeout: Duration at which optimization times out in seconds.
692697
698+
:param jacobian: Whether or not to use the Jacobian adjustment for
699+
constrained variables in optimization. By default this is false,
700+
meaning optimization yields the Maximum Likehood Estimate (MLE).
701+
Setting it to true yields the Maximum A Posteriori Estimate (MAP).
702+
693703
:return: CmdStanMLE object
694704
"""
695705
optimize_args = OptimizeArgs(
@@ -703,8 +713,15 @@ def optimize(
703713
history_size=history_size,
704714
iter=iter,
705715
save_iterations=save_iterations,
716+
jacobian=jacobian,
706717
)
707718

719+
if jacobian and cmdstan_version_before(2, 32, self.exe_info()):
720+
raise ValueError(
721+
"Jacobian adjustment for optimization is only supported "
722+
"in CmdStan 2.32 and above."
723+
)
724+
708725
with MaybeDictToFilePath(data, inits) as (_data, _inits):
709726
args = CmdStanArgs(
710727
self._name,
@@ -1606,6 +1623,8 @@ def log_prob(
16061623
self,
16071624
params: Union[Dict[str, Any], str, os.PathLike],
16081625
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
1626+
*,
1627+
jacobian: bool = True,
16091628
) -> pd.DataFrame:
16101629
"""
16111630
Calculate the log probability and gradient at the given parameter
@@ -1626,6 +1645,9 @@ def log_prob(
16261645
either as a dictionary with entries matching the data variables,
16271646
or as the path of a data file in JSON or Rdump format.
16281647
1648+
:param jacobian: Whether or not to enable the Jacobian adjustment
1649+
for constrained parameters. Defaults to ``True``.
1650+
16291651
:return: A pandas.DataFrame containing columns "lp__" and additional
16301652
columns for the gradient values. These gradients will be for the
16311653
unconstrained parameters of the model.
@@ -1641,6 +1663,7 @@ def log_prob(
16411663
str(self.exe_file),
16421664
"log_prob",
16431665
f"constrained_params={_params}",
1666+
f"jacobian={int(jacobian)}",
16441667
]
16451668
if _data is not None:
16461669
cmd += ["data", f"file={_data}"]
@@ -1669,6 +1692,104 @@ def log_prob(
16691692
result = pd.read_csv(output, comment="#")
16701693
return result
16711694

1695+
def laplace_sample(
1696+
self,
1697+
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
1698+
mode: Union[CmdStanMLE, str, os.PathLike, None] = None,
1699+
draws: Optional[int] = None,
1700+
*,
1701+
jacobian: bool = True, # NB: Different than optimize!
1702+
seed: Optional[int] = None,
1703+
output_dir: OptionalPath = None,
1704+
sig_figs: Optional[int] = None,
1705+
save_profile: bool = False,
1706+
show_console: bool = False,
1707+
refresh: Optional[int] = None,
1708+
time_fmt: str = "%Y%m%d%H%M%S",
1709+
timeout: Optional[float] = None,
1710+
opt_args: Optional[Dict[str, Any]] = None,
1711+
) -> CmdStanLaplace:
1712+
if cmdstan_version_before(2, 32, self.exe_info()):
1713+
raise ValueError(
1714+
"Method 'laplace_sample' not available for CmdStan versions "
1715+
"before 2.32"
1716+
)
1717+
if opt_args is not None and mode is not None:
1718+
raise ValueError(
1719+
"Cannot specify both 'opt_args' and 'mode' arguments"
1720+
)
1721+
if mode is None:
1722+
optimize_args = {
1723+
"seed": seed,
1724+
"sig_figs": sig_figs,
1725+
"jacobian": jacobian,
1726+
"save_profile": save_profile,
1727+
"show_console": show_console,
1728+
"refresh": refresh,
1729+
"time_fmt": time_fmt,
1730+
"timeout": timeout,
1731+
"output_dir": output_dir,
1732+
}
1733+
optimize_args.update(opt_args or {})
1734+
optimize_args['time_fmt'] = 'opt-' + time_fmt
1735+
try:
1736+
cmdstan_mode: CmdStanMLE = self.optimize(
1737+
data=data,
1738+
**optimize_args, # type: ignore
1739+
)
1740+
except Exception as e:
1741+
raise RuntimeError(
1742+
"Failed to run optimizer on model. "
1743+
"Consider supplying a mode or additional optimizer args"
1744+
) from e
1745+
elif not isinstance(mode, CmdStanMLE):
1746+
cmdstan_mode = from_csv(mode) # type: ignore # we check below
1747+
else:
1748+
cmdstan_mode = mode
1749+
1750+
if cmdstan_mode.runset.method != Method.OPTIMIZE:
1751+
raise ValueError(
1752+
"Mode must be a CmdStanMLE or a path to an optimize CSV"
1753+
)
1754+
1755+
mode_jacobian = (
1756+
cmdstan_mode.runset._args.method_args.jacobian # type: ignore
1757+
)
1758+
if mode_jacobian != jacobian:
1759+
raise ValueError(
1760+
"Jacobian argument to optimize and laplace must match!\n"
1761+
f"Laplace was run with jacobian={jacobian},\n"
1762+
f"but optimize was run with jacobian={mode_jacobian}"
1763+
)
1764+
1765+
laplace_args = LaplaceArgs(
1766+
cmdstan_mode.runset.csv_files[0], draws, jacobian
1767+
)
1768+
1769+
with MaybeDictToFilePath(data) as (_data,):
1770+
args = CmdStanArgs(
1771+
self._name,
1772+
self._exe_file,
1773+
chain_ids=None,
1774+
data=_data,
1775+
seed=seed,
1776+
output_dir=output_dir,
1777+
sig_figs=sig_figs,
1778+
save_profile=save_profile,
1779+
method_args=laplace_args,
1780+
refresh=refresh,
1781+
)
1782+
dummy_chain_id = 0
1783+
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
1784+
self._run_cmdstan(
1785+
runset,
1786+
dummy_chain_id,
1787+
show_console=show_console,
1788+
timeout=timeout,
1789+
)
1790+
runset.raise_for_timeouts()
1791+
return CmdStanLaplace(runset, cmdstan_mode)
1792+
16721793
def _run_cmdstan(
16731794
self,
16741795
runset: RunSet,

cmdstanpy/stanfit/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
1414

1515
from .gq import CmdStanGQ
16+
from .laplace import CmdStanLaplace
1617
from .mcmc import CmdStanMCMC
1718
from .metadata import InferenceMetadata
1819
from .mle import CmdStanMLE
@@ -26,6 +27,7 @@
2627
"CmdStanMLE",
2728
"CmdStanVB",
2829
"CmdStanGQ",
30+
"CmdStanLaplace",
2931
]
3032

3133

@@ -143,7 +145,7 @@ def from_csv(
143145
save_warmup=config_dict['save_warmup'],
144146
fixed_param=True,
145147
)
146-
except (ValueError) as e:
148+
except ValueError as e:
147149
raise ValueError(
148150
'Invalid or corrupt Stan CSV output file, '
149151
) from e
@@ -170,6 +172,7 @@ def from_csv(
170172
optimize_args = OptimizeArgs(
171173
algorithm=config_dict['algorithm'],
172174
save_iterations=config_dict['save_iterations'],
175+
jacobian=config_dict.get('jacobian', 0),
173176
)
174177
cmdstan_args = CmdStanArgs(
175178
model_name=config_dict['model'],

cmdstanpy/stanfit/gq.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232

3333
from cmdstanpy.cmdstan_args import Method
3434
from cmdstanpy.utils import (
35-
BaseType,
3635
build_xarray_data,
3736
flatten_chains,
3837
get_logger,
3938
scan_generated_quantities_csv,
4039
)
40+
from cmdstanpy.utils.data_munging import extract_reshape
4141

4242
from .mcmc import CmdStanMCMC
4343
from .metadata import InferenceMetadata
@@ -586,21 +586,17 @@ def stan_variable(
586586

587587
# is gq variable
588588
self._assemble_generated_quantities()
589-
590589
draw1, num_draws = self._draws_start(inc_warmup)
591-
dims = [num_draws * self.chains]
590+
dims = (num_draws * self.chains,)
592591
col_idxs = self._metadata.stan_vars_cols[var]
593-
if len(col_idxs) > 0:
594-
dims.extend(self._metadata.stan_vars_dims[var])
595-
draws = self._draws[draw1:, :, col_idxs]
596-
597-
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
598-
draws = draws[..., ::2] + 1j * draws[..., 1::2]
599-
dims = dims[:-1]
600592

601-
draws = draws.reshape(dims, order='F')
602-
603-
return draws
593+
return extract_reshape(
594+
dims=dims + self._metadata.stan_vars_dims[var],
595+
col_idxs=col_idxs,
596+
var_type=self._metadata.stan_vars_types[var],
597+
start_row=draw1,
598+
draws_in=self._draws,
599+
)
604600

605601
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
606602
"""

0 commit comments

Comments
 (0)