Skip to content

Commit 329d168

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents 0649ae3 + a14e9bd commit 329d168

49 files changed

Lines changed: 1726 additions & 1623 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/main.yml

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ on:
88
tags:
99
- '**'
1010
pull_request:
11+
workflow_dispatch:
12+
inputs:
13+
cmdstan-version:
14+
description: 'Version to test'
15+
required: false
16+
default: 'latest'
17+
1118
jobs:
1219
get-cmdstan-version:
1320
# get the latest cmdstan version to use as part of the cache key
@@ -17,7 +24,11 @@ jobs:
1724
- name: Get CmdStan version
1825
id: check-cmdstan
1926
run: |
20-
echo "::set-output name=version::$(python -c 'import requests;print(requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])')"
27+
if [[ "${{ github.event.inputs.cmdstan-version }}" == "latest" ]]; then
28+
echo "::set-output name=version::${{ github.event.inputs.cmdstan-version }}"
29+
else
30+
echo "::set-output name=version::$(python -c 'import requests;print(requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])')"
31+
fi
2132
outputs:
2233
version: ${{ steps.check-cmdstan.outputs.version }}
2334

@@ -70,20 +81,28 @@ jobs:
7081
- name: Show libraries
7182
run: python -m pip freeze
7283

84+
- name: Get system info
85+
uses: kenchan0130/actions-system-info@v1.0.0
86+
id: system-info
87+
7388
- name: CmdStan installation cacheing
7489
uses: actions/cache@v2
7590
with:
7691
path: ~/.cmdstan
77-
key: ${{ runner.os }}-cmdstan-${{ needs.get-cmdstan-version.outputs.version }}-${{ hashFiles('**/install_cmdstan.py') }}
92+
key: ${{ runner.os }}-${{ steps.system-info.outputs.release }}-cmdstan-${{ needs.get-cmdstan-version.outputs.version }}-${{ hashFiles('**/install_cmdstan.py') }}
7893

7994
- name: Install CmdStan (Linux, macOS)
8095
if: matrix.os != 'windows-latest'
8196
run: |
97+
install_cmdstan -h
98+
install_cxx_toolchain -h
8299
python -m cmdstanpy.install_cmdstan
83100
84101
- name: Install CmdStan (Windows)
85102
if: matrix.os == 'windows-latest'
86103
run: |
104+
install_cmdstan -h
105+
install_cxx_toolchain -h
87106
python -m cmdstanpy.install_cmdstan --compiler
88107
89108
- name: Run tests

cmdstanpy/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""PyPi Version"""
22

3-
__version__ = '1.0.0rc2'
3+
__version__ = '1.0.0'

cmdstanpy/install_cmdstan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,5 +538,9 @@ def main(args: Dict[str, Any]) -> None:
538538
print('CmdStan version {} already installed'.format(version))
539539

540540

541-
if __name__ == '__main__':
541+
def __main__() -> None:
542542
main(parse_cmdline_args())
543+
544+
545+
if __name__ == '__main__':
546+
__main__()

cmdstanpy/install_cxx_toolchain.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,5 +360,9 @@ def parse_cmdline_args() -> Dict[str, Any]:
360360
return vars(parser.parse_args(sys.argv[1:]))
361361

362362

363-
if __name__ == '__main__':
363+
def __main__() -> None:
364364
main(parse_cmdline_args())
365+
366+
367+
if __name__ == '__main__':
368+
__main__()

cmdstanpy/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def sample(
10001000
show_progress = False
10011001
else:
10021002
show_progress = show_progress and progbar.allow_show_progress()
1003-
get_logger().info('CmdStan start procesing')
1003+
get_logger().info('CmdStan start processing')
10041004

10051005
progress_hook: Optional[Callable[[str, int], None]] = None
10061006
if show_progress:

cmdstanpy/stanfit/__init__.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Container objects for results of CmdStan run(s)."""
2+
3+
import glob
4+
import os
5+
from typing import Any, Dict, List, Optional, Union
6+
7+
from cmdstanpy.cmdstan_args import (
8+
CmdStanArgs,
9+
OptimizeArgs,
10+
SamplerArgs,
11+
VariationalArgs,
12+
)
13+
from cmdstanpy.utils import check_sampler_csv, get_logger, scan_config
14+
15+
from .mcmc import CmdStanGQ, CmdStanMCMC
16+
from .metadata import InferenceMetadata
17+
from .mle import CmdStanMLE
18+
from .runset import RunSet
19+
from .vb import CmdStanVB
20+
21+
__all__ = [
22+
"RunSet",
23+
"InferenceMetadata",
24+
"CmdStanMCMC",
25+
"CmdStanMLE",
26+
"CmdStanVB",
27+
"CmdStanGQ",
28+
]
29+
30+
31+
def from_csv(
32+
path: Union[str, List[str], None] = None, method: Optional[str] = None
33+
) -> Union[CmdStanMCMC, CmdStanMLE, CmdStanVB, None]:
34+
"""
35+
Instantiate a CmdStan object from a the Stan CSV files from a CmdStan run.
36+
CSV files are specified from either a list of Stan CSV files or a single
37+
filepath which can be either a directory name, a Stan CSV filename, or
38+
a pathname pattern (i.e., a Python glob). The optional argument 'method'
39+
checks that the CSV files were produced by that method.
40+
Stan CSV files from CmdStan methods 'sample', 'optimize', and 'variational'
41+
result in objects of class CmdStanMCMC, CmdStanMLE, and CmdStanVB,
42+
respectively.
43+
44+
:param path: directory path
45+
:param method: method name (optional)
46+
47+
:return: either a CmdStanMCMC, CmdStanMLE, or CmdStanVB object
48+
"""
49+
if path is None:
50+
raise ValueError('Must specify path to Stan CSV files.')
51+
if method is not None and method not in [
52+
'sample',
53+
'optimize',
54+
'variational',
55+
]:
56+
raise ValueError(
57+
'Bad method argument {}, must be one of: '
58+
'"sample", "optimize", "variational"'.format(method)
59+
)
60+
61+
csvfiles = []
62+
if isinstance(path, list):
63+
csvfiles = path
64+
elif isinstance(path, str):
65+
if '*' in path:
66+
splits = os.path.split(path)
67+
if splits[0] is not None:
68+
if not (os.path.exists(splits[0]) and os.path.isdir(splits[0])):
69+
raise ValueError(
70+
'Invalid path specification, {} '
71+
' unknown directory: {}'.format(path, splits[0])
72+
)
73+
csvfiles = glob.glob(path)
74+
elif os.path.exists(path) and os.path.isdir(path):
75+
for file in os.listdir(path):
76+
if file.endswith(".csv"):
77+
csvfiles.append(os.path.join(path, file))
78+
elif os.path.exists(path):
79+
csvfiles.append(path)
80+
else:
81+
raise ValueError('Invalid path specification: {}'.format(path))
82+
else:
83+
raise ValueError('Invalid path specification: {}'.format(path))
84+
85+
if len(csvfiles) == 0:
86+
raise ValueError('No CSV files found in directory {}'.format(path))
87+
for file in csvfiles:
88+
if not (os.path.exists(file) and file.endswith('.csv')):
89+
raise ValueError(
90+
'Bad CSV file path spec,'
91+
' includes non-csv file: {}'.format(file)
92+
)
93+
94+
config_dict: Dict[str, Any] = {}
95+
try:
96+
with open(csvfiles[0], 'r') as fd:
97+
scan_config(fd, config_dict, 0)
98+
except (IOError, OSError, PermissionError) as e:
99+
raise ValueError('Cannot read CSV file: {}'.format(csvfiles[0])) from e
100+
if 'model' not in config_dict or 'method' not in config_dict:
101+
raise ValueError("File {} is not a Stan CSV file.".format(csvfiles[0]))
102+
if method is not None and method != config_dict['method']:
103+
raise ValueError(
104+
'Expecting Stan CSV output files from method {}, '
105+
' found outputs from method {}'.format(
106+
method, config_dict['method']
107+
)
108+
)
109+
try:
110+
if config_dict['method'] == 'sample':
111+
chains = len(csvfiles)
112+
sampler_args = SamplerArgs(
113+
iter_sampling=config_dict['num_samples'],
114+
iter_warmup=config_dict['num_warmup'],
115+
thin=config_dict['thin'],
116+
save_warmup=config_dict['save_warmup'],
117+
)
118+
# bugfix 425, check for fixed_params output
119+
try:
120+
check_sampler_csv(
121+
csvfiles[0],
122+
iter_sampling=config_dict['num_samples'],
123+
iter_warmup=config_dict['num_warmup'],
124+
thin=config_dict['thin'],
125+
save_warmup=config_dict['save_warmup'],
126+
)
127+
except ValueError:
128+
try:
129+
check_sampler_csv(
130+
csvfiles[0],
131+
is_fixed_param=True,
132+
iter_sampling=config_dict['num_samples'],
133+
iter_warmup=config_dict['num_warmup'],
134+
thin=config_dict['thin'],
135+
save_warmup=config_dict['save_warmup'],
136+
)
137+
sampler_args = SamplerArgs(
138+
iter_sampling=config_dict['num_samples'],
139+
iter_warmup=config_dict['num_warmup'],
140+
thin=config_dict['thin'],
141+
save_warmup=config_dict['save_warmup'],
142+
fixed_param=True,
143+
)
144+
except (ValueError) as e:
145+
raise ValueError(
146+
'Invalid or corrupt Stan CSV output file, '
147+
) from e
148+
149+
cmdstan_args = CmdStanArgs(
150+
model_name=config_dict['model'],
151+
model_exe=config_dict['model'],
152+
chain_ids=[x + 1 for x in range(chains)],
153+
method_args=sampler_args,
154+
)
155+
runset = RunSet(args=cmdstan_args, chains=chains)
156+
runset._csv_files = csvfiles
157+
for i in range(len(runset._retcodes)):
158+
runset._set_retcode(i, 0)
159+
fit = CmdStanMCMC(runset)
160+
fit.draws()
161+
return fit
162+
elif config_dict['method'] == 'optimize':
163+
if 'algorithm' not in config_dict:
164+
raise ValueError(
165+
"Cannot find optimization algorithm"
166+
" in file {}.".format(csvfiles[0])
167+
)
168+
optimize_args = OptimizeArgs(
169+
algorithm=config_dict['algorithm'],
170+
save_iterations=config_dict['save_iterations'],
171+
)
172+
cmdstan_args = CmdStanArgs(
173+
model_name=config_dict['model'],
174+
model_exe=config_dict['model'],
175+
chain_ids=None,
176+
method_args=optimize_args,
177+
)
178+
runset = RunSet(args=cmdstan_args)
179+
runset._csv_files = csvfiles
180+
for i in range(len(runset._retcodes)):
181+
runset._set_retcode(i, 0)
182+
return CmdStanMLE(runset)
183+
elif config_dict['method'] == 'variational':
184+
if 'algorithm' not in config_dict:
185+
raise ValueError(
186+
"Cannot find variational algorithm"
187+
" in file {}.".format(csvfiles[0])
188+
)
189+
variational_args = VariationalArgs(
190+
algorithm=config_dict['algorithm'],
191+
iter=config_dict['iter'],
192+
grad_samples=config_dict['grad_samples'],
193+
elbo_samples=config_dict['elbo_samples'],
194+
eta=config_dict['eta'],
195+
tol_rel_obj=config_dict['tol_rel_obj'],
196+
eval_elbo=config_dict['eval_elbo'],
197+
output_samples=config_dict['output_samples'],
198+
)
199+
cmdstan_args = CmdStanArgs(
200+
model_name=config_dict['model'],
201+
model_exe=config_dict['model'],
202+
chain_ids=None,
203+
method_args=variational_args,
204+
)
205+
runset = RunSet(args=cmdstan_args)
206+
runset._csv_files = csvfiles
207+
for i in range(len(runset._retcodes)):
208+
runset._set_retcode(i, 0)
209+
return CmdStanVB(runset)
210+
else:
211+
get_logger().info(
212+
'Unable to process CSV output files from method %s.',
213+
(config_dict['method']),
214+
)
215+
return None
216+
except (IOError, OSError, PermissionError) as e:
217+
raise ValueError(
218+
'An error occured processing the CSV files:\n\t{}'.format(str(e))
219+
) from e

0 commit comments

Comments
 (0)