Skip to content

Commit 27a219f

Browse files
committed
Merge branch 'develop' into model-formatting
2 parents e598918 + 683a135 commit 27a219f

11 files changed

Lines changed: 50 additions & 55 deletions

File tree

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ jobs:
3939
strategy:
4040
matrix:
4141
os: [ubuntu-latest, macos-latest, windows-latest]
42-
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
43-
fail-fast: false
42+
python-version: [3.7, 3.8, 3.9, "3.10"]
4443
env:
4544
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
4645
steps:
@@ -60,6 +59,7 @@ jobs:
6059
pip install codecov
6160
6261
- name: Run flake8, pylint, mypy
62+
if: matrix.python-version == '3.10'
6363
run: |
6464
flake8 cmdstanpy test
6565
pylint -v cmdstanpy test

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
- id: mypy
2626
# Copied from setup.cfg
2727
exclude: ^test/
28-
additional_dependencies: [ numpy >= 1.21 ]
28+
additional_dependencies: [ numpy >= 1.22, types-ujson ]
2929
# local uses the user-installed pylint, this allows dependency checking
3030
- repo: local
3131
hooks:

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ def validate(self, chains: Optional[int]) -> None:
202202
if all(isinstance(elem, dict) for elem in self.metric):
203203
metric_files: List[str] = []
204204
for i, metric in enumerate(self.metric):
205-
assert isinstance(
206-
metric, dict
207-
) # make the typechecker happy
208-
metric_dict: Dict[str, Any] = metric
205+
metric_dict: Dict[str, Any] = metric # type: ignore
209206
if 'inv_metric' not in metric_dict:
210207
raise ValueError(
211208
'Entry "inv_metric" not found in metric dict '

cmdstanpy/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def format_model(
311311
if self.stan_file is None or not os.path.isfile(self.stan_file):
312312
raise ValueError("No Stan file found for this module")
313313
try:
314+
# TODO need include paths if they exist.
314315
cmd = [
315316
os.path.join('.', 'bin', 'stanc' + EXTENSION),
316317
'--auto-format',

cmdstanpy/stanfit/mcmc.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def __init__(
9595
self._save_warmup = sampler_args.save_warmup
9696
self._sig_figs = runset._args.sig_figs
9797
# info from CSV values, instantiated lazily
98-
self._metric = np.array(())
99-
self._step_size = np.array(())
100-
self._draws = np.array(())
98+
self._metric: np.ndarray = np.array(())
99+
self._step_size: np.ndarray = np.array(())
100+
self._draws: np.ndarray = np.array(())
101101
# info from CSV initial comments and header
102102
config = self._validate_csv_files()
103103
self._metadata: InferenceMetadata = InferenceMetadata(config)
@@ -231,7 +231,7 @@ def draws(
231231
CmdStanMCMC.draws_xr
232232
CmdStanGQ.draws
233233
"""
234-
if self._draws.size == 0:
234+
if self._draws.shape == (0,):
235235
self._assemble_draws()
236236

237237
if inc_warmup and not self._save_warmup:
@@ -246,7 +246,7 @@ def draws(
246246

247247
if concat_chains:
248248
return flatten_chains(self._draws[start_idx:, :, :])
249-
return self._draws[start_idx:, :, :] # type: ignore
249+
return self._draws[start_idx:, :, :]
250250

251251
def _validate_csv_files(self) -> Dict[str, Any]:
252252
"""
@@ -309,9 +309,6 @@ def _assemble_draws(self) -> None:
309309
Allocates and populates the step size, metric, and sample arrays
310310
by parsing the validated stan_csv files.
311311
"""
312-
if self._draws.shape != (0,):
313-
return
314-
315312
num_draws = self.num_draws_sampling
316313
sampling_iter_start = 0
317314
if self._save_warmup:
@@ -527,7 +524,8 @@ def draws_pd(
527524
' must run sampler with "save_warmup=True".'
528525
)
529526

530-
self._assemble_draws()
527+
if self._draws.shape == (0,):
528+
self._assemble_draws()
531529
cols = []
532530
if vars is not None:
533531
for var in set(vars_list):
@@ -583,7 +581,8 @@ def draws_xr(
583581
else:
584582
vars_list = vars
585583

586-
self._assemble_draws()
584+
if self._draws.shape == (0,):
585+
self._assemble_draws()
587586

588587
num_draws = self.num_draws_sampling
589588
meta = self._metadata.cmdstan_config
@@ -663,7 +662,8 @@ def stan_variable(
663662
raise ValueError('No variable name specified.')
664663
if var not in self._metadata.stan_vars_dims:
665664
raise ValueError('Unknown variable name: {}'.format(var))
666-
self._assemble_draws()
665+
if self._draws.shape == (0,):
666+
self._assemble_draws()
667667
draw1 = 0
668668
if not inc_warmup and self._save_warmup:
669669
draw1 = self.num_draws_warmup
@@ -675,9 +675,7 @@ def stan_variable(
675675
if len(col_idxs) > 0:
676676
dims.extend(self._metadata.stan_vars_dims[var])
677677
# pylint: disable=redundant-keyword-arg
678-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
679-
dims, order='F'
680-
)
678+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
681679

682680
def stan_variables(self) -> Dict[str, np.ndarray]:
683681
"""
@@ -705,7 +703,8 @@ def method_variables(self) -> Dict[str, np.ndarray]:
705703
containing per-draw diagnostic values.
706704
"""
707705
result = {}
708-
self._assemble_draws()
706+
if self._draws.shape == (0,):
707+
self._assemble_draws()
709708
for idxs in self.metadata.method_vars_cols.values():
710709
for idx in idxs:
711710
result[self.column_names[idx]] = self._draws[:, :, idx]
@@ -747,7 +746,7 @@ def __init__(
747746
)
748747
self.runset = runset
749748
self.mcmc_sample = mcmc_sample
750-
self._draws = np.array(())
749+
self._draws: np.ndarray = np.array(())
751750
config = self._validate_csv_files()
752751
self._metadata = InferenceMetadata(config)
753752

@@ -764,7 +763,7 @@ def __repr__(self) -> str:
764763
)
765764
return repr
766765

767-
def _validate_csv_files(self) -> dict:
766+
def _validate_csv_files(self) -> Dict[str, Any]:
768767
"""
769768
Checks that Stan CSV output files for all chains are consistent
770769
and returns dict containing config and column names.
@@ -868,7 +867,7 @@ def draws(
868867
CmdStanGQ.draws_xr
869868
CmdStanMCMC.draws
870869
"""
871-
if self._draws.size == 0:
870+
if self._draws.shape == (0,):
872871
self._assemble_generated_quantities()
873872
if (
874873
inc_warmup
@@ -909,13 +908,13 @@ def draws(
909908
if concat_chains:
910909
return flatten_chains(self._draws[start_idx:, :, :])
911910
if inc_sample:
912-
return np.dstack( # type: ignore
911+
return np.dstack(
913912
(
914913
np.delete(self.mcmc_sample.draws(), drop_cols, axis=1),
915914
self._draws,
916915
)
917916
)[start_idx:, :, :]
918-
return self._draws[start_idx:, :, :] # type: ignore
917+
return self._draws[start_idx:, :, :]
919918

920919
def draws_pd(
921920
self,
@@ -955,7 +954,8 @@ def draws_pd(
955954
'Draws from warmup iterations not available,'
956955
' must run sampler with "save_warmup=True".'
957956
)
958-
self._assemble_generated_quantities()
957+
if self._draws.shape == (0,):
958+
self._assemble_generated_quantities()
959959

960960
gq_cols = []
961961
mcmc_vars = []
@@ -1076,7 +1076,8 @@ def draws_xr(
10761076
for var in dup_vars:
10771077
vars_list.remove(var)
10781078

1079-
self._assemble_generated_quantities()
1079+
if self._draws.shape == (0,):
1080+
self._assemble_generated_quantities()
10801081

10811082
num_draws = self.mcmc_sample.num_draws_sampling
10821083
sample_config = self.mcmc_sample.metadata.cmdstan_config
@@ -1173,7 +1174,8 @@ def stan_variable(
11731174
if var not in gq_var_names:
11741175
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
11751176
else: # is gq variable
1176-
self._assemble_generated_quantities()
1177+
if self._draws.shape == (0,):
1178+
self._assemble_generated_quantities()
11771179
draw1 = 0
11781180
if (
11791181
not inc_warmup
@@ -1191,9 +1193,7 @@ def stan_variable(
11911193
if len(col_idxs) > 0:
11921194
dims.extend(self._metadata.stan_vars_dims[var])
11931195
# pylint: disable=redundant-keyword-arg
1194-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
1195-
dims, order='F'
1196-
)
1196+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
11971197

11981198
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
11991199
"""
@@ -1222,10 +1222,10 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
12221222
return result
12231223

12241224
def _assemble_generated_quantities(self) -> None:
1225-
# use numpy genfromtext
1225+
# use numpy loadtxt
12261226
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
12271227
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]
1228-
gq_sample = np.empty(
1228+
gq_sample: np.ndarray = np.empty(
12291229
(num_draws, self.chains, len(self.column_names)),
12301230
dtype=float,
12311231
order='F',

cmdstanpy/stanfit/mle.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,9 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None:
5454
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
5555
self._metadata = InferenceMetadata(meta)
5656
self._column_names: Tuple[str, ...] = meta['column_names']
57-
assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy
58-
self._mle = meta['mle']
57+
self._mle: np.ndarray = meta['mle']
5958
if self._save_iterations:
60-
assert isinstance(
61-
meta['all_iters'], np.ndarray
62-
) # make the typechecker happy
63-
self._all_iters = meta['all_iters']
59+
self._all_iters: np.ndarray = meta['all_iters']
6460

6561
@property
6662
def column_names(self) -> Tuple[str, ...]:
@@ -202,13 +198,13 @@ def stan_variable(
202198
num_rows = self._all_iters.shape[0]
203199
else:
204200
num_rows = 1
201+
202+
result: Union[np.ndarray, float]
205203
if len(col_idxs) > 1: # container var
206204
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
207205
# pylint: disable=redundant-keyword-arg
208206
if num_rows > 1:
209-
result = self._all_iters[:, col_idxs].reshape( # type: ignore
210-
dims, order='F'
211-
)
207+
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
212208
else:
213209
result = self._mle[col_idxs].reshape(dims[1:], order="F")
214210
else: # scalar var
@@ -217,9 +213,7 @@ def stan_variable(
217213
result = self._all_iters[:, col_idx]
218214
else:
219215
result = float(self._mle[col_idx])
220-
assert isinstance(
221-
result, (np.ndarray, float)
222-
) # make the typechecker happy
216+
223217
return result
224218

225219
def stan_variables(

cmdstanpy/stanfit/vb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def stan_variable(
126126
raise ValueError('Unknown variable name: {}'.format(var))
127127
col_idxs = list(self._metadata.stan_vars_cols[var])
128128
shape: Tuple[int, ...] = ()
129+
result: Union[np.ndarray, float]
129130
if len(col_idxs) > 1:
130131
shape = self._metadata.stan_vars_dims[var]
131132
result = np.asarray(self._variational_mean)[col_idxs].reshape(
132133
shape, order="F"
133134
)
134135
else:
135136
result = float(self._variational_mean[col_idxs[0]])
136-
assert isinstance(result, (np.ndarray, float))
137137
return result
138138

139139
def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:

cmdstanpy/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
639639
for line in fd:
640640
iters += 1
641641
if save_iters:
642-
all_iters = np.empty(
642+
all_iters: np.ndarray = np.empty(
643643
(iters, len(dict['column_names'])), dtype=float, order='F'
644644
)
645645
# rescan to capture estimates
@@ -658,7 +658,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
658658
if save_iters:
659659
all_iters[i, :] = [float(x) for x in xs]
660660
if i == iters - 1:
661-
mle = np.array(xs, dtype=float)
661+
mle: np.ndarray = np.array(xs, dtype=float)
662662
dict['mle'] = mle
663663
if save_iters:
664664
dict['all_iters'] = all_iters
@@ -944,7 +944,7 @@ def read_metric(path: str) -> List[int]:
944944
with open(path, 'r') as fd:
945945
metric_dict = json.load(fd)
946946
if 'inv_metric' in metric_dict:
947-
dims_np = np.asarray(metric_dict['inv_metric'])
947+
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
948948
return list(dims_np.shape)
949949
else:
950950
raise ValueError(

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@ line_length = 80
1414
disallow_untyped_defs = true
1515
disallow_incomplete_defs = true
1616
no_implicit_optional = true
17-
# disallow_any_generics = true # disabled due to issues with numpy < 1.20
17+
# disallow_any_generics = true # disabled due to issues with numpy
1818
warn_return_any = true
1919
# warn_unused_ignores = true # can't be run on CI due to windows having different ctypes
20+
check_untyped_defs = true
21+
warn_redundant_casts = true
22+
strict_equality = true
23+
disallow_untyped_calls = true
2024

2125
[[tool.mypy.overrides]]
2226
module = [
2327
'tqdm.auto',
2428
'pandas',
25-
'ujson',
26-
'numpy', # these two are required for py36, which numpy 1.21 doesn't support
27-
'numpy.random'
2829
]
2930
ignore_missing_imports = true

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ mypy
77
testfixtures
88
tqdm
99
xarray
10+
types-ujson

0 commit comments

Comments
 (0)