@@ -48,27 +48,30 @@ class CompilerOptions:
4848 Attributes:
4949 stanc_options - stanc compiler flags, options
5050 cpp_options - makefile options (NAME=value)
51+ user_header - path to a user .hpp file to include during compilation
5152 """
5253
5354 def __init__ (
5455 self ,
5556 * ,
5657 stanc_options : Optional [Dict [str , Any ]] = None ,
5758 cpp_options : Optional [Dict [str , Any ]] = None ,
59+ user_header : Optional [str ] = None ,
5860 logger : Optional [logging .Logger ] = None ,
5961 ) -> None :
6062 """Initialize object."""
6163 self ._stanc_options = stanc_options if stanc_options is not None else {}
6264 self ._cpp_options = cpp_options if cpp_options is not None else {}
65+ self ._user_header = user_header if user_header is not None else ''
6366 if logger is not None :
6467 get_logger ().warning (
6568 "Parameter 'logger' is deprecated."
6669 " Control logging behavior via logging.getLogger('cmdstanpy')"
6770 )
6871
6972 def __repr__ (self ) -> str :
70- return 'stanc_options={}, cpp_options={}' .format (
71- self ._stanc_options , self ._cpp_options
73+ return 'stanc_options={}, cpp_options={}, user_header={} ' .format (
74+ self ._stanc_options , self ._cpp_options , self . _user_header
7275 )
7376
7477 @property
@@ -81,13 +84,19 @@ def cpp_options(self) -> Dict[str, Union[bool, int]]:
8184 """C++ compiler options."""
8285 return self ._cpp_options
8386
87+ @property
88+ def user_header (self ) -> str :
89+ """The user header file if it exists, otherwise empty"""
90+ return self ._user_header
91+
8492 def validate (self ) -> None :
8593 """
8694 Check compiler args.
8795 Raise ValueError if invalid options are found.
8896 """
8997 self .validate_stanc_opts ()
9098 self .validate_cpp_opts ()
99+ self .validate_user_header ()
91100
92101 def validate_stanc_opts (self ) -> None :
93102 """
@@ -104,17 +113,15 @@ def validate_stanc_opts(self) -> None:
104113 get_logger ().info ('ignoring compiler option: %s' , key )
105114 ignore .append (key )
106115 elif key not in STANC_OPTS :
107- raise ValueError (
108- 'unknown stanc compiler option: {}' .format (key )
109- )
116+ raise ValueError (f'unknown stanc compiler option: { key } ' )
110117 elif key == 'include_paths' :
111118 paths = val
112119 if isinstance (val , str ):
113120 paths = val .split (',' )
114121 elif not isinstance (val , list ):
115122 raise ValueError (
116123 'Invalid include_paths, expecting list or '
117- 'string, found type: {}.' . format ( type (val ))
124+ f 'string, found type: { type (val )} .'
118125 )
119126 elif key == 'use-opencl' :
120127 if self ._cpp_options is None :
@@ -149,10 +156,37 @@ def validate_cpp_opts(self) -> None:
149156 val = self ._cpp_options [key ]
150157 if not isinstance (val , int ) or val < 0 :
151158 raise ValueError (
152- '{ } must be a non-negative integer value,'
153- ' found {}.' . format ( key , val )
159+ f' { key } must be a non-negative integer value,'
160+ f ' found { val } .'
154161 )
155162
163+ def validate_user_header (self ) -> None :
164+ """
165+ User header exists.
166+ Raise ValueError if bad config is found.
167+ """
168+ if self ._user_header != "" :
169+ if not (
170+ os .path .exists (self ._user_header )
171+ and os .path .isfile (self ._user_header )
172+ ):
173+ raise ValueError (
174+ f"User header file { self ._user_header } cannot be found"
175+ )
176+ if self ._user_header [- 4 :] != '.hpp' :
177+ raise ValueError (
178+ f"Header file must end in .hpp, got { self ._user_header } "
179+ )
180+ if "allow_undefined" not in self ._stanc_options :
181+ self ._stanc_options ["allow_undefined" ] = True
182+ # set full path
183+ self ._user_header = os .path .abspath (self ._user_header )
184+
185+ if ' ' in self ._user_header :
186+ raise ValueError (
187+ "User header must be in a folder with no spaces in path!"
188+ )
189+
156190 def add (self , new_opts : "CompilerOptions" ) -> None : # noqa: disable=Q000
157191 """Adds options to existing set of compiler options."""
158192 if new_opts .stanc_options is not None :
@@ -167,6 +201,8 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
167201 if new_opts .cpp_options is not None :
168202 for key , val in new_opts .cpp_options .items ():
169203 self ._cpp_options [key ] = val
204+ if new_opts .user_header != '' and self ._user_header == '' :
205+ self ._user_header = new_opts .user_header
170206
171207 def add_include_path (self , path : str ) -> None :
172208 """Adds include path to existing set of compiler options."""
@@ -191,10 +227,12 @@ def compose(self) -> List[str]:
191227 )
192228 )
193229 elif key == 'name' :
194- opts .append ('STANCFLAGS+=--{}={}' . format ( key , val ) )
230+ opts .append (f 'STANCFLAGS+=--name= { val } ' )
195231 else :
196- opts .append ('STANCFLAGS+=--{}' . format ( key ) )
232+ opts .append (f 'STANCFLAGS+=--{ key } ' )
197233 if self ._cpp_options is not None and len (self ._cpp_options ) > 0 :
198234 for key , val in self ._cpp_options .items ():
199- opts .append ('{}={}' .format (key , val ))
235+ opts .append (f'{ key } ={ val } ' )
236+ if self ._user_header :
237+ opts .append (f'USER_HEADER={ self ._user_header } ' )
200238 return opts
0 commit comments