Skip to content

Commit c625e65

Browse files
committed
checkpointing
1 parent 364d038 commit c625e65

9 files changed

Lines changed: 233 additions & 175 deletions

File tree

cmdstanpy/install_cmdstan.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232

3333
from cmdstanpy import _DOT_CMDSTAN, _DOT_CMDSTANPY
3434

35-
from cmdstanpy.progress import allow_show_progress, disable_progress
36-
3735
from cmdstanpy.utils import (
3836
cmdstan_path,
3937
do_command,
@@ -43,6 +41,8 @@
4341
wrap_url_progress_hook,
4442
)
4543

44+
from . import progress as progbar
45+
4646
MAKE = os.getenv(
4747
'MAKE', 'make' if platform.system() != 'Windows' else 'mingw32-make'
4848
)
@@ -112,7 +112,7 @@ def build(verbose: bool = False, progress: bool = True) -> None:
112112
try:
113113
if verbose:
114114
do_command(cmd)
115-
elif progress and allow_show_progress():
115+
elif progress and progbar.allow_show_progress():
116116
progress_hook: Any = _wrap_build_progress_hook()
117117
do_command(cmd, fd_out=None, pbar=progress_hook)
118118
else:
@@ -146,43 +146,33 @@ def build(verbose: bool = False, progress: bool = True) -> None:
146146
)
147147

148148

149-
# pylint: disable=no-self-use
149+
@progbar.wrap_callback
150150
def _wrap_build_progress_hook() -> Optional[Callable[[str], None]]:
151151
"""Sets up tqdm callback for CmdStan sampler console msgs."""
152152
pad = ' ' * 20
153153
msgs_expected = 150 # hack: 2.27 make build send ~140 msgs to console
154-
try:
155-
pbar: tqdm = tqdm(
156-
total=msgs_expected,
157-
bar_format="{desc} ({elapsed}) | {bar} | {postfix[0][value]}",
158-
postfix=[dict(value=f'Building CmdStan {pad}')],
159-
colour='blue',
160-
desc='',
161-
position=0,
162-
)
163-
# pylint: disable=broad-except
164-
except Exception as e:
165-
disable_progress(e)
166-
167-
# pylint: disable=unused-argument
168-
def build_progress_hook(line: str) -> None:
169-
return
170-
171-
else:
154+
pbar: tqdm = tqdm(
155+
total=msgs_expected,
156+
bar_format="{desc} ({elapsed}) | {bar} | {postfix[0][value]}",
157+
postfix=[dict(value=f'Building CmdStan {pad}')],
158+
colour='blue',
159+
desc='',
160+
position=0,
161+
)
172162

173-
def build_progress_hook(line: str) -> None:
174-
if line.startswith('--- CmdStan'):
175-
pbar.set_description('Done')
163+
def build_progress_hook(line: str) -> None:
164+
if line.startswith('--- CmdStan'):
165+
pbar.set_description('Done')
166+
pbar.postfix[0]["value"] = line
167+
pbar.update(msgs_expected - pbar.n)
168+
pbar.close()
169+
else:
170+
if line.startswith('--'):
176171
pbar.postfix[0]["value"] = line
177-
pbar.update(msgs_expected - pbar.n)
178-
pbar.close()
179172
else:
180-
if line.startswith('--'):
181-
pbar.postfix[0]["value"] = line
182-
else:
183-
pbar.postfix[0]["value"] = f'{line[:8]} ... {line[-20:]}'
184-
pbar.set_description('Compiling')
185-
pbar.update(1)
173+
pbar.postfix[0]["value"] = f'{line[:8]} ... {line[-20:]}'
174+
pbar.set_description('Compiling')
175+
pbar.update(1)
186176

187177
return build_progress_hook
188178

@@ -336,7 +326,7 @@ def retrieve_version(version: str, progress: bool = True) -> None:
336326
)
337327
for i in range(6): # always retry to allow for transient URLErrors
338328
try:
339-
if progress and allow_show_progress():
329+
if progress and progbar.allow_show_progress():
340330
progress_hook: Optional[
341331
Callable[[int, int, int], None]
342332
] = wrap_url_progress_hook()
@@ -386,7 +376,7 @@ def retrieve_version(version: str, progress: bool = True) -> None:
386376
# fixes long-path limitation on Windows
387377
target = r'\\?\{}'.format(target)
388378

389-
if progress and allow_show_progress():
379+
if progress and progbar.allow_show_progress():
390380
for member in tqdm(
391381
iterable=tar.getmembers(),
392382
total=len(tar.getmembers()),

cmdstanpy/model.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,22 @@
1616

1717
from tqdm.auto import tqdm # type: ignore
1818

19-
2019
from cmdstanpy import (
2120
_CMDSTAN_REFRESH,
2221
_CMDSTAN_SAMPLING,
2322
_CMDSTAN_WARMUP,
2423
)
2524

26-
from cmdstanpy.progress import allow_show_progress, disable_progress
27-
2825
from cmdstanpy.cmdstan_args import (
2926
CmdStanArgs,
3027
GenerateQuantitiesArgs,
3128
OptimizeArgs,
3229
SamplerArgs,
3330
VariationalArgs,
3431
)
32+
3533
from cmdstanpy.compiler_opts import CompilerOptions
34+
3635
from cmdstanpy.stanfit import (
3736
CmdStanGQ,
3837
CmdStanMCMC,
@@ -41,6 +40,7 @@
4140
RunSet,
4241
from_csv,
4342
)
43+
4444
from cmdstanpy.utils import (
4545
EXTENSION,
4646
MaybeDictToFilePath,
@@ -51,6 +51,8 @@
5151
returncode_msg,
5252
)
5353

54+
from . import progress as progbar
55+
5456

5557
class CmdStanModel:
5658
# overview, omitted from doc comment in order to improve Sphinx docs.
@@ -905,7 +907,7 @@ def sample(
905907
if show_console:
906908
show_progress = False
907909
else:
908-
show_progress = show_progress and allow_show_progress()
910+
show_progress = show_progress and progbar.allow_show_progress()
909911

910912
get_logger().info('sampling: %s', runset.cmds[0])
911913
with ThreadPoolExecutor(max_workers=parallel_chains) as executor:
@@ -918,7 +920,7 @@ def sample(
918920
show_console,
919921
iter_total,
920922
)
921-
if show_progress and allow_show_progress():
923+
if show_progress and progbar.allow_show_progress():
922924
# advance terminal window cursor past progress bars
923925
term_size: os.terminal_size = shutil.get_terminal_size(
924926
fallback=(80, 24)
@@ -1283,7 +1285,7 @@ def _run_cmdstan(
12831285
get_logger().debug(
12841286
'threads: %s', str(os.environ.get('STAN_NUM_THREADS'))
12851287
)
1286-
if show_progress and allow_show_progress():
1288+
if show_progress and progbar.allow_show_progress():
12871289
progress_hook: Optional[
12881290
Callable[[str], None]
12891291
] = self._wrap_sampler_progress_hook(idx + 1, iter_total)
@@ -1352,37 +1354,28 @@ def _run_cmdstan(
13521354
sampler_args.fixed_param = True
13531355

13541356
# pylint: disable=no-self-use
1357+
@progbar.wrap_callback
13551358
def _wrap_sampler_progress_hook(
13561359
self, chain_id: int, total: int
13571360
) -> Optional[Callable[[str], None]]:
13581361
"""Sets up tqdm callback for CmdStan sampler console msgs."""
1359-
try:
1360-
pbar: tqdm = tqdm(
1361-
total=total,
1362-
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
1363-
postfix=[dict(value="Status")],
1364-
desc=f'chain {chain_id}',
1365-
colour='yellow',
1366-
)
1367-
# pylint: disable=broad-except
1368-
except Exception as e:
1369-
disable_progress(e)
1370-
1371-
# pylint: disable=unused-argument
1372-
def sampler_progress_hook(line: str) -> None:
1373-
return
1374-
1375-
else:
1362+
pbar: tqdm = tqdm(
1363+
total=total,
1364+
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
1365+
postfix=[dict(value="Status")],
1366+
desc=f'chain {chain_id}',
1367+
colour='yellow',
1368+
)
13761369

1377-
def sampler_progress_hook(line: str) -> None:
1378-
if line == "Done":
1379-
pbar.postfix[0]["value"] = 'Sampling completed'
1380-
pbar.update(total - pbar.n)
1381-
pbar.close()
1382-
elif line.startswith("Iteration"):
1383-
if 'Sampling' in line:
1384-
pbar.colour = 'blue'
1385-
pbar.update(1)
1386-
pbar.postfix[0]["value"] = line
1370+
def sampler_progress_hook(line: str) -> None:
1371+
if line == "Done":
1372+
pbar.postfix[0]["value"] = 'Sampling completed'
1373+
pbar.update(total - pbar.n)
1374+
pbar.close()
1375+
elif line.startswith("Iteration"):
1376+
if 'Sampling' in line:
1377+
pbar.colour = 'blue'
1378+
pbar.update(1)
1379+
pbar.postfix[0]["value"] = line
13871380

13881381
return sampler_progress_hook

cmdstanpy/progress.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,49 @@
11
"""
22
Record tqdm progress bar fail during session
33
"""
4-
4+
import functools
55
import logging
66

7-
SHOW_PROGRESS: bool = True
7+
_SHOW_PROGRESS: bool = True
88

99

1010
def allow_show_progress() -> bool:
11-
return SHOW_PROGRESS
11+
"""Return False if any progressbar errors have occurred this session"""
12+
return _SHOW_PROGRESS
1213

1314

14-
def disable_progress(e: Exception) -> None:
15-
print("DISABLE")
15+
def _disable_progress(e: Exception) -> None:
16+
"""Print an exception and disable progress bars for this session"""
1617
# pylint: disable=global-statement
17-
global SHOW_PROGRESS
18-
if SHOW_PROGRESS:
18+
global _SHOW_PROGRESS
19+
if _SHOW_PROGRESS:
20+
_SHOW_PROGRESS = False
1921
logging.getLogger('cmdstanpy').error(
2022
'Error in progress bar initialization:\n'
2123
'\t%s\n'
2224
'Disabling progress bars for this session',
2325
str(e),
2426
)
25-
SHOW_PROGRESS = False
27+
28+
29+
def wrap_callback(func): # type: ignore
30+
"""Wrap a callback generator so it fails safely"""
31+
32+
@functools.wraps(func)
33+
def safe_progress(*args, **kwargs): # type: ignore
34+
# pylint: disable=unused-argument
35+
def callback(*args, **kwargs): # type: ignore
36+
# totally empty callback
37+
return None
38+
39+
if not allow_show_progress():
40+
return callback
41+
42+
try:
43+
return func(*args, **kwargs)
44+
# pylint: disable=broad-except
45+
except Exception as e:
46+
_disable_progress(e)
47+
return callback
48+
49+
return safe_progress

cmdstanpy/utils.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@
4242
_TMPDIR,
4343
)
4444

45-
from cmdstanpy.progress import disable_progress
46-
45+
from . import progress as progbar
4746

4847
EXTENSION = '.exe' if platform.system() == 'Windows' else ''
4948

@@ -1210,40 +1209,27 @@ def install_cmdstan(
12101209
return True
12111210

12121211

1212+
@progbar.wrap_callback
12131213
def wrap_url_progress_hook() -> Optional[Callable[[int, int, int], None]]:
12141214
"""Sets up tqdm callback for url downloads."""
1215-
try:
1216-
pbar: tqdm = tqdm(
1217-
unit='B',
1218-
unit_scale=True,
1219-
unit_divisor=1024,
1220-
colour='blue',
1221-
leave=False,
1222-
) # type: ignore
1223-
# pylint: disable=broad-except
1224-
except Exception as e:
1225-
disable_progress(e)
1226-
1227-
def download_progress_hook(
1228-
# pylint: disable=unused-argument
1229-
count: int,
1230-
block_size: int,
1231-
total_size: int,
1232-
) -> None:
1233-
return
1234-
1235-
else:
1215+
pbar: tqdm = tqdm(
1216+
unit='B',
1217+
unit_scale=True,
1218+
unit_divisor=1024,
1219+
colour='blue',
1220+
leave=False,
1221+
)
12361222

1237-
def download_progress_hook(
1238-
count: int, block_size: int, total_size: int
1239-
) -> None:
1240-
if pbar.total is None:
1241-
pbar.total = total_size
1242-
pbar.reset()
1243-
downloaded_size = count * block_size
1244-
pbar.update(downloaded_size - pbar.n)
1245-
if pbar.n >= total_size:
1246-
pbar.close()
1223+
def download_progress_hook(
1224+
count: int, block_size: int, total_size: int
1225+
) -> None:
1226+
if pbar.total is None:
1227+
pbar.total = total_size
1228+
pbar.reset()
1229+
downloaded_size = count * block_size
1230+
pbar.update(downloaded_size - pbar.n)
1231+
if pbar.n >= total_size:
1232+
pbar.close()
12471233

12481234
return download_progress_hook
12491235

0 commit comments

Comments
 (0)