Skip to content

Commit f048803

Browse files
authored
Merge pull request #464 from stan-dev/external-fns
Allow user to supply a c++ header file
2 parents c35d42d + d90af13 commit f048803

17 files changed

Lines changed: 354 additions & 87 deletions

cmdstanpy/compiler_opts.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,21 @@ class CompilerOptions:
4848
Attributes:
4949
stanc_options - stanc compiler flags, options
5050
cpp_options - makefile options (NAME=value)
51+
user_header - path to a user .hpp file to include during compilation
5152
"""
5253

5354
def __init__(
5455
self,
5556
*,
5657
stanc_options: Optional[Dict[str, Any]] = None,
5758
cpp_options: Optional[Dict[str, Any]] = None,
59+
user_header: Optional[str] = None,
5860
logger: Optional[logging.Logger] = None,
5961
) -> None:
6062
"""Initialize object."""
6163
self._stanc_options = stanc_options if stanc_options is not None else {}
6264
self._cpp_options = cpp_options if cpp_options is not None else {}
65+
self._user_header = user_header if user_header is not None else ''
6366
if logger is not None:
6467
get_logger().warning(
6568
"Parameter 'logger' is deprecated."
@@ -88,6 +91,7 @@ def validate(self) -> None:
8891
"""
8992
self.validate_stanc_opts()
9093
self.validate_cpp_opts()
94+
self.validate_user_header()
9195

9296
def validate_stanc_opts(self) -> None:
9397
"""
@@ -104,17 +108,15 @@ def validate_stanc_opts(self) -> None:
104108
get_logger().info('ignoring compiler option: %s', key)
105109
ignore.append(key)
106110
elif key not in STANC_OPTS:
107-
raise ValueError(
108-
'unknown stanc compiler option: {}'.format(key)
109-
)
111+
raise ValueError(f'unknown stanc compiler option: {key}')
110112
elif key == 'include_paths':
111113
paths = val
112114
if isinstance(val, str):
113115
paths = val.split(',')
114116
elif not isinstance(val, list):
115117
raise ValueError(
116118
'Invalid include_paths, expecting list or '
117-
'string, found type: {}.'.format(type(val))
119+
f'string, found type: {type(val)}.'
118120
)
119121
elif key == 'use-opencl':
120122
if self._cpp_options is None:
@@ -149,10 +151,48 @@ def validate_cpp_opts(self) -> None:
149151
val = self._cpp_options[key]
150152
if not isinstance(val, int) or val < 0:
151153
raise ValueError(
152-
'{} must be a non-negative integer value,'
153-
' found {}.'.format(key, val)
154+
f'{key} must be a non-negative integer value,'
155+
f' found {val}.'
154156
)
155157

158+
def validate_user_header(self) -> None:
159+
"""
160+
User header exists.
161+
Raise ValueError if bad config is found.
162+
"""
163+
if self._user_header != "":
164+
if not (
165+
os.path.exists(self._user_header)
166+
and os.path.isfile(self._user_header)
167+
):
168+
raise ValueError(
169+
f"User header file {self._user_header} cannot be found"
170+
)
171+
if self._user_header[-4:] != '.hpp':
172+
raise ValueError(
173+
f"Header file must end in .hpp, got {self._user_header}"
174+
)
175+
if "allow_undefined" not in self._stanc_options:
176+
self._stanc_options["allow_undefined"] = True
177+
# set full path
178+
self._user_header = os.path.abspath(self._user_header)
179+
180+
if ' ' in self._user_header:
181+
raise ValueError(
182+
"User header must be in a location with no spaces in path!"
183+
)
184+
185+
if (
186+
'USER_HEADER' in self._cpp_options
187+
and self._user_header != self._cpp_options['USER_HEADER']
188+
):
189+
raise ValueError(
190+
"Disagreement in user_header C++ options found!\n"
191+
f"{self._user_header}, {self._cpp_options['USER_HEADER']}"
192+
)
193+
194+
self._cpp_options['USER_HEADER'] = self._user_header
195+
156196
def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
157197
"""Adds options to existing set of compiler options."""
158198
if new_opts.stanc_options is not None:
@@ -167,6 +207,8 @@ def add(self, new_opts: "CompilerOptions") -> None: # noqa: disable=Q000
167207
if new_opts.cpp_options is not None:
168208
for key, val in new_opts.cpp_options.items():
169209
self._cpp_options[key] = val
210+
if new_opts._user_header != '' and self._user_header == '':
211+
self._user_header = new_opts._user_header
170212

171213
def add_include_path(self, path: str) -> None:
172214
"""Adds include path to existing set of compiler options."""
@@ -191,10 +233,10 @@ def compose(self) -> List[str]:
191233
)
192234
)
193235
elif key == 'name':
194-
opts.append('STANCFLAGS+=--{}={}'.format(key, val))
236+
opts.append(f'STANCFLAGS+=--name={val}')
195237
else:
196-
opts.append('STANCFLAGS+=--{}'.format(key))
238+
opts.append(f'STANCFLAGS+=--{key}')
197239
if self._cpp_options is not None and len(self._cpp_options) > 0:
198240
for key, val in self._cpp_options.items():
199-
opts.append('{}={}'.format(key, val))
241+
opts.append(f'{key}={val}')
200242
return opts

cmdstanpy/model.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class CmdStanModel:
7676
:param cpp_options: Options for C++ compiler, specified as a Python
7777
dictionary containing C++ compiler option name, value pairs.
7878
Optional.
79+
80+
:param user_header: A path to a header file to include during C++
81+
compilation.
82+
Optional.
7983
"""
8084

8185
def __init__(
@@ -86,6 +90,7 @@ def __init__(
8690
compile: bool = True,
8791
stanc_options: Optional[Dict[str, Any]] = None,
8892
cpp_options: Optional[Dict[str, Any]] = None,
93+
user_header: Optional[str] = None,
8994
logger: Optional[logging.Logger] = None,
9095
) -> None:
9196
"""
@@ -97,12 +102,16 @@ def __init__(
97102
:param compile: Whether or not to compile the model.
98103
:param stanc_options: Options for stanc compiler.
99104
:param cpp_options: Options for C++ compiler.
105+
:param user_header: A path to a header file to include during C++
106+
compilation.
100107
"""
101108
self._name = ''
102109
self._stan_file = None
103110
self._exe_file = None
104111
self._compiler_options = CompilerOptions(
105-
stanc_options=stanc_options, cpp_options=cpp_options
112+
stanc_options=stanc_options,
113+
cpp_options=cpp_options,
114+
user_header=user_header,
106115
)
107116
if logger is not None:
108117
get_logger().warning(
@@ -235,6 +244,11 @@ def cpp_options(self) -> Dict[str, Union[bool, int]]:
235244
"""Options to C++ compilers."""
236245
return self._compiler_options._cpp_options
237246

247+
@property
248+
def user_header(self) -> str:
249+
"""The user header file if it exists, otherwise empty"""
250+
return self._compiler_options._user_header
251+
238252
def code(self) -> Optional[str]:
239253
"""Return Stan program as a string."""
240254
if not self._stan_file:
@@ -255,6 +269,7 @@ def compile(
255269
force: bool = False,
256270
stanc_options: Optional[Dict[str, Any]] = None,
257271
cpp_options: Optional[Dict[str, Any]] = None,
272+
user_header: Optional[str] = None,
258273
override_options: bool = False,
259274
) -> None:
260275
"""
@@ -272,6 +287,8 @@ def compile(
272287
273288
:param stanc_options: Options for stanc compiler.
274289
:param cpp_options: Options for C++ compiler.
290+
:param user_header: A path to a header file to include during C++
291+
compilation.
275292
276293
:param override_options: When ``True``, override existing option.
277294
When ``False``, add/replace existing options. Default is ``False``.
@@ -280,9 +297,15 @@ def compile(
280297
raise RuntimeError('Please specify source file')
281298

282299
compiler_options = None
283-
if not (stanc_options is None and cpp_options is None):
300+
if not (
301+
stanc_options is None
302+
and cpp_options is None
303+
and user_header is None
304+
):
284305
compiler_options = CompilerOptions(
285-
stanc_options=stanc_options, cpp_options=cpp_options
306+
stanc_options=stanc_options,
307+
cpp_options=cpp_options,
308+
user_header=user_header,
286309
)
287310
compiler_options.validate()
288311
if self._compiler_options is None:

docsrc/env.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies:
66
- python=3.7
77
- ipykernel
88
- ipython
9+
- ipywidgets
910
- numpy>=1.15
1011
- pandas
1112
- xarray

docsrc/examples.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ __________________
88
examples/Maximum Likelihood Estimation.ipynb
99
examples/Variational Inference.ipynb
1010
examples/Run Generated Quantities.ipynb
11+
examples/Using External C++.ipynb

docsrc/examples/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
*.hpp
88
*.exe
99
*.csv
10-
.ipynb_checkpoints
10+
.ipynb_checkpoints
11+
!make_odds.hpp
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Advanced Topic: Using External C++ Functions\n",
8+
"\n",
9+
"This is based on the relevant portion of the CmdStan documentation [here](https://mc-stan.org/docs/cmdstan-guide/using-external-cpp-code.html)"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"Consider the following Stan model, based on the bernoulli example."
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {"nbsphinx": "hidden"},
23+
"outputs": [],
24+
"source": [
25+
"import os\n",
26+
"try:\n",
27+
" os.remove('bernoulli_external')\n",
28+
"except:\n",
29+
" pass"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"from cmdstanpy import CmdStanModel\n",
39+
"model_external = CmdStanModel(stan_file='bernoulli_external.stan', compile=False)\n",
40+
"print(model_external.code())"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {},
46+
"source": [
47+
"As you can see, it features a function declaration for `make_odds`, but no definition. If we try to compile this, we will get an error. "
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"model_external.compile()"
57+
]
58+
},
59+
{
60+
"cell_type": "markdown",
61+
"metadata": {},
62+
"source": [
63+
"Even enabling the `--allow_undefined` flag to stanc3 will not allow this model to be compiled quite yet."
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": null,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"model_external.compile(stanc_options={'allow_undefined':True})"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"To resolve this, we need to both tell the Stan compiler an undefined function is okay **and** let C++ know what it should be. \n",
80+
"\n",
81+
"We can provide a definition in a C++ header file by using the `user_header` argument to either the CmdStanModel constructor or the `compile` method. \n",
82+
"\n",
83+
"This will enables the `allow_undefined` flag automatically."
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"model_external.compile(user_header='make_odds.hpp')"
93+
]
94+
},
95+
{
96+
"cell_type": "markdown",
97+
"metadata": {},
98+
"source": [
99+
"We can then run this model and inspect the output"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"fit = model_external.sample(data={'N':10, 'y':[0,1,0,0,0,0,0,0,0,1]})\n",
109+
"fit.stan_variable('odds')"
110+
]
111+
},
112+
{
113+
"cell_type": "markdown",
114+
"metadata": {},
115+
"source": [
116+
"The contents of this header file are a bit complicated unless you are familiar with the C++ internals of Stan, so they are presented without comment:\n",
117+
"\n",
118+
"```c++\n",
119+
"#include <boost/math/tools/promotion.hpp>\n",
120+
"#include <ostream>\n",
121+
"\n",
122+
"namespace bernoulli_model_namespace {\n",
123+
" template <typename T0__> inline typename\n",
124+
" boost::math::tools::promote_args<T0__>::type \n",
125+
" make_odds(const T0__& theta, std::ostream* pstream__) {\n",
126+
" return theta / (1 - theta); \n",
127+
" }\n",
128+
"}\n",
129+
"```"
130+
]
131+
}
132+
],
133+
"metadata": {
134+
"interpreter": {
135+
"hash": "8765ce46b013071999fc1966b52035a7309a0da7551e066cc0f0fa23e83d4f60"
136+
},
137+
"kernelspec": {
138+
"display_name": "Python 3.9.5 64-bit ('stan': conda)",
139+
"name": "python3"
140+
},
141+
"language_info": {
142+
"codemirror_mode": {
143+
"name": "ipython",
144+
"version": 3
145+
},
146+
"file_extension": ".py",
147+
"mimetype": "text/x-python",
148+
"name": "python",
149+
"nbconvert_exporter": "python",
150+
"pygments_lexer": "ipython3",
151+
"version": "3.9.5"
152+
},
153+
"orig_nbformat": 4
154+
},
155+
"nbformat": 4,
156+
"nbformat_minor": 2
157+
}

0 commit comments

Comments
 (0)