Skip to content

Commit a28bb95

Browse files
authored
Merge pull request #681 from stan-dev/feature/tuple-io
Stan 2.33: Move IO munging to external package, refactors
2 parents 107a347 + 7945536 commit a28bb95

37 files changed

Lines changed: 406 additions & 890 deletions

.github/workflows/main.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ on:
1515
required: false
1616
default: ''
1717

18+
# only run one copy per PR
19+
concurrency:
20+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
21+
cancel-in-progress: true
22+
1823
jobs:
1924
get-cmdstan-version:
2025
# get the latest cmdstan version to use as part of the cache key
@@ -27,7 +32,7 @@ jobs:
2732
if [[ "${{ github.event.inputs.cmdstan-version }}" != "" ]]; then
2833
echo "version=${{ github.event.inputs.cmdstan-version }}" >> $GITHUB_OUTPUT
2934
else
30-
python -c 'import requests;print("version="+requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])' >> $GITHUB_OUTPUT
35+
python -c 'import requests;print("version="+requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])' >> $GITHUB_OUTPUT
3136
fi
3237
outputs:
3338
version: ${{ steps.check-cmdstan.outputs.version }}
@@ -39,7 +44,7 @@ jobs:
3944
strategy:
4045
matrix:
4146
os: [ubuntu-latest, macos-latest, windows-latest]
42-
python-version: ["3.7.1 - 3.7.16", "3.8", "3.9", "3.10", "3.11"]
47+
python-version: ["3.8", "3.9", "3.10", "3.11"]
4348
env:
4449
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
4550
steps:

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ repos:
1212
- id: isort
1313
# https://github.com/python/black#version-control-integration
1414
- repo: https://github.com/psf/black
15-
rev: 22.10.0
15+
rev: 23.7.0
1616
hooks:
1717
- id: black
1818
- repo: https://github.com/pycqa/flake8
1919
rev: 3.9.2
2020
hooks:
2121
- id: flake8
2222
- repo: https://github.com/pre-commit/mirrors-mypy
23-
rev: v0.982
23+
rev: v1.5.0
2424
hooks:
2525
- id: mypy
2626
# Copied from setup.cfg

cmdstanpy/compiler_opts.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,12 @@ def add_include_path(self, path: str) -> None:
275275
elif path not in self._stanc_options['include-paths']:
276276
self._stanc_options['include-paths'].append(path)
277277

278-
def compose_stanc(self) -> List[str]:
278+
def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
279279
opts = []
280+
281+
if filename_in_msg is not None:
282+
opts.append(f'--filename-in-msg={filename_in_msg}')
283+
280284
if self._stanc_options is not None and len(self._stanc_options) > 0:
281285
for key, val in self._stanc_options.items():
282286
if key == 'include-paths':
@@ -295,11 +299,19 @@ def compose_stanc(self) -> List[str]:
295299
opts.append(f'--{key}')
296300
return opts
297301

298-
def compose(self) -> List[str]:
299-
"""Format makefile options as list of strings."""
302+
def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
303+
"""
304+
Format makefile options as list of strings.
305+
306+
Parameters
307+
----------
308+
filename_in_msg : str, optional
309+
filename to be displayed in stanc3 error messages
310+
(if different from actual filename on disk), by default None
311+
"""
300312
opts = [
301313
'STANCFLAGS+=' + flag.replace(" ", "\\ ")
302-
for flag in self.compose_stanc()
314+
for flag in self.compose_stanc(filename_in_msg)
303315
]
304316
if self._cpp_options is not None and len(self._cpp_options) > 0:
305317
for key, val in self._cpp_options.items():

cmdstanpy/install_cxx_toolchain.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Linux: Not implemented
88
Optional command line arguments:
99
-v, --version : version, defaults to latest
10-
-d, --dir : install directory, defaults to '~/.cmdstan(py)
10+
-d, --dir : install directory, defaults to '~/.cmdstan
1111
-s (--silent) : install with /VERYSILENT instead of /SILENT for RTools
1212
-m --no-make : don't install mingw32-make (Windows RTools 4.0 only)
1313
--progress : flag, when specified show progress bar for RTools download
@@ -27,7 +27,7 @@
2727
from cmdstanpy.utils import pushd, validate_dir, wrap_url_progress_hook
2828

2929
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
30-
IS_64BITS = sys.maxsize > 2 ** 32
30+
IS_64BITS = sys.maxsize > 2**32
3131

3232

3333
def usage() -> None:
@@ -333,7 +333,7 @@ def parse_cmdline_args() -> Dict[str, Any]:
333333
parser = argparse.ArgumentParser()
334334
parser.add_argument('--version', '-v', help="version, defaults to latest")
335335
parser.add_argument(
336-
'--dir', '-d', help="install directory, defaults to '~/.cmdstan(py)"
336+
'--dir', '-d', help="install directory, defaults to '~/.cmdstan"
337337
)
338338
parser.add_argument(
339339
'--silent',

cmdstanpy/model.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Dict,
2323
Iterable,
2424
List,
25+
Literal,
2526
Mapping,
2627
Optional,
2728
TypeVar,
@@ -117,8 +118,7 @@ def __init__(
117118
model_name: Optional[str] = None,
118119
stan_file: OptionalPath = None,
119120
exe_file: OptionalPath = None,
120-
# TODO should be Literal['force'] not str
121-
compile: Union[bool, str] = True,
121+
compile: Union[bool, Literal['force']] = True,
122122
stanc_options: Optional[Dict[str, Any]] = None,
123123
cpp_options: Optional[Dict[str, Any]] = None,
124124
user_header: OptionalPath = None,
@@ -300,7 +300,7 @@ def src_info(self) -> Dict[str, Any]:
300300
cmd = (
301301
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
302302
# handle include-paths, allow-undefined etc
303-
+ self._compiler_options.compose_stanc()
303+
+ self._compiler_options.compose_stanc(None)
304304
+ ['--info', str(self.stan_file)]
305305
)
306306
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
@@ -343,7 +343,7 @@ def format(
343343
cmd = (
344344
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
345345
# handle include-paths, allow-undefined etc
346-
+ self._compiler_options.compose_stanc()
346+
+ self._compiler_options.compose_stanc(None)
347347
+ [str(self.stan_file)]
348348
)
349349

@@ -528,7 +528,7 @@ def compile(
528528
)
529529
cmd = [make]
530530
if self._compiler_options is not None:
531-
cmd.extend(self._compiler_options.compose())
531+
cmd.extend(self._compiler_options.compose(self._stan_file))
532532
cmd.append(Path(exe_file).as_posix())
533533

534534
sout = io.StringIO()
@@ -1005,10 +1005,7 @@ def sample(
10051005
fixed_param = self._fixed_param
10061006

10071007
if chains is None:
1008-
if fixed_param:
1009-
chains = 1
1010-
else:
1011-
chains = 4
1008+
chains = 4
10121009
if chains < 1:
10131010
raise ValueError(
10141011
'Chains must be a positive integer value, found {}.'.format(
@@ -1045,7 +1042,7 @@ def sample(
10451042
info_dict = self.exe_info()
10461043
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
10471044
# run multi-chain sampler unless algo is fixed_param or 1 chain
1048-
if fixed_param or (chains == 1):
1045+
if chains == 1:
10491046
force_one_process_per_chain = True
10501047

10511048
if (
@@ -1224,19 +1221,6 @@ def sample(
12241221
sampler_args.fixed_param = True
12251222
runset._args.method_args = sampler_args
12261223

1227-
# if there was an exe-file only initialization,
1228-
# this could happen, so throw a nice error
1229-
if (
1230-
sampler_args.fixed_param
1231-
and not one_process_per_chain
1232-
and chains > 1
1233-
):
1234-
raise RuntimeError(
1235-
"Cannot use single-process multichain parallelism"
1236-
" with algorithm fixed_param.\nTry setting argument"
1237-
" force_one_process_per_chain to True"
1238-
)
1239-
12401224
errors = runset.get_err_msgs()
12411225
if not runset._check_retcodes():
12421226
msg = (

cmdstanpy/stanfit/gq.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
get_logger,
3838
scan_generated_quantities_csv,
3939
)
40-
from cmdstanpy.utils.data_munging import extract_reshape
4140

4241
from .mcmc import CmdStanMCMC
4342
from .metadata import InferenceMetadata
@@ -242,7 +241,9 @@ def draws(
242241
]
243242
drop_cols: List[int] = []
244243
for dup in dups:
245-
drop_cols.extend(self.previous_fit.metadata.stan_vars_cols[dup])
244+
drop_cols.extend(
245+
self.previous_fit._metadata.stan_vars[dup].columns()
246+
)
246247

247248
start_idx, _ = self._draws_start(inc_warmup)
248249
previous_draws = self._previous_draws(True)
@@ -324,18 +325,24 @@ def draws_pd(
324325

325326
self._assemble_generated_quantities()
326327

327-
gq_cols = []
328-
mcmc_vars = []
328+
gq_cols: List[str] = []
329+
mcmc_vars: List[str] = []
329330
if vars is not None:
330331
for var in vars_list:
331-
if var in self.metadata.stan_vars_cols:
332-
for idx in self.metadata.stan_vars_cols[var]:
333-
gq_cols.append(self.column_names[idx])
332+
if var in self._metadata.stan_vars:
333+
info = self._metadata.stan_vars[var]
334+
gq_cols.extend(
335+
self.column_names[info.start_idx : info.end_idx]
336+
)
334337
elif (
335-
inc_sample
336-
and var in self.previous_fit.metadata.stan_vars_cols
338+
inc_sample and var in self.previous_fit._metadata.stan_vars
337339
):
338-
mcmc_vars.append(var)
340+
info = self.previous_fit._metadata.stan_vars[var]
341+
mcmc_vars.extend(
342+
self.previous_fit.column_names[
343+
info.start_idx : info.end_idx
344+
]
345+
)
339346
else:
340347
raise ValueError('Unknown variable: {}'.format(var))
341348
else:
@@ -463,18 +470,18 @@ def draws_xr(
463470
else:
464471
vars_list = vars
465472
for var in vars_list:
466-
if var not in self.metadata.stan_vars_cols:
473+
if var not in self._metadata.stan_vars:
467474
if inc_sample and (
468-
var in self.previous_fit.metadata.stan_vars_cols
475+
var in self.previous_fit._metadata.stan_vars
469476
):
470477
mcmc_vars_list.append(var)
471478
dup_vars.append(var)
472479
else:
473480
raise ValueError('Unknown variable: {}'.format(var))
474481
else:
475-
vars_list = list(self.metadata.stan_vars_cols.keys())
482+
vars_list = list(self._metadata.stan_vars.keys())
476483
if inc_sample:
477-
for var in self.previous_fit.metadata.stan_vars_cols.keys():
484+
for var in self.previous_fit._metadata.stan_vars.keys():
478485
if var not in vars_list and var not in mcmc_vars_list:
479486
mcmc_vars_list.append(var)
480487
for var in dup_vars:
@@ -483,7 +490,7 @@ def draws_xr(
483490
self._assemble_generated_quantities()
484491

485492
num_draws = self.previous_fit.num_draws_sampling
486-
sample_config = self.previous_fit.metadata.cmdstan_config
493+
sample_config = self.previous_fit._metadata.cmdstan_config
487494
attrs: MutableMapping[Hashable, Any] = {
488495
"stan_version": f"{sample_config['stan_version_major']}."
489496
f"{sample_config['stan_version_minor']}."
@@ -504,23 +511,15 @@ def draws_xr(
504511
for var in vars_list:
505512
build_xarray_data(
506513
data,
507-
var,
508-
self._metadata.stan_vars_dims[var],
509-
self._metadata.stan_vars_cols[var],
510-
0,
514+
self._metadata.stan_vars[var],
511515
self.draws(inc_warmup=inc_warmup),
512-
self._metadata.stan_vars_types[var],
513516
)
514517
if inc_sample:
515518
for var in mcmc_vars_list:
516519
build_xarray_data(
517520
data,
518-
var,
519-
self.previous_fit.metadata.stan_vars_dims[var],
520-
self.previous_fit.metadata.stan_vars_cols[var],
521-
0,
521+
self.previous_fit._metadata.stan_vars[var],
522522
self.previous_fit.draws(inc_warmup=inc_warmup),
523-
self.previous_fit._metadata.stan_vars_types[var],
524523
)
525524

526525
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
@@ -545,13 +544,13 @@ def stan_variable(
545544
the next M are from chain 2, and the last M elements are from chain N.
546545
547546
* If the variable is a scalar variable, the return array has shape
548-
( draws X chains, 1).
547+
( draws * chains, 1).
549548
* If the variable is a vector, the return array has shape
550-
( draws X chains, len(vector))
549+
( draws * chains, len(vector))
551550
* If the variable is a matrix, the return array has shape
552-
( draws X chains, size(dim 1) X size(dim 2) )
551+
( draws * chains, size(dim 1), size(dim 2) )
553552
* If the variable is an array with N dimensions, the return array
554-
has shape ( draws X chains, size(dim 1) X ... X size(dim N))
553+
has shape ( draws * chains, size(dim 1), ..., size(dim N))
555554
556555
For example, if the Stan program variable ``theta`` is a 3x3 matrix,
557556
and the sample consists of 4 chains with 1000 post-warmup draws,
@@ -573,8 +572,8 @@ def stan_variable(
573572
CmdStanMLE.stan_variable
574573
CmdStanVB.stan_variable
575574
"""
576-
model_var_names = self.previous_fit.metadata.stan_vars_cols.keys()
577-
gq_var_names = self.metadata.stan_vars_cols.keys()
575+
model_var_names = self.previous_fit._metadata.stan_vars.keys()
576+
gq_var_names = self._metadata.stan_vars.keys()
578577
if not (var in model_var_names or var in gq_var_names):
579578
raise ValueError(
580579
f'Unknown variable name: {var}\n'
@@ -588,30 +587,21 @@ def stan_variable(
588587
)
589588
elif isinstance(self.previous_fit, CmdStanMLE):
590589
return np.atleast_1d( # type: ignore
591-
np.asarray(
592-
self.previous_fit.stan_variable(
593-
var, inc_iterations=inc_warmup
594-
)
590+
self.previous_fit.stan_variable(
591+
var, inc_iterations=inc_warmup
595592
)
596593
)
597594
else:
598595
return np.atleast_1d( # type: ignore
599-
np.asarray(self.previous_fit.stan_variable(var))
596+
self.previous_fit.stan_variable(var)
600597
)
601-
602598
# is gq variable
603599
self._assemble_generated_quantities()
604-
draw1, num_draws = self._draws_start(inc_warmup)
605-
dims = (num_draws * self.chains,)
606-
col_idxs = self._metadata.stan_vars_cols[var]
607-
608-
return extract_reshape(
609-
dims=dims + self._metadata.stan_vars_dims[var],
610-
col_idxs=col_idxs,
611-
var_type=self._metadata.stan_vars_types[var],
612-
start_row=draw1,
613-
draws_in=self._draws,
614-
)
600+
601+
draw1, _ = self._draws_start(inc_warmup)
602+
draws = flatten_chains(self._draws[draw1:])
603+
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(draws)
604+
return out
615605

616606
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
617607
"""
@@ -630,8 +620,8 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
630620
CmdStanVB.stan_variables
631621
"""
632622
result = {}
633-
sample_var_names = self.previous_fit.metadata.stan_vars_cols.keys()
634-
gq_var_names = self.metadata.stan_vars_cols.keys()
623+
sample_var_names = self.previous_fit._metadata.stan_vars.keys()
624+
gq_var_names = self._metadata.stan_vars.keys()
635625
for name in gq_var_names:
636626
result[name] = self.stan_variable(name, inc_warmup)
637627
for name in sample_var_names:
@@ -697,9 +687,9 @@ def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
697687
if inc_warmup and p_fit._save_iterations:
698688
return p_fit.optimized_iterations_np[:, None] # type: ignore
699689

700-
return np.atleast_2d(p_fit.optimized_params_np,)[ # type: ignore
701-
:, None
702-
]
690+
return np.atleast_2d( # type: ignore
691+
p_fit.optimized_params_np,
692+
)[:, None]
703693
else: # CmdStanVB:
704694
if inc_warmup:
705695
return np.vstack(

0 commit comments

Comments
 (0)