From bc91f2d18266be881ec42406cab2ae938f07ec17 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Sep 2023 11:33:49 -0700 Subject: [PATCH] Add more extensive tests for version strings --- tests/version_test.py | 138 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 113 insertions(+), 25 deletions(-) diff --git a/tests/version_test.py b/tests/version_test.py index b65cbe346..7ce98c858 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import datetime import unittest +from unittest import mock +import re from absl.testing import absltest @@ -21,43 +24,128 @@ import jax from jax._src.lib import check_jaxlib_version from jax._src import test_util as jtu +# This is a subset of the full PEP440 pattern; for example we skip pre & post releases +VERSION_PATTERN = re.compile(r""" + ^ # start of string + (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?:\+(?P[a-zA-Z0-9_]+))? # optional local version; like '+g6643af3c3' + $ # end of string +""", re.VERBOSE) + + +@contextlib.contextmanager +def patch_jax_version(version, release_version): + """ + Patch jax.version._version & jax.version._release_version in order to + test the version construction logic. + """ + original_version = jax.version._version + original_release_version = jax.version._release_version + + jax.version._version = version + jax.version._release_version = release_version + try: + yield + finally: + jax.version._version = original_version + jax.version._release_version = original_release_version + + +@contextlib.contextmanager +def assert_no_subprocess_call(): + """Run code, asserting that subprocess.Popen *is not* called.""" + with mock.patch("subprocess.Popen") as mock_Popen: + yield + mock_Popen.assert_not_called() + + +@contextlib.contextmanager +def assert_subprocess_call(): + """Run code, asserting that subprocess.Popen *is* called at least once.""" + with mock.patch("subprocess.Popen") as mock_Popen: + yield + mock_Popen.assert_called() + class JaxVersionTest(unittest.TestCase): - def testBuildVersion(self): - base_version = jax.version._version + def assertValidVersion(self, version): + self.assertIsNotNone(VERSION_PATTERN.match(version)) - if jax.version._release_version is not None: + def testVersionValidity(self): + self.assertValidVersion(jax.__version__) + self.assertValidVersion(jax._src.lib.version_str) + + @patch_jax_version("1.2.3", "1.2.3.dev4567") + def testVersionInRelease(self): + # If the release version is set, subprocess should not be called. + with assert_no_subprocess_call(): + version = jax.version._get_version_string() + self.assertEqual(version, "1.2.3.dev4567") + self.assertValidVersion(version) + + @patch_jax_version("1.2.3", None) + def testVersionInNonRelease(self): + # If the release version is not set, we expect subprocess to be called + # in order to attempt accessing git commit information. + with assert_subprocess_call(): + version = jax.version._get_version_string() + self.assertTrue(version.startswith("1.2.3.dev")) + self.assertValidVersion(version) + + @patch_jax_version("1.2.3", "1.2.3.dev4567") + def testBuildVersionInRelease(self): + # If building from a source tree with release version set, subprocess + # should not be called. + with assert_no_subprocess_call(): version = jax.version._get_version_for_build() - self.assertEqual(version, jax.version._release_version) - else: - with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): - version = jax.version._get_version_for_build() - # TODO(jakevdp): confirm that this includes a date string & commit hash? - self.assertTrue(version.startswith(f"{base_version}.dev")) + self.assertEqual(version, "1.2.3.dev4567") + self.assertValidVersion(version) - with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY="1", JAXLIB_NIGHTLY=None): - version = jax.version._get_version_for_build() - datestring = datetime.date.today().strftime("%Y%m%d") - self.assertEqual(version, f"{base_version}.dev{datestring}") + @patch_jax_version("1.2.3", None) + def testBuildVersionFromEnvironment(self): + # This test covers build-time construction of version strings in the + # presence of several environment variables. + base_version = "1.2.3" - with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + with assert_subprocess_call(): version = jax.version._get_version_for_build() - datestring = datetime.date.today().strftime("%Y%m%d") - self.assertEqual(version, f"{base_version}.dev{datestring}") + # TODO(jakevdp): confirm that this includes a date string & commit hash? + self.assertTrue(version.startswith(f"{base_version}.dev")) + self.assertValidVersion(version) - with jtu.set_env(JAX_RELEASE="1", JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY="1", JAXLIB_NIGHTLY=None): + with assert_no_subprocess_call(): version = jax.version._get_version_for_build() - self.assertEqual(version, base_version) + datestring = datetime.date.today().strftime("%Y%m%d") + self.assertEqual(version, f"{base_version}.dev{datestring}") + self.assertValidVersion(version) - with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): + with assert_no_subprocess_call(): version = jax.version._get_version_for_build() - self.assertEqual(version, base_version) + datestring = datetime.date.today().strftime("%Y%m%d") + self.assertEqual(version, f"{base_version}.dev{datestring}") + self.assertValidVersion(version) + + with jtu.set_env(JAX_RELEASE="1", JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, base_version) + self.assertValidVersion(version) + + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, base_version) + self.assertValidVersion(version) def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3",