Skip to content

Commit 683a135

Browse files
authored
Merge pull request #511 from stan-dev/drop-py36
Remove Python 3.6 from test set
2 parents a14e9bd + 5b27740 commit 683a135

10 files changed

Lines changed: 32 additions & 42 deletions

File tree

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ jobs:
3939
strategy:
4040
matrix:
4141
os: [ubuntu-latest, macos-latest, windows-latest]
42-
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
43-
fail-fast: false
42+
python-version: [3.7, 3.8, 3.9, "3.10"]
4443
env:
4544
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
4645
steps:
@@ -60,6 +59,7 @@ jobs:
6059
pip install codecov
6160
6261
- name: Run flake8, pylint, mypy
62+
if: matrix.python-version == '3.10'
6363
run: |
6464
flake8 cmdstanpy test
6565
pylint -v cmdstanpy test

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
- id: mypy
2626
# Copied from setup.cfg
2727
exclude: ^test/
28-
additional_dependencies: [ numpy >= 1.21 ]
28+
additional_dependencies: [ numpy >= 1.22, types-ujson ]
2929
# local uses the user-installed pylint, this allows dependency checking
3030
- repo: local
3131
hooks:

cmdstanpy/cmdstan_args.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ def validate(self, chains: Optional[int]) -> None:
202202
if all(isinstance(elem, dict) for elem in self.metric):
203203
metric_files: List[str] = []
204204
for i, metric in enumerate(self.metric):
205-
assert isinstance(
206-
metric, dict
207-
) # make the typechecker happy
208-
metric_dict: Dict[str, Any] = metric
205+
metric_dict: Dict[str, Any] = metric # type: ignore
209206
if 'inv_metric' not in metric_dict:
210207
raise ValueError(
211208
'Entry "inv_metric" not found in metric dict '

cmdstanpy/stanfit/mcmc.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def __init__(
9595
self._save_warmup = sampler_args.save_warmup
9696
self._sig_figs = runset._args.sig_figs
9797
# info from CSV values, instantiated lazily
98-
self._metric = np.array(())
99-
self._step_size = np.array(())
100-
self._draws = np.array(())
98+
self._metric: np.ndarray = np.array(())
99+
self._step_size: np.ndarray = np.array(())
100+
self._draws: np.ndarray = np.array(())
101101
# info from CSV initial comments and header
102102
config = self._validate_csv_files()
103103
self._metadata: InferenceMetadata = InferenceMetadata(config)
@@ -246,7 +246,7 @@ def draws(
246246

247247
if concat_chains:
248248
return flatten_chains(self._draws[start_idx:, :, :])
249-
return self._draws[start_idx:, :, :] # type: ignore
249+
return self._draws[start_idx:, :, :]
250250

251251
def _validate_csv_files(self) -> Dict[str, Any]:
252252
"""
@@ -675,9 +675,7 @@ def stan_variable(
675675
if len(col_idxs) > 0:
676676
dims.extend(self._metadata.stan_vars_dims[var])
677677
# pylint: disable=redundant-keyword-arg
678-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
679-
dims, order='F'
680-
)
678+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
681679

682680
def stan_variables(self) -> Dict[str, np.ndarray]:
683681
"""
@@ -748,7 +746,7 @@ def __init__(
748746
)
749747
self.runset = runset
750748
self.mcmc_sample = mcmc_sample
751-
self._draws = np.array(())
749+
self._draws: np.ndarray = np.array(())
752750
config = self._validate_csv_files()
753751
self._metadata = InferenceMetadata(config)
754752

@@ -765,7 +763,7 @@ def __repr__(self) -> str:
765763
)
766764
return repr
767765

768-
def _validate_csv_files(self) -> dict:
766+
def _validate_csv_files(self) -> Dict[str, Any]:
769767
"""
770768
Checks that Stan CSV output files for all chains are consistent
771769
and returns dict containing config and column names.
@@ -910,13 +908,13 @@ def draws(
910908
if concat_chains:
911909
return flatten_chains(self._draws[start_idx:, :, :])
912910
if inc_sample:
913-
return np.dstack( # type: ignore
911+
return np.dstack(
914912
(
915913
np.delete(self.mcmc_sample.draws(), drop_cols, axis=1),
916914
self._draws,
917915
)
918916
)[start_idx:, :, :]
919-
return self._draws[start_idx:, :, :] # type: ignore
917+
return self._draws[start_idx:, :, :]
920918

921919
def draws_pd(
922920
self,
@@ -1195,9 +1193,7 @@ def stan_variable(
11951193
if len(col_idxs) > 0:
11961194
dims.extend(self._metadata.stan_vars_dims[var])
11971195
# pylint: disable=redundant-keyword-arg
1198-
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
1199-
dims, order='F'
1200-
)
1196+
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
12011197

12021198
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
12031199
"""
@@ -1229,7 +1225,7 @@ def _assemble_generated_quantities(self) -> None:
12291225
# use numpy loadtxt
12301226
warmup = self.mcmc_sample.metadata.cmdstan_config['save_warmup']
12311227
num_draws = self.mcmc_sample.draws(inc_warmup=warmup).shape[0]
1232-
gq_sample = np.empty(
1228+
gq_sample: np.ndarray = np.empty(
12331229
(num_draws, self.chains, len(self.column_names)),
12341230
dtype=float,
12351231
order='F',

cmdstanpy/stanfit/mle.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,9 @@ def _set_mle_attrs(self, sample_csv_0: str) -> None:
5454
meta = scan_optimize_csv(sample_csv_0, self._save_iterations)
5555
self._metadata = InferenceMetadata(meta)
5656
self._column_names: Tuple[str, ...] = meta['column_names']
57-
assert isinstance(meta['mle'], np.ndarray) # make the typechecker happy
58-
self._mle = meta['mle']
57+
self._mle: np.ndarray = meta['mle']
5958
if self._save_iterations:
60-
assert isinstance(
61-
meta['all_iters'], np.ndarray
62-
) # make the typechecker happy
63-
self._all_iters = meta['all_iters']
59+
self._all_iters: np.ndarray = meta['all_iters']
6460

6561
@property
6662
def column_names(self) -> Tuple[str, ...]:
@@ -202,13 +198,13 @@ def stan_variable(
202198
num_rows = self._all_iters.shape[0]
203199
else:
204200
num_rows = 1
201+
202+
result: Union[np.ndarray, float]
205203
if len(col_idxs) > 1: # container var
206204
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
207205
# pylint: disable=redundant-keyword-arg
208206
if num_rows > 1:
209-
result = self._all_iters[:, col_idxs].reshape( # type: ignore
210-
dims, order='F'
211-
)
207+
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
212208
else:
213209
result = self._mle[col_idxs].reshape(dims[1:], order="F")
214210
else: # scalar var
@@ -217,9 +213,7 @@ def stan_variable(
217213
result = self._all_iters[:, col_idx]
218214
else:
219215
result = float(self._mle[col_idx])
220-
assert isinstance(
221-
result, (np.ndarray, float)
222-
) # make the typechecker happy
216+
223217
return result
224218

225219
def stan_variables(

cmdstanpy/stanfit/vb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def stan_variable(
126126
raise ValueError('Unknown variable name: {}'.format(var))
127127
col_idxs = list(self._metadata.stan_vars_cols[var])
128128
shape: Tuple[int, ...] = ()
129+
result: Union[np.ndarray, float]
129130
if len(col_idxs) > 1:
130131
shape = self._metadata.stan_vars_dims[var]
131132
result = np.asarray(self._variational_mean)[col_idxs].reshape(
132133
shape, order="F"
133134
)
134135
else:
135136
result = float(self._variational_mean[col_idxs[0]])
136-
assert isinstance(result, (np.ndarray, float))
137137
return result
138138

139139
def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:

cmdstanpy/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
639639
for line in fd:
640640
iters += 1
641641
if save_iters:
642-
all_iters = np.empty(
642+
all_iters: np.ndarray = np.empty(
643643
(iters, len(dict['column_names'])), dtype=float, order='F'
644644
)
645645
# rescan to capture estimates
@@ -658,7 +658,7 @@ def scan_optimize_csv(path: str, save_iters: bool = False) -> Dict[str, Any]:
658658
if save_iters:
659659
all_iters[i, :] = [float(x) for x in xs]
660660
if i == iters - 1:
661-
mle = np.array(xs, dtype=float)
661+
mle: np.ndarray = np.array(xs, dtype=float)
662662
dict['mle'] = mle
663663
if save_iters:
664664
dict['all_iters'] = all_iters
@@ -944,7 +944,7 @@ def read_metric(path: str) -> List[int]:
944944
with open(path, 'r') as fd:
945945
metric_dict = json.load(fd)
946946
if 'inv_metric' in metric_dict:
947-
dims_np = np.asarray(metric_dict['inv_metric'])
947+
dims_np: np.ndarray = np.asarray(metric_dict['inv_metric'])
948948
return list(dims_np.shape)
949949
else:
950950
raise ValueError(

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@ line_length = 80
1414
disallow_untyped_defs = true
1515
disallow_incomplete_defs = true
1616
no_implicit_optional = true
17-
# disallow_any_generics = true # disabled due to issues with numpy < 1.20
17+
# disallow_any_generics = true # disabled due to issues with numpy
1818
warn_return_any = true
1919
# warn_unused_ignores = true # can't be run on CI due to windows having different ctypes
20+
check_untyped_defs = true
21+
warn_redundant_casts = true
22+
strict_equality = true
23+
disallow_untyped_calls = true
2024

2125
[[tool.mypy.overrides]]
2226
module = [
2327
'tqdm.auto',
2428
'pandas',
25-
'ujson',
26-
'numpy', # these two are required for py36, which numpy 1.21 doesn't support
27-
'numpy.random'
2829
]
2930
ignore_missing_imports = true

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ mypy
77
testfixtures
88
tqdm
99
xarray
10+
types-ujson

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def get_version() -> str:
8080
]
8181
},
8282
install_requires=INSTALL_REQUIRES,
83+
python_requires='>=3.7',
8384
extras_require=EXTRAS_REQUIRE,
8485
classifiers=_classifiers.strip().split('\n'),
8586
)

0 commit comments

Comments
 (0)