Skip to content

Commit a002f80

Browse files
Properly handle complex outputs (#537)
* parse complex names, dims * stan program for unit tests * save var type info in metadata * complex conversion OK, needs unit tests * logic fix * Start testing * Add MLE, VB and tests * Cleanup * Fix pytest errors, test parsing * Fix test * Fix typo Co-authored-by: Mitzi Morris <mitzi@panix.com>
1 parent 5ed096e commit a002f80

11 files changed

Lines changed: 189 additions & 39 deletions

File tree

cmdstanpy/stanfit/mcmc.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from cmdstanpy.cmdstan_args import Method, SamplerArgs
3333
from cmdstanpy.utils import (
3434
EXTENSION,
35+
BaseType,
3536
check_sampler_csv,
3637
cmdstan_path,
3738
cmdstan_version_before,
@@ -610,6 +611,7 @@ def draws_xr(
610611
self._metadata.stan_vars_cols[var],
611612
0,
612613
self.draws(inc_warmup=inc_warmup),
614+
self._metadata.stan_vars_types[var],
613615
)
614616
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
615617
'chain', 'draw', ...
@@ -674,8 +676,10 @@ def stan_variable(
674676
col_idxs = self._metadata.stan_vars_cols[var]
675677
if len(col_idxs) > 0:
676678
dims.extend(self._metadata.stan_vars_dims[var])
677-
# pylint: disable=redundant-keyword-arg
678-
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
679+
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
680+
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
681+
draws = draws[..., 0] + 1j * draws[..., 1]
682+
return draws
679683

680684
def stan_variables(self) -> Dict[str, np.ndarray]:
681685
"""
@@ -1106,6 +1110,7 @@ def draws_xr(
11061110
self._metadata.stan_vars_cols[var],
11071111
0,
11081112
self.draws(inc_warmup=inc_warmup),
1113+
self._metadata.stan_vars_types[var],
11091114
)
11101115
if inc_sample:
11111116
for var in mcmc_vars_list:
@@ -1116,6 +1121,7 @@ def draws_xr(
11161121
self.mcmc_sample.metadata.stan_vars_cols[var],
11171122
0,
11181123
self.mcmc_sample.draws(inc_warmup=inc_warmup),
1124+
self.mcmc_sample._metadata.stan_vars_types[var],
11191125
)
11201126

11211127
return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
@@ -1193,7 +1199,10 @@ def stan_variable(
11931199
if len(col_idxs) > 0:
11941200
dims.extend(self._metadata.stan_vars_dims[var])
11951201
# pylint: disable=redundant-keyword-arg
1196-
return self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
1202+
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
1203+
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
1204+
draws = draws[..., 0] + 1j * draws[..., 1]
1205+
return draws
11971206

11981207
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
11991208
"""
@@ -1262,6 +1271,7 @@ def build_xarray_data(
12621271
col_idxs: Tuple[int, ...],
12631272
start_row: int,
12641273
drawset: np.ndarray,
1274+
var_type: BaseType,
12651275
) -> None:
12661276
"""
12671277
Adds Stan variable name, labels, and values to a dictionary
@@ -1270,12 +1280,19 @@ def build_xarray_data(
12701280
var_dims: Tuple[str, ...] = ('draw', 'chain')
12711281
if dims:
12721282
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))
1283+
1284+
draws = drawset[start_row:, :, col_idxs].reshape(
1285+
*drawset.shape[:2], *dims, order="F"
1286+
)
1287+
if var_type == BaseType.COMPLEX:
1288+
draws = draws[..., 0] + 1j * draws[..., 1]
1289+
var_dims = var_dims[:-1]
1290+
12731291
data[var_name] = (
12741292
var_dims,
1275-
drawset[start_row:, :, col_idxs].reshape(
1276-
*drawset.shape[:2], *dims, order="F"
1277-
),
1293+
draws,
12781294
)
1295+
12791296
else:
12801297
data[var_name] = (
12811298
var_dims,

cmdstanpy/stanfit/metadata.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
from typing import Any, Dict, Tuple
55

6-
from cmdstanpy.utils import parse_method_vars, parse_stan_vars
6+
from cmdstanpy.utils import parse_method_vars, parse_stan_vars, BaseType
77

88

99
class InferenceMetadata:
@@ -17,11 +17,12 @@ def __init__(self, config: Dict[str, Any]) -> None:
1717
"""Initialize object from CSV headers"""
1818
self._cmdstan_config = config
1919
self._method_vars_cols = parse_method_vars(names=config['column_names'])
20-
stan_vars_dims, stan_vars_cols = parse_stan_vars(
20+
stan_vars_dims, stan_vars_cols, stan_vars_types = parse_stan_vars(
2121
names=config['column_names']
2222
)
2323
self._stan_vars_dims = stan_vars_dims
2424
self._stan_vars_cols = stan_vars_cols
25+
self._stan_vars_types = stan_vars_types
2526

2627
def __repr__(self) -> str:
2728
return 'Metadata:\n{}\n'.format(self._cmdstan_config)
@@ -66,3 +67,11 @@ def stan_vars_dims(self) -> Dict[str, Tuple[int, ...]]:
6667
Uses deepcopy for immutability.
6768
"""
6869
return copy.deepcopy(self._stan_vars_dims)
70+
71+
@property
72+
def stan_vars_types(self) -> Dict[str, BaseType]:
73+
"""
74+
Returns map from Stan program variable names to variable base type.
75+
Uses deepcopy for immutability.
76+
"""
77+
return copy.deepcopy(self._stan_vars_types)

cmdstanpy/stanfit/mle.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88

99
from cmdstanpy.cmdstan_args import Method, OptimizeArgs
10-
from cmdstanpy.utils import get_logger, scan_optimize_csv
10+
from cmdstanpy.utils import BaseType, get_logger, scan_optimize_csv
1111

1212
from .metadata import InferenceMetadata
1313
from .runset import RunSet
@@ -199,22 +199,24 @@ def stan_variable(
199199
else:
200200
num_rows = 1
201201

202-
result: Union[np.ndarray, float]
203202
if len(col_idxs) > 1: # container var
204203
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
205204
# pylint: disable=redundant-keyword-arg
206205
if num_rows > 1:
207206
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
208207
else:
209208
result = self._mle[col_idxs].reshape(dims[1:], order="F")
209+
210+
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
211+
result = result[..., 0] + 1j * result[..., 1]
212+
return result
213+
210214
else: # scalar var
211215
col_idx = col_idxs[0]
212216
if num_rows > 1:
213-
result = self._all_iters[:, col_idx]
217+
return self._all_iters[:, col_idx]
214218
else:
215-
result = float(self._mle[col_idx])
216-
217-
return result
219+
return float(self._mle[col_idx])
218220

219221
def stan_variables(
220222
self, inc_iterations: bool = False

cmdstanpy/stanfit/vb.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88

99
from cmdstanpy.cmdstan_args import Method
10-
from cmdstanpy.utils import scan_variational_csv
10+
from cmdstanpy.utils import BaseType, scan_variational_csv
1111

1212
from .metadata import InferenceMetadata
1313
from .runset import RunSet
@@ -126,15 +126,17 @@ def stan_variable(
126126
raise ValueError('Unknown variable name: {}'.format(var))
127127
col_idxs = list(self._metadata.stan_vars_cols[var])
128128
shape: Tuple[int, ...] = ()
129-
result: Union[np.ndarray, float]
130129
if len(col_idxs) > 1:
131130
shape = self._metadata.stan_vars_dims[var]
132-
result = np.asarray(self._variational_mean)[col_idxs].reshape(
133-
shape, order="F"
134-
)
131+
result: np.ndarray = np.asarray(self._variational_mean)[
132+
col_idxs
133+
].reshape(shape, order="F")
134+
135+
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
136+
result = result[..., 0] + 1j * result[..., 1]
137+
return result
135138
else:
136-
result = float(self._variational_mean[col_idxs[0]])
137-
return result
139+
return float(self._variational_mean[col_idxs[0]])
138140

139141
def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:
140142
"""

cmdstanpy/utils.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import tempfile
1515
from collections import OrderedDict
1616
from collections.abc import Collection
17+
from enum import Enum, auto
1718
from typing import (
1819
Any,
1920
Callable,
@@ -46,6 +47,16 @@
4647
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
4748

4849

50+
class BaseType(Enum):
51+
"""Stan langauge base type"""
52+
53+
COMPLEX = auto()
54+
PRIM = auto() # future: int / real
55+
56+
def __repr__(self) -> str:
57+
return '<%s.%s>' % (self.__class__.__name__, self.name)
58+
59+
4960
@functools.lru_cache(maxsize=None)
5061
def get_logger() -> logging.Logger:
5162
"""cmdstanpy logger"""
@@ -794,10 +805,14 @@ def munge_varnames(names: List[str]) -> List[str]:
794805
"""
795806
if names is None:
796807
raise ValueError('missing argument "names"')
797-
return [
798-
re.sub(r',([\d,]+)$', r'[\1]', column.replace('.', ','))
799-
for column in names
800-
]
808+
result = []
809+
for name in names:
810+
if '.' not in name:
811+
result.append(name)
812+
else:
813+
head, *rest = name.split('.')
814+
result.append(''.join([head, '[', ','.join(rest), ']']))
815+
return result
801816

802817

803818
def parse_method_vars(names: Tuple[str, ...]) -> Dict[str, Tuple[int, ...]]:
@@ -816,38 +831,52 @@ def parse_method_vars(names: Tuple[str, ...]) -> Dict[str, Tuple[int, ...]]:
816831

817832
def parse_stan_vars(
818833
names: Tuple[str, ...]
819-
) -> Tuple[Dict[str, Tuple[int, ...]], Dict[str, Tuple[int, ...]]]:
834+
) -> Tuple[
835+
Dict[str, Tuple[int, ...]], Dict[str, Tuple[int, ...]], Dict[str, BaseType]
836+
]:
820837
"""
821838
Parses out Stan variable names (i.e., names not ending in `__`)
822839
from list of CSV file column names.
823-
Returns a pair of dicts which map variable names to dimensions and
824-
variable names to columns, respectively, using zero-based column indexing.
840+
Returns three dicts which map variable names to base type, dimensions and
841+
CSV file columns, respectively, using zero-based column indexing.
825842
Note: assumes: (a) munged varnames and (b) container vars are non-ragged
826-
and dense; no checks size, indices.
843+
and dense; no checks on size, indices.
827844
"""
828845
if names is None:
829846
raise ValueError('missing argument "names"')
830847
dims_map: Dict[str, Tuple[int, ...]] = {}
831848
cols_map: Dict[str, Tuple[int, ...]] = {}
849+
types_map: Dict[str, BaseType] = {}
832850
idxs = []
833851
dims: Union[List[str], List[int]]
834852
for (idx, name) in enumerate(names):
853+
if name.endswith('real]') or name.endswith('imag]'):
854+
basetype = BaseType.COMPLEX
855+
else:
856+
basetype = BaseType.PRIM
835857
idxs.append(idx)
836858
var, *dims = name.split('[')
837859
if var.endswith('__'):
838860
idxs = []
839861
elif len(dims) == 0:
840862
dims_map[var] = ()
841863
cols_map[var] = tuple(idxs)
864+
types_map[var] = basetype
842865
idxs = []
843866
else:
844867
if idx < len(names) - 1 and names[idx + 1].split('[')[0] == var:
845868
continue
846-
dims = [int(x) for x in dims[0][:-1].split(',')]
869+
coords = dims[0][:-1].split(',')
870+
if coords[-1] == 'imag':
871+
dims = [int(x) for x in coords[:-1]]
872+
dims.append(2)
873+
else:
874+
dims = [int(x) for x in coords]
847875
dims_map[var] = tuple(dims)
848876
cols_map[var] = tuple(idxs)
877+
types_map[var] = basetype
849878
idxs = []
850-
return (dims_map, cols_map)
879+
return (dims_map, cols_map, types_map)
851880

852881

853882
def scan_hmc_params(

test/data/complex_var.stan

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
transformed data {
2+
array[10] int<lower=0,upper=1> y = {0,1,0,0,0,0,0,0,0,1};
3+
}
4+
parameters {
5+
real<lower=0,upper=1> theta;
6+
}
7+
model {
8+
theta ~ beta(1,1);
9+
y ~ bernoulli(theta);
10+
}
11+
12+
// model segment is just so that VB works
13+
14+
generated quantities {
15+
int a = 1;
16+
array[2,3,2] int ys = {{{3,0},{0,4}, {5,0}}, {{0,1}, {0,2}, {0,3}}};
17+
array[2,3] complex zs = {{3,4i,5},{1i,2i,3i}};
18+
complex z = 3 + 4i;
19+
20+
array[2] int imag = {3,4};
21+
}

test/test_generate_quantities.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,27 @@ def test_show_console(self):
423423
self.assertTrue('Chain [3] method = generate' in console)
424424
self.assertTrue('Chain [4] method = generate' in console)
425425

426+
def test_complex_output(self):
427+
stan_bern = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
428+
model_bern = CmdStanModel(stan_file=stan_bern)
429+
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
430+
fit_sampling = model_bern.sample(chains=1, iter_sampling=10, data=jdata)
431+
432+
stan = os.path.join(DATAFILES_PATH, 'complex_var.stan')
433+
model = CmdStanModel(stan_file=stan)
434+
fit = model.generate_quantities(mcmc_sample=fit_sampling)
435+
436+
self.assertEqual(fit.stan_variable('zs').shape, (10, 2, 3))
437+
self.assertEqual(fit.stan_variable('z')[0], 3 + 4j)
438+
# make sure the name 'imag' isn't magic
439+
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))
440+
441+
self.assertNotIn("zs_dim_2", fit.draws_xr())
442+
# getting a raw scalar out of xarray is heavy
443+
self.assertEqual(
444+
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
445+
)
446+
426447

427448
if __name__ == '__main__':
428449
unittest.main()

test/test_optimize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,16 @@ def test_exe_only(self):
599599
mle.optimized_params_np[1], mle.optimized_params_dict['theta']
600600
)
601601

602+
def test_complex_output(self):
603+
stan = os.path.join(DATAFILES_PATH, 'complex_var.stan')
604+
model = CmdStanModel(stan_file=stan)
605+
fit = model.optimize()
606+
607+
self.assertEqual(fit.stan_variable('zs').shape, (2, 3))
608+
self.assertEqual(fit.stan_variable('z'), 3 + 4j)
609+
# make sure the name 'imag' isn't magic
610+
self.assertEqual(fit.stan_variable('imag').shape, (2,))
611+
602612

603613
if __name__ == '__main__':
604614
unittest.main()

test/test_sample.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,22 @@ def test_single_row_csv(self):
17621762
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
17631763
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)
17641764

1765+
def test_complex_output(self):
1766+
stan = os.path.join(DATAFILES_PATH, 'complex_var.stan')
1767+
model = CmdStanModel(stan_file=stan)
1768+
fit = model.sample(chains=1, iter_sampling=10)
1769+
1770+
self.assertEqual(fit.stan_variable('zs').shape, (10, 2, 3))
1771+
self.assertEqual(fit.stan_variable('z')[0], 3 + 4j)
1772+
# make sure the name 'imag' isn't magic
1773+
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))
1774+
1775+
self.assertNotIn("zs_dim_2", fit.draws_xr())
1776+
# getting a raw scalar out of xarray is heavy
1777+
self.assertEqual(
1778+
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
1779+
)
1780+
17651781

17661782
if __name__ == '__main__':
17671783
unittest.main()

0 commit comments

Comments
 (0)