Skip to content

Commit ae09cc3

Browse files
committed
Allow user to supply a c++ header file
1 parent 8596085 commit ae09cc3

15 files changed

Lines changed: 337 additions & 28 deletions

cmdstanpy/compiler_opts.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,30 @@ class CompilerOptions:
4848
Attributes:
4949
stanc_options - stanc compiler flags, options
5050
cpp_options - makefile options (NAME=value)
51+
user_header - path to a user .hpp file to include during compilation
5152
"""
5253

5354
def __init__(
5455
self,
5556
*,
5657
stanc_options: Optional[Dict[str, Any]] = None,
5758
cpp_options: Optional[Dict[str, Any]] = None,
59+
user_header: Optional[str] = None,
5860
logger: Optional[logging.Logger] = None,
5961
) -> None:
6062
"""Initialize object."""
6163
self._stanc_options = stanc_options if stanc_options is not None else {}
6264
self._cpp_options = cpp_options if cpp_options is not None else {}
65+
self._user_header = user_header if user_header is not None else ''
6366
if logger is not None:
6467
get_logger().warning(
6568
"Parameter 'logger' is deprecated."
6669
" Control logging behavior via logging.getLogger('cmdstanpy')"
6770
)
6871

6972
def __repr__(self) -> str:
70-
return 'stanc_options={}, cpp_options={}'.format(
71-
self._stanc_options, self._cpp_options
73+
return 'stanc_options={}, cpp_options={}, user_header={}'.format(
74+
self._stanc_options, self._cpp_options, self._user_header
7275
)
7376

7477
@property
@@ -81,13 +84,19 @@ def cpp_options(self) -> Dict[str, Union[bool, int]]:
8184
"""C++ compiler options."""
8285
return self._cpp_options
8386

87+
@property
88+
def user_header(self) -> str:
89+
"""The user header file if it exists, otherwise empty"""
90+
return self._user_header
91+
8492
def validate(self) -> None:
8593
"""
8694
Check compiler args.
8795
Raise ValueError if invalid options are found.
8896
"""
8997
self.validate_stanc_opts()
9098
self.validate_cpp_opts()
99+
self.validate_user_header()
91100

92101
def validate_stanc_opts(self) -> None:
93102
"""
@@ -104,17 +113,15 @@ def validate_stanc_opts(self) -> None:
104113
get_logger().info('ignoring compiler option: %s', key)
105114
ignore.append(key)
106115
elif key not in STANC_OPTS:
107-
raise ValueError(
108-
'unknown stanc compiler option: {}'.format(key)
109-
)
116+
raise ValueError(f'unknown stanc compiler option: {key}')
110117
elif key == 'include_paths':
111118
paths = val
112119
if isinstance(val, str):
113120
paths = val.split(',')
114121
elif not isinstance(val, list):
115122
raise ValueError(
116123
'Invalid include_paths, expecting list or '
117-
'string, found type: {}.'.format(type(val))
124+
f'string, found type: {type(val)}.'
118125
)
119126
elif key == 'use-opencl':
120127
if self._cpp_options is None:
@@ -149,10 +156,37 @@ def validate_cpp_opts(self) -> None:
149156
val = self._cpp_options[key]
150157
if not isinstance(val, int) or val < 0:
151158
raise ValueError(
152-
'{} must be a non-negative integer value,'
153-
' found {}.'.format(key, val)
159+
f'{key} must be a non-negative integer value,'
160+
f' found {val}.'
154161
)
155162

163+
def validate_user_header(self) -> None:
164+
"""
165+
User header exists.
166+
Raise ValueError if bad config is found.
167+
"""
168+
if self._user_header != "":
169+
if not (
170+
os.path.exists(self._user_header)
171+
and os.path.isfile(self._user_header)
172+
):
173+
raise ValueError(
174+
f"User header file {self._user_header} cannot be found"
175+
)
176+
if self._user_header[-4:] != '.hpp':
177+
raise ValueError(
178+
f"Header file must end in .hpp, got {self._user_header}"
179+
)
180+
if "allow_undefined" not in self._stanc_options:
181+
self._stanc_options["allow_undefined"] = True
182+
# set full path
183+
self._user_header = os.path.abspath(self._user_header)
184+
185+
if ' ' in self._user_header:
186+
raise ValueError(
187+
"User header must be in a folder with no spaces in path!"
188+
)
189+
156190
def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
157191
"""Adds options to existing set of compiler options."""
158192
if new_opts.stanc_options is not None:
@@ -167,6 +201,8 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
167201
if new_opts.cpp_options is not None:
168202
for key, val in new_opts.cpp_options.items():
169203
self._cpp_options[key] = val
204+
if new_opts.user_header != '' and self._user_header == '':
205+
self._user_header = new_opts.user_header
170206

171207
def add_include_path(self, path: str) -> None:
172208
"""Adds include path to existing set of compiler options."""
@@ -191,10 +227,12 @@ def compose(self) -> List[str]:
191227
)
192228
)
193229
elif key == 'name':
194-
opts.append('STANCFLAGS+=--{}={}'.format(key, val))
230+
opts.append(f'STANCFLAGS+=--name={val}')
195231
else:
196-
opts.append('STANCFLAGS+=--{}'.format(key))
232+
opts.append(f'STANCFLAGS+=--{key}')
197233
if self._cpp_options is not None and len(self._cpp_options) > 0:
198234
for key, val in self._cpp_options.items():
199-
opts.append('{}={}'.format(key, val))
235+
opts.append(f'{key}={val}')
236+
if self._user_header:
237+
opts.append(f'USER_HEADER={self._user_header}')
200238
return opts

cmdstanpy/model.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ class CmdStanModel:
6868
:param cpp_options: Options for C++ compiler, specified as a Python
6969
dictionary containing C++ compiler option name, value pairs.
7070
Optional.
71+
72+
:param user_header: A path to a header file to include during C++
73+
compilation.
74+
Optional.
7175
"""
7276

7377
def __init__(
@@ -78,6 +82,7 @@ def __init__(
7882
compile: bool = True,
7983
stanc_options: Optional[Dict[str, Any]] = None,
8084
cpp_options: Optional[Dict[str, Any]] = None,
85+
user_header: Optional[str] = None,
8186
logger: Optional[logging.Logger] = None,
8287
) -> None:
8388
"""
@@ -89,12 +94,16 @@ def __init__(
8994
:param compile: Whether or not to compile the model.
9095
:param stanc_options: Options for stanc compiler.
9196
:param cpp_options: Options for C++ compiler.
97+
:param user_header: A path to a header file to include during C++
98+
compilation.
9299
"""
93100
self._name = ''
94101
self._stan_file = None
95102
self._exe_file = None
96103
self._compiler_options = CompilerOptions(
97-
stanc_options=stanc_options, cpp_options=cpp_options
104+
stanc_options=stanc_options,
105+
cpp_options=cpp_options,
106+
user_header=user_header,
98107
)
99108
if logger is not None:
100109
get_logger().warning(
@@ -227,6 +236,11 @@ def cpp_options(self) -> Dict[str, Union[bool, int]]:
227236
"""Options to C++ compilers."""
228237
return self._compiler_options._cpp_options
229238

239+
@property
240+
def user_header(self) -> str:
241+
"""The user header file if it exists, otherwise empty"""
242+
return self._compiler_options._user_header
243+
230244
def code(self) -> Optional[str]:
231245
"""Return Stan program as a string."""
232246
if not self._stan_file:
@@ -247,6 +261,7 @@ def compile(
247261
force: bool = False,
248262
stanc_options: Optional[Dict[str, Any]] = None,
249263
cpp_options: Optional[Dict[str, Any]] = None,
264+
user_header: Optional[str] = None,
250265
override_options: bool = False,
251266
) -> None:
252267
"""
@@ -264,6 +279,8 @@ def compile(
264279
265280
:param stanc_options: Options for stanc compiler.
266281
:param cpp_options: Options for C++ compiler.
282+
:param user_header: A path to a header file to include during C++
283+
compilation.
267284
268285
:param override_options: When ``True``, override existing option.
269286
When ``False``, add/replace existing options. Default is ``False``.
@@ -272,9 +289,15 @@ def compile(
272289
raise RuntimeError('Please specify source file')
273290

274291
compiler_options = None
275-
if not (stanc_options is None and cpp_options is None):
292+
if not (
293+
stanc_options is None
294+
and cpp_options is None
295+
and user_header is None
296+
):
276297
compiler_options = CompilerOptions(
277-
stanc_options=stanc_options, cpp_options=cpp_options
298+
stanc_options=stanc_options,
299+
cpp_options=cpp_options,
300+
user_header=user_header,
278301
)
279302
compiler_options.validate()
280303
if self._compiler_options is None:

docsrc/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ __________________
88
examples/Maximum Likelihood Estimation.ipynb
99
examples/Variational Inference.ipynb
1010
examples/Run Generated Quantities.ipynb
11+
examples/Using External C++.ipynb

docsrc/examples/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
*.hpp
88
*.exe
99
*.csv
10-
.ipynb_checkpoints
10+
.ipynb_checkpoints
11+
!make_odds.hpp
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"source": [
6+
"# Advanced Topic: Using External C++ Functions\n",
7+
"\n",
8+
"This is based on the relevant portion of the CmdStan documentation [here](https://mc-stan.org/docs/cmdstan-guide/using-external-cpp-code.html)"
9+
],
10+
"metadata": {}
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"source": [
15+
"Consider the following Stan model, based on the bernoulli example."
16+
],
17+
"metadata": {}
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"source": [
23+
"import os\n",
24+
"from cmdstanpy import cmdstan_path, CmdStanModel\n",
25+
"model_external = CmdStanModel(stan_file='bernoulli_external.stan', compile=False)\n",
26+
"print(model_external.code())"
27+
],
28+
"outputs": [],
29+
"metadata": {}
30+
},
31+
{
32+
"cell_type": "markdown",
33+
"source": [
34+
"As you can see, it features a function declaration for `make_odds`, but no definition. If we try to compile this, we will get an error. "
35+
],
36+
"metadata": {}
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"source": [
42+
"model_external.compile()"
43+
],
44+
"outputs": [],
45+
"metadata": {}
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"source": [
50+
"Even enabling the `--allow_undefined` flag to stanc3 will not allow this model to be compiled quite yet."
51+
],
52+
"metadata": {}
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"source": [
58+
"model_external.compile(stanc_options={'allow_undefined':True})"
59+
],
60+
"outputs": [],
61+
"metadata": {}
62+
},
63+
{
64+
"cell_type": "markdown",
65+
"source": [
66+
"To resolve this, we need to both tell the Stan compiler an undefined function is okay **and** let C++ know what it should be. \n",
67+
"\n",
68+
"We can provide a definition in a C++ header file by using the `user_header` argument to either the CmdStanModel constructor or the `compile` method. \n",
69+
"\n",
70+
"This will enables the `allow_undefined` flag automatically."
71+
],
72+
"metadata": {}
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"source": [
78+
"model_external.compile(user_header='make_odds.hpp')"
79+
],
80+
"outputs": [],
81+
"metadata": {}
82+
},
83+
{
84+
"cell_type": "markdown",
85+
"source": [
86+
"We can then run this model and inspect the output"
87+
],
88+
"metadata": {}
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"source": [
94+
"fit = model_external.sample(data={'N':10, 'y':[0,1,0,0,0,0,0,0,0,1]})\n",
95+
"fit.stan_variable('odds')"
96+
],
97+
"outputs": [],
98+
"metadata": {}
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"source": [
103+
"The contents of this header file are a bit complicated unless you are familiar with the C++ internals of Stan, so they are presented without comment:\n",
104+
"\n",
105+
"```c++\n",
106+
"#include <boost/math/tools/promotion.hpp>\n",
107+
"#include <ostream>\n",
108+
"\n",
109+
"namespace bernoulli_model_namespace {\n",
110+
" template <typename T0__> inline typename\n",
111+
" boost::math::tools::promote_args<T0__>::type \n",
112+
" make_odds(const T0__& theta, std::ostream* pstream__) {\n",
113+
" return theta / (1 - theta); \n",
114+
" }\n",
115+
"}\n",
116+
"```"
117+
],
118+
"metadata": {}
119+
}
120+
],
121+
"metadata": {
122+
"orig_nbformat": 4,
123+
"language_info": {
124+
"name": "python",
125+
"version": "3.9.5",
126+
"mimetype": "text/x-python",
127+
"codemirror_mode": {
128+
"name": "ipython",
129+
"version": 3
130+
},
131+
"pygments_lexer": "ipython3",
132+
"nbconvert_exporter": "python",
133+
"file_extension": ".py"
134+
},
135+
"kernelspec": {
136+
"name": "python3",
137+
"display_name": "Python 3.9.5 64-bit ('stan': conda)"
138+
},
139+
"interpreter": {
140+
"hash": "d31ce8e45781476cfd394e192e0962028add96ff436d4fd4e560a347d206b9cb"
141+
}
142+
},
143+
"nbformat": 4,
144+
"nbformat_minor": 2
145+
}

0 commit comments

Comments
 (0)