Skip to content

Commit 90dcbba

Browse files
committed
Fix environ clobbering issue
1 parent da836dd commit 90dcbba

4 files changed

Lines changed: 76 additions & 56 deletions

File tree

test/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ def without_import(self, library, module):
3232
yield
3333
reload(module)
3434

35+
# recipe from https://stackoverflow.com/a/34333710
36+
# pylint: disable=no-self-use
37+
@contextlib.contextmanager
38+
def modified_environ(self, *remove, **update):
39+
"""
40+
Temporarily updates the ``os.environ`` dictionary in-place.
41+
42+
The ``os.environ`` dictionary is updated in-place so that the modification
43+
is sure to work in all situations.
44+
45+
:param remove: Environment variables to remove.
46+
:param update: Dictionary of environment variables and values to add/update.
47+
"""
48+
env = os.environ
49+
update = update or {}
50+
remove = remove or []
51+
52+
# List of environment variables being updated or removed.
53+
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
54+
# Environment variables and values to restore on exit.
55+
update_after = {k: env[k] for k in stomped}
56+
# Environment variables and values to remove on exit.
57+
remove_after = frozenset(k for k in update if k not in env)
58+
59+
try:
60+
env.update(update)
61+
[env.pop(k, None) for k in remove]
62+
yield
63+
finally:
64+
env.update(update_after)
65+
[env.pop(k) for k in remove_after]
66+
3567
# pylint: disable=invalid-name
3668
def assertPathsEqual(self, path1, path2):
3769
"""Assert paths are equal after normalization"""

test/test_install_cmdstan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import unittest
5+
from test import CustomTestCase
56
from unittest.mock import patch
67

78
from cmdstanpy.install_cmdstan import (
@@ -14,7 +15,7 @@
1415
)
1516

1617

17-
class InstallCmdStanTest(unittest.TestCase):
18+
class InstallCmdStanTest(CustomTestCase):
1819
def test_is_version_available(self):
1920
# check http error for bad version
2021
self.assertFalse(is_version_available('2.222.222-rc222'))
@@ -39,7 +40,7 @@ def test_retrieve_version(self):
3940
retrieve_version('')
4041

4142
def test_rebuild_bad_path(self):
42-
with patch.dict(os.environ, {"CMDSTAN": "~/some/fake/path"}):
43+
with self.modified_environ(CMDSTAN="~/some/fake/path"):
4344
with self.assertRaisesRegex(
4445
CmdStanInstallError, "you sure it is installed"
4546
):

test/test_utils.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -54,45 +54,32 @@
5454
BERN_EXE = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
5555

5656

57-
class CmdStanPathTest(unittest.TestCase):
57+
class CmdStanPathTest(CustomTestCase):
5858
def test_default_path(self):
59-
cur_value = None
6059
if 'CMDSTAN' in os.environ:
61-
cur_value = os.environ['CMDSTAN']
62-
try:
63-
if 'CMDSTAN' in os.environ:
64-
self.assertEqual(cmdstan_path(), os.environ['CMDSTAN'])
65-
path = os.environ['CMDSTAN']
66-
del os.environ['CMDSTAN']
60+
self.assertPathsEqual(cmdstan_path(), os.environ['CMDSTAN'])
61+
path = os.environ['CMDSTAN']
62+
with self.modified_environ('CMDSTAN'):
6763
self.assertFalse('CMDSTAN' in os.environ)
6864
set_cmdstan_path(path)
69-
self.assertEqual(cmdstan_path(), path)
70-
self.assertTrue('CMDSTAN' in os.environ)
71-
else:
72-
cmdstan_dir = os.path.expanduser(
73-
os.path.join('~', _DOT_CMDSTAN)
74-
)
75-
install_version = os.path.join(
76-
cmdstan_dir, get_latest_cmdstan(cmdstan_dir)
77-
)
78-
self.assertTrue(
79-
os.path.samefile(cmdstan_path(), install_version)
80-
)
65+
self.assertPathsEqual(cmdstan_path(), path)
8166
self.assertTrue('CMDSTAN' in os.environ)
82-
finally:
83-
if cur_value is not None:
84-
os.environ['CMDSTAN'] = cur_value
85-
else:
86-
if 'CMDSTAN' in os.environ:
87-
del os.environ['CMDSTAN']
67+
else:
68+
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
69+
install_version = os.path.join(
70+
cmdstan_dir, get_latest_cmdstan(cmdstan_dir)
71+
)
72+
self.assertTrue(os.path.samefile(cmdstan_path(), install_version))
73+
self.assertTrue('CMDSTAN' in os.environ)
8874

8975
def test_non_spaces_location(self):
9076
with tempfile.TemporaryDirectory(
9177
prefix="cmdstan_tests", dir=_TMPDIR
9278
) as tmpdir:
9379
good_path = os.path.join(tmpdir, 'good_dir')
80+
os.mkdir(good_path)
9481
with SanitizedOrTmpFilePath(good_path) as (pth, is_changed):
95-
self.assertEqual(pth, good_path)
82+
self.assertPathsEqual(pth, good_path)
9683
self.assertFalse(is_changed)
9784

9885
# prepare files for test
@@ -117,19 +104,20 @@ def test_non_spaces_location(self):
117104
self.assertFalse(os.path.exists(stan_copied))
118105

119106
# cleanup after test
107+
shutil.rmtree(good_path, ignore_errors=True)
120108
shutil.rmtree(bad_path, ignore_errors=True)
121109

122110
def test_set_path(self):
123111
if 'CMDSTAN' in os.environ:
124-
self.assertEqual(cmdstan_path(), os.environ['CMDSTAN'])
112+
self.assertPathsEqual(cmdstan_path(), os.environ['CMDSTAN'])
125113
else:
126114
cmdstan_dir = os.path.expanduser(os.path.join('~', _DOT_CMDSTAN))
127115
install_version = os.path.join(
128116
cmdstan_dir, get_latest_cmdstan(cmdstan_dir)
129117
)
130118
set_cmdstan_path(install_version)
131-
self.assertEqual(install_version, cmdstan_path())
132-
self.assertEqual(install_version, os.environ['CMDSTAN'])
119+
self.assertPathsEqual(install_version, cmdstan_path())
120+
self.assertPathsEqual(install_version, os.environ['CMDSTAN'])
133121

134122
def test_validate_path(self):
135123
if 'CMDSTAN' in os.environ:
@@ -209,29 +197,27 @@ def test_cmdstan_version(self):
209197
fake_bin = os.path.join(fake_path, 'bin')
210198
os.makedirs(fake_bin)
211199
Path(os.path.join(fake_bin, 'stanc' + EXTENSION)).touch()
212-
os.environ['CMDSTAN'] = fake_path
213-
self.assertTrue(fake_path == cmdstan_path())
214-
expect = (
215-
'CmdStan installation {} missing makefile, '
216-
'cannot get version.'.format(fake_path)
217-
)
218-
with LogCapture() as log:
219-
logging.getLogger()
220-
cmdstan_version()
221-
log.check_present(('cmdstanpy', 'INFO', expect))
222-
fake_makefile = os.path.join(fake_path, 'makefile')
223-
with open(fake_makefile, 'w') as fd:
224-
fd.write('... CMDSTAN_VERSION := dont_need_no_mmp\n\n')
225-
expect = (
226-
'Cannot parse version, expected "<major>.<minor>.<patch>", '
227-
'found: "dont_need_no_mmp".'
228-
)
229-
with LogCapture() as log:
230-
logging.getLogger()
231-
cmdstan_version()
232-
log.check_present(('cmdstanpy', 'INFO', expect))
233-
# cleanup
234-
del os.environ['CMDSTAN']
200+
with self.modified_environ(CMDSTAN=fake_path):
201+
self.assertTrue(fake_path == cmdstan_path())
202+
expect = (
203+
'CmdStan installation {} missing makefile, '
204+
'cannot get version.'.format(fake_path)
205+
)
206+
with LogCapture() as log:
207+
logging.getLogger()
208+
cmdstan_version()
209+
log.check_present(('cmdstanpy', 'INFO', expect))
210+
fake_makefile = os.path.join(fake_path, 'makefile')
211+
with open(fake_makefile, 'w') as fd:
212+
fd.write('... CMDSTAN_VERSION := dont_need_no_mmp\n\n')
213+
expect = (
214+
'Cannot parse version, expected "<major>.<minor>.<patch>", '
215+
'found: "dont_need_no_mmp".'
216+
)
217+
with LogCapture() as log:
218+
logging.getLogger()
219+
cmdstan_version()
220+
log.check_present(('cmdstanpy', 'INFO', expect))
235221
cmdstan_path()
236222

237223

test/test_variational.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ def test_exe_only(self):
246246
data=jdata,
247247
require_converged=False,
248248
seed=12345,
249-
algorithm='meanfield')
249+
algorithm='meanfield',
250+
)
250251
self.assertEqual(variational.variational_sample.shape, (1000, 4))
251252

252253

0 commit comments

Comments
 (0)