Skip to content

Commit 0e5c378

Browse files
committed
fix & test pushd instead of do_command
1 parent 2261265 commit 0e5c378

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

cmdstanpy/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,6 @@ def do_command(
10011001
10021002
"""
10031003
get_logger().debug('cmd: %s\ncwd: %s', ' '.join(cmd), cwd)
1004-
restore_cwd = os.getcwd() if cwd is not None else None
10051004
try:
10061005
# NB: Using this rather than cwd arg to Popen due to windows behavior
10071006
with pushd(cwd if cwd is not None else '.'):
@@ -1043,8 +1042,6 @@ def do_command(
10431042
raise RuntimeError(msg)
10441043
except OSError as e:
10451044
msg = 'Command: {}\nfailed with error {}\n'.format(cmd, str(e))
1046-
if restore_cwd is not None:
1047-
os.chdir(restore_cwd)
10481045
raise RuntimeError(msg) from e
10491046

10501047

@@ -1322,8 +1319,10 @@ def pushd(new_dir: str) -> Iterator[None]:
13221319
"""Acts like pushd/popd."""
13231320
previous_dir = os.getcwd()
13241321
os.chdir(new_dir)
1325-
yield
1326-
os.chdir(previous_dir)
1322+
try:
1323+
yield
1324+
finally:
1325+
os.chdir(previous_dir)
13271326

13281327

13291328
def report_signal(sig: int) -> None:

test/test_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
validate_dir,
4646
windows_short_path,
4747
write_stan_json,
48+
pushd,
4849
)
4950

5051
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -806,17 +807,19 @@ def test_exit(self):
806807
with self.assertRaises(RuntimeError):
807808
do_command(args, HERE)
808809

810+
811+
class PushdTest(unittest.TestCase):
812+
809813
def test_restore_cwd(self):
810814
"Ensure do_command in a different cwd restores cwd after error."
811-
before = os.getcwd()
812-
# after = None
815+
cwd = os.getcwd()
813816
try:
814-
do_command(cmd=['ls /does-not-exist'], cwd=os.path.dirname(before))
817+
with self.assertRaises(RuntimeError):
818+
with pushd(os.path.dirname(cwd)):
819+
raise RuntimeError('error')
815820
except RuntimeError:
816821
pass
817-
finally:
818-
after = os.getcwd()
819-
self.assertEqual(before, after)
822+
self.assertEqual(cwd, os.getcwd())
820823

821824

822825
class FlattenTest(unittest.TestCase):

0 commit comments

Comments
 (0)