Skip to content

Commit 730742a

Browse files
authored
Merge pull request #679 from stan-dev/install/allow-git-versions
Allow git:TAG as a version in install_cmdstan
2 parents 9b01be8 + df5c3a4 commit 730742a

3 files changed

Lines changed: 48 additions & 6 deletions

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ jobs:
8282
run: python -m pip freeze
8383

8484
- name: CmdStan installation cacheing
85+
if: ${{ !startswith(needs.get-cmdstan-version.outputs.version, 'git:') }}
8586
uses: actions/cache@v3
8687
with:
8788
path: ~/.cmdstan

cmdstanpy/install_cmdstan.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import platform
2424
import re
25+
import shutil
2526
import sys
2627
import tarfile
2728
import urllib.error
@@ -426,6 +427,9 @@ def install_version(
426427

427428

428429
def is_version_available(version: str) -> bool:
430+
if 'git:' in version:
431+
return True # no good way in general to check if a git tag exists
432+
429433
is_available = True
430434
url = get_download_url(version)
431435
for i in range(6):
@@ -455,6 +459,27 @@ def retrieve_version(version: str, progress: bool = True) -> None:
455459
"""Download specified CmdStan version."""
456460
if version is None or version == '':
457461
raise ValueError('Argument "version" unspecified.')
462+
463+
if 'git:' in version:
464+
tag = version.split(':')[1]
465+
tag_folder = version.replace(':', '-').replace('/', '_')
466+
print(f"Cloning CmdStan branch '{tag}' from stan-dev/cmdstan on GitHub")
467+
do_command(
468+
[
469+
'git',
470+
'clone',
471+
'--depth',
472+
'1',
473+
'--branch',
474+
tag,
475+
'--recursive',
476+
'--shallow-submodules',
477+
'https://github.com/stan-dev/cmdstan.git',
478+
f'cmdstan-{tag_folder}',
479+
]
480+
)
481+
return
482+
458483
print('Downloading CmdStan version {}'.format(version))
459484
url = get_download_url(version)
460485
for i in range(6): # always retry to allow for transient URLErrors
@@ -578,9 +603,12 @@ def run_install(args: Union[InteractiveSettings, InstallationSettings]) -> None:
578603
if args.compiler:
579604
run_compiler_install(args.dir, args.verbose, args.progress)
580605

581-
cmdstan_version = f'cmdstan-{args.version}'
606+
if 'git:' in args.version:
607+
tag = args.version.replace(':', '-').replace('/', '_')
608+
cmdstan_version = f'cmdstan-{tag}'
609+
else:
610+
cmdstan_version = f'cmdstan-{args.version}'
582611
with pushd(args.dir):
583-
584612
already_installed = os.path.exists(cmdstan_version) and os.path.exists(
585613
os.path.join(
586614
cmdstan_version,
@@ -598,6 +626,7 @@ def run_install(args: Union[InteractiveSettings, InstallationSettings]) -> None:
598626
'Connection to GitHub failed. '
599627
'Check firewall settings or ensure this version exists.'
600628
)
629+
shutil.rmtree(cmdstan_version, ignore_errors=True)
601630
retrieve_version(args.version, args.progress)
602631
install_version(
603632
cmdstan_version=cmdstan_version,
@@ -620,7 +649,11 @@ def parse_cmdline_args() -> Dict[str, Any]:
620649
+ "interactive mode",
621650
)
622651
parser.add_argument(
623-
'--version', '-v', help="version, defaults to latest release version"
652+
'--version',
653+
'-v',
654+
help="version, defaults to latest release version. "
655+
"If git is installed, you can also specify a git tag or branch, "
656+
"e.g. git:develop",
624657
)
625658
parser.add_argument(
626659
'--dir', '-d', help="install directory, defaults to '$HOME/.cmdstan"

cmdstanpy/utils/cmdstan.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,13 @@ def get_latest_cmdstan(cmdstan_dir: str) -> Optional[str]:
9797
for name in os.listdir(cmdstan_dir)
9898
if os.path.isdir(os.path.join(cmdstan_dir, name))
9999
and name.startswith('cmdstan-')
100-
and name[8].isdigit()
101-
and len(name[8:].split('.')) == 3
102100
]
103101
if len(versions) == 0:
104102
return None
103+
if len(versions) == 1:
104+
return 'cmdstan-' + versions[0]
105+
# we can only compare numeric versions
106+
versions = [v for v in versions if v[0].isdigit() and v.count('.') == 2]
105107
# munge rc for sort, e.g. 2.25.0-rc1 -> 2.25.-99
106108
for i in range(len(versions)): # # pylint: disable=C0200
107109
if '-rc' in versions[i]:
@@ -442,6 +444,8 @@ def install_cmdstan(
442444
443445
:param version: CmdStan version string, e.g. "2.29.2".
444446
Defaults to latest CmdStan release.
447+
If ``git`` is installed, a git tag or branch of stan-dev/cmdstan
448+
can be specified, e.g. "git:develop".
445449
446450
:param dir: Path to install directory. Defaults to hidden directory
447451
``$HOME/.cmdstan``.
@@ -516,7 +520,11 @@ def install_cmdstan(
516520
logger.warning('CmdStan installation failed.\n%s', str(e))
517521
return False
518522

519-
set_cmdstan_path(os.path.join(args.dir, f"cmdstan-{args.version}"))
523+
if 'git:' in args.version:
524+
folder = f"cmdstan-{args.version.replace(':', '-').replace('/', '_')}"
525+
else:
526+
folder = f"cmdstan-{args.version}"
527+
set_cmdstan_path(os.path.join(args.dir, folder))
520528

521529
return True
522530

0 commit comments

Comments
 (0)