1414import tempfile
1515from collections import OrderedDict
1616from collections .abc import Collection
17+ from enum import Enum , auto
1718from typing import (
1819 Any ,
1920 Callable ,
4647EXTENSION = '.exe' if platform .system () == 'Windows' else ''
4748
4849
50+ class BaseType (Enum ):
51+ """Stan langauge base type"""
52+
53+ COMPLEX = auto ()
54+ PRIM = auto () # future: int / real
55+
56+ def __repr__ (self ) -> str :
57+ return '<%s.%s>' % (self .__class__ .__name__ , self .name )
58+
59+
4960@functools .lru_cache (maxsize = None )
5061def get_logger () -> logging .Logger :
5162 """cmdstanpy logger"""
@@ -794,10 +805,14 @@ def munge_varnames(names: List[str]) -> List[str]:
794805 """
795806 if names is None :
796807 raise ValueError ('missing argument "names"' )
797- return [
798- re .sub (r',([\d,]+)$' , r'[\1]' , column .replace ('.' , ',' ))
799- for column in names
800- ]
808+ result = []
809+ for name in names :
810+ if '.' not in name :
811+ result .append (name )
812+ else :
813+ head , * rest = name .split ('.' )
814+ result .append ('' .join ([head , '[' , ',' .join (rest ), ']' ]))
815+ return result
801816
802817
803818def parse_method_vars (names : Tuple [str , ...]) -> Dict [str , Tuple [int , ...]]:
@@ -816,38 +831,52 @@ def parse_method_vars(names: Tuple[str, ...]) -> Dict[str, Tuple[int, ...]]:
816831
817832def parse_stan_vars (
818833 names : Tuple [str , ...]
819- ) -> Tuple [Dict [str , Tuple [int , ...]], Dict [str , Tuple [int , ...]]]:
834+ ) -> Tuple [
835+ Dict [str , Tuple [int , ...]], Dict [str , Tuple [int , ...]], Dict [str , BaseType ]
836+ ]:
820837 """
821838 Parses out Stan variable names (i.e., names not ending in `__`)
822839 from list of CSV file column names.
823- Returns a pair of dicts which map variable names to dimensions and
824- variable names to columns, respectively, using zero-based column indexing.
840+ Returns three dicts which map variable names to base type, dimensions and
841+ CSV file columns, respectively, using zero-based column indexing.
825842 Note: assumes: (a) munged varnames and (b) container vars are non-ragged
826- and dense; no checks size, indices.
843+ and dense; no checks on size, indices.
827844 """
828845 if names is None :
829846 raise ValueError ('missing argument "names"' )
830847 dims_map : Dict [str , Tuple [int , ...]] = {}
831848 cols_map : Dict [str , Tuple [int , ...]] = {}
849+ types_map : Dict [str , BaseType ] = {}
832850 idxs = []
833851 dims : Union [List [str ], List [int ]]
834852 for (idx , name ) in enumerate (names ):
853+ if name .endswith ('real]' ) or name .endswith ('imag]' ):
854+ basetype = BaseType .COMPLEX
855+ else :
856+ basetype = BaseType .PRIM
835857 idxs .append (idx )
836858 var , * dims = name .split ('[' )
837859 if var .endswith ('__' ):
838860 idxs = []
839861 elif len (dims ) == 0 :
840862 dims_map [var ] = ()
841863 cols_map [var ] = tuple (idxs )
864+ types_map [var ] = basetype
842865 idxs = []
843866 else :
844867 if idx < len (names ) - 1 and names [idx + 1 ].split ('[' )[0 ] == var :
845868 continue
846- dims = [int (x ) for x in dims [0 ][:- 1 ].split (',' )]
869+ coords = dims [0 ][:- 1 ].split (',' )
870+ if coords [- 1 ] == 'imag' :
871+ dims = [int (x ) for x in coords [:- 1 ]]
872+ dims .append (2 )
873+ else :
874+ dims = [int (x ) for x in coords ]
847875 dims_map [var ] = tuple (dims )
848876 cols_map [var ] = tuple (idxs )
877+ types_map [var ] = basetype
849878 idxs = []
850- return (dims_map , cols_map )
879+ return (dims_map , cols_map , types_map )
851880
852881
853882def scan_hmc_params (
0 commit comments