Introduce a new jax/jaxlib versioning scheme.

Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
This commit is contained in:
Peter Hawkins 2022-02-02 16:21:26 -05:00
parent 287c476eec
commit 8be057de1f
7 changed files with 281 additions and 21 deletions

View File

@ -5,5 +5,6 @@ Design Notes
:maxdepth: 1
custom_derivatives
jax_versioning
omnistaging
prng

View File

@ -0,0 +1,189 @@
# Jax and Jaxlib versioning
## Why are `jax` and `jaxlib` separate packages?
We publish JAX as two separate Python wheels, namely `jax`, which is a pure
Python wheel, and `jaxlib`, which is a mostly-C++ wheel that contains libraries
such as:
* XLA,
* pieces of LLVM used by XLA,
* MLIR infrastructure, such as the MHLO Python bindings.
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
We distribute separate `jax` and `jaxlib` packages because it makes it easy to
work on the Python parts of JAX without having to build C++ code or even having
a C++ toolchain installed. `jaxlib` is a large library that is not easy for
many users to build, but most changes to JAX only touch Python code. By
allowing the Python pieces to be updated independently of the C++ pieces, we
improve the development velocity for Python changes.
In addition `jaxlib` is not cheap to build, but we want to be able to iterate on
and run the JAX tests in environments without a lot of CPU, for example in
Github Actions or on a laptop. Many of our CI builds simply use a prebuilt
`jaxlib`, rather than needing to rebuild the C++ pieces of JAX on each PR.
As we will see, distributing `jax` and `jaxlib` separately comes with a cost, in
that it requires that changes to `jaxlib` maintain a backward compatible API.
However, we believe that on balance it is preferable to make Python changes
easy, even if at the cost of making C++ changes slightly harder.
## How are `jax` and `jaxlib` versioned?
Summary: `jax`'s version must be greater than or equal to `jaxlib`'s version,
and `jaxlib`'s version must be greater than or equal to the minimum `jaxlib`
version specified by `jax`.
Both `jax` and `jaxlib` releases are numbered `x.y.z`, where `x` is the major
version, and `y` is the minor version, and `z` is an optional patch release.
Version numbers must follow
[PEP 440](https://www.python.org/dev/peps/pep-0440/). Version number comparisons
are lexicographic comparisons on tuples of integers.
Each `jax` release has an associated minimum `jaxlib` version `mx.my.mz`. The
minimum `jaxlib` version for `jax` version `x.y.z` must be no greater than
`x.y.z`.
For `jax` version `x.y.z` and `jaxlib` version `lx.ly.lz` to be compatible,
the following must hold:
* The jaxlib version (`lx.ly.lz`) must be greater than or equal to the minimum
jaxlib version (`mx.my.mz`).
* The jax version (`x.y.z`) must be greater than or equal to the jaxlib version
(`lx.ly.lz`).
These constraints imply the following rules for releases:
* `jax` may be released on its own at any time, without updating `jaxlib`.
* If a new `jaxlib` is released, a `jax` release whose version is equal to or
greater than version the `jaxlib`'s version number must be made at the same
time.
These
[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
are currently checked by `jax` at import time, instead of being expressed as
Python package version constraints. `jax` checks the `jaxlib` version at
runtime rather than using a `pip` package version constraint because we
[provide separate `jaxlib` wheels](https://github.com/google/jax#installation)
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want `pip`
to install a `jaxlib` package for us automatically.
In the future, we hope to separate out the hardware-specific pieces of `jaxlib`
into separate plugins, at which point the minimum version could be expressed as
a Python package dependency. For now, we do provide
platform-specific extra requirements that install a compatible jaxlib version,
e.g., `jax[cuda]`.
## How can I safely make changes to the API of `jaxlib`?
* `jax` may drop compatibility with older `jaxlib` releases at any time, so long
as the minimum `jaxlib` version is increased to a compatible version. However,
note that the minimum `jaxlib`, even for unreleased versions of `jax`, must be
a released version! This allows us to use released `jaxlib` wheels in our CI
builds, and allows Python developers to work on `jax` at HEAD without ever
needing to build `jaxlib`.
For example, to remove an old backwards compatibility path in the `jax` Python
code, it is sufficient to bump the minimum jaxlib version and then delete the
compatibility path.
* `jaxlib` may drop compatibility with older `jax` releases lower than
its own release version number. The version constraints enforced by `jax`
would forbid the use of an incompatible `jaxlib`.
For example, for `jaxlib` to drop a Python binding API used by an older `jax`
version, the `jaxlib` minor or major version number must be incremented.
* If possible, changes to the `jaxlib` should be made in a backwards-compatible
way.
In general `jaxlib` may freely change its API, so long
as the rules about `jax` being compatible with all `jaxlib`s at least as new
as the minimum version are followed. This implies that
`jax` must always be compatible with at least two versions of `jaxlib`,
namely, the last release, and the tip-of-tree version, effectively
the next release. This is easier to do if compatibility is maintained,
although incompatible changes can be made using version tests from `jax`; see
below.
For example, it is usually safe to add a new function to `jaxlib`, but unsafe
to remove an existing function or to change its signature if current `jax` is
still using it. Changes to `jax` must work or degrade gracefully
for all `jaxlib` releases greater than the minimum up to HEAD.
Note that the compatibility rules here only apply to *released* versions of
`jax` and `jaxlib`. They do not apply to unreleased versions; that is, it is ok
to introduce and then remove an API from `jaxlib` if it is never released, or if
no released `jax` version uses that API.
## How is the source to `jaxlib` laid out?
`jaxlib` is split across two main repositories, namely the
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla).
The JAX-specific pieces inside XLA are primarily in the
[`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python).
The reason that C++ pieces of JAX, such as Python bindings and runtime
components, are inside the XLA tree is partially
historical and partially technical.
The historical reason is that originally the
`xla/python` bindings were envisaged as general purpose Python bindings that
might be shared with other frameworks. In practice this is increasingly less
true, and `xla/python` incorporates a number of JAX-specific pieces and is
likely to incorporate more. So it is probably best to simply think of
`xla/python` as part of JAX.
The technical reason is that the XLA C++ API is not stable. By keeping the
XLA:Python bindings in the XLA tree, their C++ implementation can be updated
atomically with the C++ API of XLA. It is easier to maintain backward and forward
compatibility of Python APIs than C++ ones, so `xla/python` exposes Python APIs
and is responsible for maintaining backward compatibility at the Python
level.
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of
`jaxlib` from the XLA repository are incorporated into the build
[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overriden on a build-by-build basis.
## How do we make changes across the `jax` and `jaxlib` boundary between releases?
The jaxlib version is a coarse instrument: it only lets us reason about
*releases*.
However, since the `jax` and `jaxlib` code is split across repositories that
cannot be updated atomically in a single change, we need to manage compatibility
at a finer granularity than our release cycle. To manage fine-grained
compatibility, we have additional versioning that is independent of the `jaxlib`
release version numbers.
We maintain an additional version number (`_version`) in
[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py).
The idea is that this version number, is defined in `xla/python`
together with the C++ parts of JAX, is also accessible to JAX Python as
`jax._src.lib.xla_extension_version`, and must
be incremented every time that a change is made to the XLA/Python code that has
backwards compatibility implications for `jax`. The JAX Python code can then use
this version number to maintain backwards compatibility, e.g.:
```
# 123 is the new version number for _version in xla_client.py
if jax._src.lib.xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
```
Note that this version number is in *addition* to the constraints on the
released version numbers, that is, this version number exists to help manage
compatibility during development for unreleased code. Releases must also
follow the compatibility rules given above.

View File

@ -61,7 +61,6 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
from jax._src import device_array
from jax._src import dispatch
from jax._src.lib import jax_jit
from jax._src.lib import version
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib

View File

@ -16,9 +16,10 @@
# checking on import.
import platform
import re
import os
import warnings
from typing import Optional
from typing import Optional, Tuple
__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
@ -26,8 +27,8 @@ __all__ = [
'xla_extension',
]
# First, before attempting to from jax import jaxlib, warn about experimental machine
# configurations.
# Before attempting to import jaxlib, warn about experimental
# machine configurations.
if platform.system() == "Darwin" and platform.machine() == "arm64":
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
"Please see https://github.com/google/jax/issues/5501 in the "
@ -41,33 +42,58 @@ except ModuleNotFoundError as err:
'https://github.com/google/jax#installation for installation instructions.'
) from err
import jax.version
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
import jaxlib.version as jaxlib_version
import jaxlib.version
except Exception as err:
# jaxlib is too old to have version number.
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
raise ImportError(msg) from err
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
if version < _minimum_jaxlib_version:
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
# Checks the jaxlib version before importing anything else from jaxlib.
# Returns the jaxlib version string.
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
minimum_jaxlib_version: str):
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
def _parse_version(v: str) -> Tuple[int, ...]:
m = version_regex.match(v)
if m is None:
raise ValueError(f"Unable to parse jaxlib version '{v}'")
return tuple(int(x) for x in m.group(0).split('.'))
if version == (0, 1, 23):
msg += ('\n\nA common cause of this error is that you installed jaxlib '
'using pip, but your version of pip is too old to support '
'manylinux2010 wheels. Try running:\n\n'
'pip install --upgrade pip\n'
'pip install --upgrade jax jaxlib\n')
raise ValueError(msg)
_jax_version = _parse_version(jax_version)
_minimum_jaxlib_version = _parse_version(minimum_jaxlib_version)
_jaxlib_version = _parse_version(jaxlib_version)
_check_jaxlib_version()
if _jaxlib_version < _minimum_jaxlib_version:
msg = (f'jaxlib is version {jaxlib_version}, but this version '
f'of jax requires version >= {minimum_jaxlib_version}.')
raise RuntimeError(msg)
if _jaxlib_version > _jax_version:
msg = (f'jaxlib version {jaxlib_version} is newer than and '
f'incompatible with jax version {jax_version}. Please '
'update your jax and/or jaxlib packages.')
raise RuntimeError(msg)
return _jaxlib_version
version_str = jaxlib.version.__version__
version = check_jaxlib_version(
jax_version=jax.version.__version__,
jaxlib_version=jaxlib.version.__version__,
minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
# Before importing any C compiled modules from jaxlib, first import the CPU
# feature guard module to verify that jaxlib was compiled in a way that only
# uses instructions that are present on this machine.
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()

View File

@ -93,7 +93,7 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
_hash_compile_options(hash_obj, compile_options)
_log_cache_key_hash(hash_obj, "compile_options")
hash_obj.update(bytes(jax._src.lib.version))
hash_obj.update(jax._src.lib.version_str.encode('utf-8'))
_log_cache_key_hash(hash_obj, "jax_lib version")
_hash_platform(hash_obj, backend)

View File

@ -13,4 +13,5 @@
# limitations under the License.
__version__ = "0.2.29"
_minimum_jaxlib_version = "0.1.74"

44
tests/version_test.py Normal file
View File

@ -0,0 +1,44 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
from jax._src.lib import check_jaxlib_version
class JaxVersionTest(unittest.TestCase):
def testVersions(self):
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3",
minimum_jaxlib_version="1.2.3")
check_jaxlib_version(jax_version="1.2.3.4", jaxlib_version="1.2.3",
minimum_jaxlib_version="1.2.3")
check_jaxlib_version(jax_version="2.5.dev234", jaxlib_version="1.2.3",
minimum_jaxlib_version="1.2.3")
with self.assertRaisesRegex(RuntimeError, ".*jax requires version >=.*"):
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.0",
minimum_jaxlib_version="1.2.3")
with self.assertRaisesRegex(RuntimeError, ".*jax requires version >=.*"):
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.0",
minimum_jaxlib_version="1.0.1")
with self.assertRaisesRegex(RuntimeError,
".incompatible with jax version.*"):
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.4",
minimum_jaxlib_version="1.0.5")
if __name__ == "__main__":
absltest.main()