Skip to content

Commit 0cbf6a2

Browse files
authored
Merge pull request #505 from stan-dev/stanfit-reorg
Reorganize stanfit into smaller files
2 parents 0b50ae9 + d141409 commit 0cbf6a2

9 files changed

Lines changed: 1039 additions & 974 deletions

File tree

cmdstanpy/stanfit/__init__.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Container objects for results of CmdStan run(s)."""
2+
3+
import glob
4+
import os
5+
from typing import Any, Dict, List, Optional, Union
6+
7+
from cmdstanpy.cmdstan_args import (
8+
CmdStanArgs,
9+
OptimizeArgs,
10+
SamplerArgs,
11+
VariationalArgs,
12+
)
13+
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
14+
15+
from .mcmc import CmdStanGQ, CmdStanMCMC
16+
from .metadata import InferenceMetadata
17+
from .mle import CmdStanMLE
18+
from .runset import RunSet
19+
from .vb import CmdStanVB
20+
21+
__all__ = [
22+
"RunSet",
23+
"InferenceMetadata",
24+
"CmdStanMCMC",
25+
"CmdStanMLE",
26+
"CmdStanVB",
27+
"CmdStanGQ",
28+
]
29+
30+
31+
def from_csv(
32+
path: Union[str, List[str], None] = None, method: Optional[str] = None
33+
) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB, None]:
34+
"""
35+
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
36+
CSV files are specified from either a list of Stan CSV files or a single
37+
filepath which can be either a directory name, a Stan CSV filename, or
38+
a pathname pattern (i.e., a Python glob). The optional argument 'method'
39+
checks that the CSV files were produced by that method.
40+
Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational'
41+
result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB,
42+
respectively.
43+
44+
:param path: directory path
45+
:param method: method name (optional)
46+
47+
:return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object
48+
"""
49+
if path is None:
50+
raise ValueError('Must specify path to Stan CSV files.')
51+
if method is not None and method not in [
52+
'sample',
53+
'optimize',
54+
'variational',
55+
]:
56+
raise ValueError(
57+
'Bad method argument {}, must be one of: '
58+
'"sample", "optimize", "variational"'.format(method)
59+
)
60+
61+
csvfiles = []
62+
if isinstance(path, list):
63+
csvfiles = path
64+
elif isinstance(path, str):
65+
if '*' in path:
66+
splits = os.path.split(path)
67+
if splits[0] is not None:
68+
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69+
raise ValueError(
70+
'Invalid path specification, {} '
71+
' unknown directory: {}'.format(path, splits[0])
72+
)
73+
csvfiles = glob.glob(path)
74+
elif os.path.exists(path) and os.path.isdir(path):
75+
for file in os.listdir(path):
76+
if file.endswith(".csv"):
77+
csvfiles.append(os.path.join(path, file))
78+
elif os.path.exists(path):
79+
csvfiles.append(path)
80+
else:
81+
raise ValueError('Invalid path specification: {}'.format(path))
82+
else:
83+
raise ValueError('Invalid path specification: {}'.format(path))
84+
85+
if len(csvfiles) == 0:
86+
raise ValueError('No CSV files found in directory {}'.format(path))
87+
for file in csvfiles:
88+
if not (os.path.exists(file) and file.endswith('.csv')):
89+
raise ValueError(
90+
'Bad CSV file path spec,'
91+
' includes non-csv file: {}'.format(file)
92+
)
93+
94+
config_dict: Dict[str, Any] = {}
95+
try:
96+
with open(csvfiles[0], 'r') as fd:
97+
scan_config(fd, config_dict, 0)
98+
except (IOError, OSError, PermissionError) as e:
99+
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
100+
if 'model' not in config_dict or 'method' not in config_dict:
101+
raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0]))
102+
if method is not None and method != config_dict['method']:
103+
raise ValueError(
104+
'Expecting Stan CSV output files from method {}, '
105+
' found outputs from method {}'.format(
106+
method, config_dict['method']
107+
)
108+
)
109+
try:
110+
if config_dict['method'] == 'sample':
111+
chains = len(csvfiles)
112+
sampler_args = SamplerArgs(
113+
iter_sampling=config_dict['num_samples'],
114+
iter_warmup=config_dict['num_warmup'],
115+
thin=config_dict['thin'],
116+
save_warmup=config_dict['save_warmup'],
117+
)
118+
# bugfix 425, check for fixed_params output
119+
try:
120+
check_sampler_csv(
121+
csvfiles[0],
122+
iter_sampling=config_dict['num_samples'],
123+
iter_warmup=config_dict['num_warmup'],
124+
thin=config_dict['thin'],
125+
save_warmup=config_dict['save_warmup'],
126+
)
127+
except ValueError:
128+
try:
129+
check_sampler_csv(
130+
csvfiles[0],
131+
is_fixed_param=True,
132+
iter_sampling=config_dict['num_samples'],
133+
iter_warmup=config_dict['num_warmup'],
134+
thin=config_dict['thin'],
135+
save_warmup=config_dict['save_warmup'],
136+
)
137+
sampler_args = SamplerArgs(
138+
iter_sampling=config_dict['num_samples'],
139+
iter_warmup=config_dict['num_warmup'],
140+
thin=config_dict['thin'],
141+
save_warmup=config_dict['save_warmup'],
142+
fixed_param=True,
143+
)
144+
except (ValueError) as e:
145+
raise ValueError(
146+
'Invalid or corrupt Stan CSV output file, '
147+
) from e
148+
149+
cmdstan_args = CmdStanArgs(
150+
model_name=config_dict['model'],
151+
model_exe=config_dict['model'],
152+
chain_ids=[x + 1 for x in range(chains)],
153+
method_args=sampler_args,
154+
)
155+
runset = RunSet(args=cmdstan_args, chains=chains)
156+
runset._csv_files = csvfiles
157+
for i in range(len(runset._retcodes)):
158+
runset._set_retcode(i, 0)
159+
fit = CmdStanMCMC(runset)
160+
fit.draws()
161+
return fit
162+
elif config_dict['method'] == 'optimize':
163+
if 'algorithm' not in config_dict:
164+
raise ValueError(
165+
"Cannot find optimization algorithm"
166+
" in file {}.".format(csvfiles[0])
167+
)
168+
optimize_args = OptimizeArgs(
169+
algorithm=config_dict['algorithm'],
170+
save_iterations=config_dict['save_iterations'],
171+
)
172+
cmdstan_args = CmdStanArgs(
173+
model_name=config_dict['model'],
174+
model_exe=config_dict['model'],
175+
chain_ids=None,
176+
method_args=optimize_args,
177+
)
178+
runset = RunSet(args=cmdstan_args)
179+
runset._csv_files = csvfiles
180+
for i in range(len(runset._retcodes)):
181+
runset._set_retcode(i, 0)
182+
return CmdStanMLE(runset)
183+
elif config_dict['method'] == 'variational':
184+
if 'algorithm' not in config_dict:
185+
raise ValueError(
186+
"Cannot find variational algorithm"
187+
" in file {}.".format(csvfiles[0])
188+
)
189+
variational_args = VariationalArgs(
190+
algorithm=config_dict['algorithm'],
191+
iter=config_dict['iter'],
192+
grad_samples=config_dict['grad_samples'],
193+
elbo_samples=config_dict['elbo_samples'],
194+
eta=config_dict['eta'],
195+
tol_rel_obj=config_dict['tol_rel_obj'],
196+
eval_elbo=config_dict['eval_elbo'],
197+
output_samples=config_dict['output_samples'],
198+
)
199+
cmdstan_args = CmdStanArgs(
200+
model_name=config_dict['model'],
201+
model_exe=config_dict['model'],
202+
chain_ids=None,
203+
method_args=variational_args,
204+
)
205+
runset = RunSet(args=cmdstan_args)
206+
runset._csv_files = csvfiles
207+
for i in range(len(runset._retcodes)):
208+
runset._set_retcode(i, 0)
209+
return CmdStanVB(runset)
210+
else:
211+
get_logger().info(
212+
'Unable to process CSV output files from method %s.',
213+
(config_dict['method']),
214+
)
215+
return None
216+
except (IOError, OSError, PermissionError) as e:
217+
raise ValueError(
218+
'An error occured processing the CSV files:\n\t{}'.format(str(e))
219+
) from e

0 commit comments

Comments
 (0)