From d72a7b405419a8b05ba1d00c749f689dc9bba4ed Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 5 Apr 2022 12:42:43 -0700 Subject: [PATCH] Add version int tuple `__version_info__` to JAX --- jax/__init__.py | 1 + jax/version.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/jax/__init__.py b/jax/__init__.py index 5f17acf05..7676c8169 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -123,6 +123,7 @@ from jax._src.api import ( ) from jax.experimental.maps import soft_pmap as soft_pmap from jax.version import __version__ as __version__ +from jax.version import __version_info__ as __version_info__ # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. diff --git a/jax/version.py b/jax/version.py index ee076c079..c01961201 100644 --- a/jax/version.py +++ b/jax/version.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +def _version_as_tuple(version_str): + return tuple(int(i) for i in version_str.split(".") if i.isdigit()) + __version__ = "0.3.5" +__version_info__ = _version_as_tuple(__version__) _minimum_jaxlib_version = "0.3.0" +_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)