diff --git a/tests/version_test.py b/tests/version_test.py index 1036d958f..b78e61ae0 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -24,11 +24,15 @@ import jax from jax._src.lib import check_jaxlib_version from jax._src import test_util as jtu -# This is a subset of the full PEP440 pattern; for example we skip pre & post releases +# This is a subset of the full PEP440 pattern; for example we skip post releases VERSION_PATTERN = re.compile(r""" ^ # start of string (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?: + (?:rc(?P[0-9]+))? # optional rc version; like 'rc1' + | # or + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + )? (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' $ # end of string """, re.VERBOSE) @@ -170,6 +174,18 @@ class JaxVersionTest(unittest.TestCase): self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE="1", + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3")