Skip to content

Commit ef439a5

Browse files
committed
Merge branch 'develop' of https://github.com/stan-dev/cmdstanpy into develop
2 parents 329d168 + 6e9ae89 commit ef439a5

16 files changed

Lines changed: 158 additions & 101 deletions

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ jobs:
3939
strategy:
4040
matrix:
4141
os: [ubuntu-latest, macos-latest, windows-latest]
42-
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
43-
fail-fast: false
42+
python-version: [3.7, 3.8, 3.9, "3.10"]
4443
env:
4544
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
4645
steps:
@@ -60,6 +59,7 @@ jobs:
6059
pip install codecov
6160
6261
- name: Run flake8, pylint, mypy
62+
if: matrix.python-version == '3.10'
6363
run: |
6464
flake8 cmdstanpy test
6565
pylint -v cmdstanpy test

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
- id: mypy
2626
# Copied from setup.cfg
2727
exclude: ^test/
28-
additional_dependencies: [ numpy >= 1.21 ]
28+
additional_dependencies: [ numpy >= 1.22, types-ujson ]
2929
# local uses the user-installed pylint, this allows dependency checking
3030
- repo: local
3131
hooks:

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ def validate(self, chains: Optional[int]) -> None:
202202
if all(isinstance(elem, dict) for elem in self.metric):
203203
metric_files: List[str] = []
204204
for i, metric in enumerate(self.metric):
205-
assert isinstance(
206-
metric, dict
207-
) # make the typechecker happy
208-
metric_dict: Dict[str, Any] = metric
205+
metric_dict: Dict[str, Any] = metric # type: ignore
209206
if 'inv_metric' not in metric_dict:
210207
raise ValueError(
211208
'Entry "inv_metric" not found in metric dict '

cmdstanpy/compiler_opts.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,27 @@
33
"""
44

55
import os
6+
from copy import copy
67
from pathlib import Path
78
from typing import Any, Dict, List, Optional, Union
89

910
from cmdstanpy.utils import get_logger
1011

1112
STANC_OPTS = [
1213
'O',
13-
'allow_undefined',
14+
'allow-undefined',
1415
'use-opencl',
1516
'warn-uninitialized',
16-
'include_paths',
17+
'include-paths',
1718
'name',
1819
'warn-pedantic',
1920
]
2021

22+
STANC_DEPRECATED_OPTS = {
23+
'allow_undefined': 'allow-undefined',
24+
'include_paths': 'include-paths',
25+
}
26+
2127
STANC_IGNORE_OPTS = [
2228
'debug-lex',
2329
'debug-parse',
@@ -121,19 +127,37 @@ def validate_stanc_opts(self) -> None:
121127
return
122128
ignore = []
123129
paths = None
130+
for deprecated, replacement in STANC_DEPRECATED_OPTS.items():
131+
if deprecated in self._stanc_options:
132+
if replacement:
133+
get_logger().warning(
134+
'compiler option "%s" is deprecated, use "%s" instead',
135+
deprecated,
136+
replacement,
137+
)
138+
self._stanc_options[replacement] = copy(
139+
self._stanc_options[deprecated]
140+
)
141+
del self._stanc_options[deprecated]
142+
else:
143+
get_logger().warning(
144+
'compiler option "%s" is deprecated and '
145+
'should not be used',
146+
deprecated,
147+
)
124148
for key, val in self._stanc_options.items():
125149
if key in STANC_IGNORE_OPTS:
126150
get_logger().info('ignoring compiler option: %s', key)
127151
ignore.append(key)
128152
elif key not in STANC_OPTS:
129153
raise ValueError(f'unknown stanc compiler option: {key}')
130-
elif key == 'include_paths':
154+
elif key == 'include-paths':
131155
paths = val
132156
if isinstance(val, str):
133157
paths = val.split(',')
134158
elif not isinstance(val, list):
135159
raise ValueError(
136-
'Invalid include_paths, expecting list or '
160+
'Invalid include-paths, expecting list or '
137161
f'string, found type: {type(val)}.'
138162
)
139163
elif key == 'use-opencl':
@@ -145,17 +169,16 @@ def validate_stanc_opts(self) -> None:
145169
for opt in ignore:
146170
del self._stanc_options[opt]
147171
if paths is not None:
148-
self._stanc_options['include_paths'] = paths
149-
bad_paths = [
150-
dir
151-
for dir in self._stanc_options['include_paths']
152-
if not os.path.exists(dir)
153-
]
172+
bad_paths = [dir for dir in paths if not os.path.exists(dir)]
154173
if any(bad_paths):
155174
raise ValueError(
156175
'invalid include paths: {}'.format(', '.join(bad_paths))
157176
)
158177

178+
self._stanc_options['include-paths'] = [
179+
os.path.abspath(os.path.expanduser(path)) for path in paths
180+
]
181+
159182
def validate_cpp_opts(self) -> None:
160183
"""
161184
Check cpp compiler args.
@@ -190,8 +213,8 @@ def validate_user_header(self) -> None:
190213
raise ValueError(
191214
f"Header file must end in .hpp, got {self._user_header}"
192215
)
193-
if "allow_undefined" not in self._stanc_options:
194-
self._stanc_options["allow_undefined"] = True
216+
if "allow-undefined" not in self._stanc_options:
217+
self._stanc_options["allow-undefined"] = True
195218
# set full path
196219
self._user_header = os.path.abspath(self._user_header)
197220

@@ -218,7 +241,7 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
218241
self._stanc_options = new_opts.stanc_options
219242
else:
220243
for key, val in new_opts.stanc_options.items():
221-
if key == 'include_paths':
244+
if key == 'include-paths':
222245
self.add_include_path(str(val))
223246
else:
224247
self._stanc_options[key] = val
@@ -230,30 +253,35 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
230253

231254
def add_include_path(self, path: str) -> None:
232255
"""Adds include path to existing set of compiler options."""
233-
if 'include_paths' not in self._stanc_options:
234-
self._stanc_options['include_paths'] = [path]
235-
elif path not in self._stanc_options['include_paths']:
236-
self._stanc_options['include_paths'].append(path)
256+
path = os.path.abspath(os.path.expanduser(path))
257+
if 'include-paths' not in self._stanc_options:
258+
self._stanc_options['include-paths'] = [path]
259+
elif path not in self._stanc_options['include-paths']:
260+
self._stanc_options['include-paths'].append(path)
237261

238-
def compose(self) -> List[str]:
239-
"""Format makefile options as list of strings."""
262+
def compose_stanc(self) -> List[str]:
240263
opts = []
241264
if self._stanc_options is not None and len(self._stanc_options) > 0:
242265
for key, val in self._stanc_options.items():
243-
if key == 'include_paths':
266+
if key == 'include-paths':
244267
opts.append(
245-
'STANCFLAGS+=--include_paths='
268+
'--include-paths='
246269
+ ','.join(
247270
(
248271
Path(p).as_posix()
249-
for p in self._stanc_options['include_paths']
272+
for p in self._stanc_options['include-paths']
250273
)
251274
)
252275
)
253276
elif key == 'name':
254-
opts.append(f'STANCFLAGS+=--name={val}')
277+
opts.append(f'--name={val}')
255278
else:
256-
opts.append(f'STANCFLAGS+=--{key}')
279+
opts.append(f'--{key}')
280+
return opts
281+
282+
def compose(self) -> List[str]:
283+
"""Format makefile options as list of strings."""
284+
opts = ['STANCFLAGS+=' + flag for flag in self.compose_stanc()]
257285
if self._cpp_options is not None and len(self._cpp_options) > 0:
258286
for key, val in self._cpp_options.items():
259287
opts.append(f'{key}={val}')

cmdstanpy/install_cmdstan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def main(args: Dict[str, Any]) -> None:
460460
print('Installing CmdStan version: {}'.format(version))
461461
else:
462462
raise ValueError(
463-
'Invalid version requested: {}, cannot install.'.format(version)
463+
f'Version {version} cannot be downloaded. '
464+
'Connection to GitHub failed. '
465+
'Check firewall settings or ensure this version exists.'
464466
)
465467

466468
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))

cmdstanpy/model.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def __init__(
128128
)
129129
self._name = model_name.strip()
130130

131+
self._compiler_options.validate()
132+
131133
if stan_file is None:
132134
if exe_file is None:
133135
raise ValueError(
@@ -146,19 +148,14 @@ def __init__(
146148
if not self._name:
147149
self._name, _ = os.path.splitext(filename)
148150

149-
# TODO: When minimum version is 2.27, use --info instead
150151
# if program has include directives, record path
151152
with open(self._stan_file, 'r') as fd:
152153
program = fd.read()
153154
if '#include' in program:
154155
path, _ = os.path.split(self._stan_file)
155-
if self._compiler_options is None:
156-
self._compiler_options = CompilerOptions(
157-
stanc_options={'include_paths': [path]}
158-
)
159-
elif self._compiler_options._stanc_options is None:
156+
if self._compiler_options._stanc_options is None:
160157
self._compiler_options._stanc_options = {
161-
'include_paths': [path]
158+
'include-paths': [path]
162159
}
163160
else:
164161
self._compiler_options.add_include_path(path)
@@ -186,8 +183,6 @@ def __init__(
186183
' found: {}.'.format(self._name, exename)
187184
)
188185

189-
self._compiler_options.validate()
190-
191186
if platform.system() == 'Windows':
192187
try:
193188
do_command(['where.exe', 'tbb.dll'], fd_out=None)
@@ -279,12 +274,15 @@ def src_info(self) -> Dict[str, Any]:
279274
if self.stan_file is None:
280275
return result
281276
try:
282-
283-
cmd = [
284-
os.path.join('.', 'bin', 'stanc' + EXTENSION),
285-
'--info',
286-
self.stan_file,
287-
]
277+
cmd = (
278+
[os.path.join('.', 'bin', 'stanc' + EXTENSION)]
279+
# handle include-paths, allow-undefined etc
280+
+ self._compiler_options.compose_stanc()
281+
+ [
282+
'--info',
283+
self.stan_file,
284+
]
285+
)
288286
sout = io.StringIO()
289287
do_command(cmd=cmd, cwd=cmdstan_path(), fd_out=sout)
290288
result = json.loads(sout.getvalue())

cmdstanpy/stanfit/mcmc.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def __init__(
9595
self._save_warmup = sampler_args.save_warmup
9696
self._sig_figs = runset._args.sig_figs
9797
# info from CSV values, instantiated lazily
98-
self._metric = np.array(())
99-
self._step_size = np.array(())
100-
self._draws = np.array(())
98+
self._metric: np.ndarray = np.array(())
99+
self._step_size: np.ndarray = np.array(())
100+
self._draws: np.ndarray = np.array(())
101101
# info from CSV initial comments and header
102102
config = self._validate_csv_files()
103103
self._metadata: InferenceMetadata = InferenceMetadata(config)
@@ -246,7 +246,7 @@ def draws(
246246

247247
if concat_chains:
248248
return flatten_chains(self._draws[start_idx:, :, :])
249-
return self._draws[start_idx:, :, :] # type: ignore
249+
return self._draws[start_idx:, :, :]
250250

251251
def _validate_csv_files(self) -> Dict[str, Any]:
252252
"""
@@ -675,9 +675,7 @@ def stan_variable(
675675
if len(col_idxs) > 0:
676676
dims.extend(self._metadata.stan_vars_dims[var])
677677
# pylint: disable=redundant-keyword-arg
678-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
679-
dims, order='F'
680-
)
678+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
681679

682680
def stan_variables(self) -> Dict[str, np.ndarray]:
683681
"""
@@ -748,7 +746,7 @@ def __init__(
748746
)
749747
self.runset = runset
750748
self.mcmc_sample = mcmc_sample
751-
self._draws = np.array(())
749+
self._draws: np.ndarray = np.array(())
752750
config = self._validate_csv_files()
753751
self._metadata = InferenceMetadata(config)
754752

@@ -765,7 +763,7 @@ def __repr__(self) -> str:
765763
)
766764
return repr
767765

768-
def _validate_csv_files(self) -> dict:
766+
def _validate_csv_files(self) -> Dict[str, Any]:
769767
"""
770768
Checks that Stan CSV output files for all chains are consistent
771769
and returns dict containing config and column names.
@@ -910,13 +908,13 @@ def draws(
910908
if concat_chains:
911909
return flatten_chains(self._draws[start_idx:, :, :])
912910
if inc_sample:
913-
return np.dstack( # type: ignore
911+
return np.dstack(
914912
(
915913
np.delete(self.mcmc_sample.draws(), drop_cols, axis=1),
916914
self._draws,
917915
)
918916
)[start_idx:, :, :]
919-
return self._draws[start_idx:, :, :] # type: ignore
917+
return self._draws[start_idx:, :, :]
920918

921919
def draws_pd(
922920
self,
@@ -1195,9 +1193,7 @@ def stan_variable(
11951193
if len(col_idxs) > 0:
11961194
dims.extend(self._metadata.stan_vars_dims[var])
11971195
# pylint: disable=redundant-keyword-arg
1198-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
1199-
dims, order='F'
1200-
)
1196+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
12011197

12021198
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
12031199
"""
@@ -1229,7 +1225,7 @@ def _assemble_generated_quantities(self) -> None:
12291225
# use numpy loadtxt
12301226
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
12311227
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]
1232-
gq_sample = np.empty(
1228+
gq_sample: np.ndarray = np.empty(
12331229
(num_draws, self.chains, len(self.column_names)),
12341230
dtype=float,
12351231
order='F',

0 commit comments

Comments
 (0)