mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions are related.
This commit is contained in:
parent
287c476eec
commit
8be057de1f
@ -5,5 +5,6 @@ Design Notes
|
||||
:maxdepth: 1
|
||||
|
||||
custom_derivatives
|
||||
jax_versioning
|
||||
omnistaging
|
||||
prng
|
||||
|
189
docs/design_notes/jax_versioning.md
Normal file
189
docs/design_notes/jax_versioning.md
Normal file
@ -0,0 +1,189 @@
|
||||
# Jax and Jaxlib versioning
|
||||
|
||||
## Why are `jax` and `jaxlib` separate packages?
|
||||
|
||||
We publish JAX as two separate Python wheels, namely `jax`, which is a pure
|
||||
Python wheel, and `jaxlib`, which is a mostly-C++ wheel that contains libraries
|
||||
such as:
|
||||
* XLA,
|
||||
* pieces of LLVM used by XLA,
|
||||
* MLIR infrastructure, such as the MHLO Python bindings.
|
||||
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
|
||||
|
||||
We distribute separate `jax` and `jaxlib` packages because it makes it easy to
|
||||
work on the Python parts of JAX without having to build C++ code or even having
|
||||
a C++ toolchain installed. `jaxlib` is a large library that is not easy for
|
||||
many users to build, but most changes to JAX only touch Python code. By
|
||||
allowing the Python pieces to be updated independently of the C++ pieces, we
|
||||
improve the development velocity for Python changes.
|
||||
|
||||
In addition `jaxlib` is not cheap to build, but we want to be able to iterate on
|
||||
and run the JAX tests in environments without a lot of CPU, for example in
|
||||
Github Actions or on a laptop. Many of our CI builds simply use a prebuilt
|
||||
`jaxlib`, rather than needing to rebuild the C++ pieces of JAX on each PR.
|
||||
|
||||
As we will see, distributing `jax` and `jaxlib` separately comes with a cost, in
|
||||
that it requires that changes to `jaxlib` maintain a backward compatible API.
|
||||
However, we believe that on balance it is preferable to make Python changes
|
||||
easy, even if at the cost of making C++ changes slightly harder.
|
||||
|
||||
|
||||
## How are `jax` and `jaxlib` versioned?
|
||||
|
||||
Summary: `jax`'s version must be greater than or equal to `jaxlib`'s version,
|
||||
and `jaxlib`'s version must be greater than or equal to the minimum `jaxlib`
|
||||
version specified by `jax`.
|
||||
|
||||
Both `jax` and `jaxlib` releases are numbered `x.y.z`, where `x` is the major
|
||||
version, and `y` is the minor version, and `z` is an optional patch release.
|
||||
Version numbers must follow
|
||||
[PEP 440](https://www.python.org/dev/peps/pep-0440/). Version number comparisons
|
||||
are lexicographic comparisons on tuples of integers.
|
||||
|
||||
Each `jax` release has an associated minimum `jaxlib` version `mx.my.mz`. The
|
||||
minimum `jaxlib` version for `jax` version `x.y.z` must be no greater than
|
||||
`x.y.z`.
|
||||
|
||||
For `jax` version `x.y.z` and `jaxlib` version `lx.ly.lz` to be compatible,
|
||||
the following must hold:
|
||||
|
||||
* The jaxlib version (`lx.ly.lz`) must be greater than or equal to the minimum
|
||||
jaxlib version (`mx.my.mz`).
|
||||
* The jax version (`x.y.z`) must be greater than or equal to the jaxlib version
|
||||
(`lx.ly.lz`).
|
||||
|
||||
These constraints imply the following rules for releases:
|
||||
* `jax` may be released on its own at any time, without updating `jaxlib`.
|
||||
* If a new `jaxlib` is released, a `jax` release whose version is equal to or
|
||||
greater than version the `jaxlib`'s version number must be made at the same
|
||||
time.
|
||||
|
||||
These
|
||||
[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
|
||||
are currently checked by `jax` at import time, instead of being expressed as
|
||||
Python package version constraints. `jax` checks the `jaxlib` version at
|
||||
runtime rather than using a `pip` package version constraint because we
|
||||
[provide separate `jaxlib` wheels](https://github.com/google/jax#installation)
|
||||
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
|
||||
do not know which is the right choice for any given user, we do not want `pip`
|
||||
to install a `jaxlib` package for us automatically.
|
||||
|
||||
In the future, we hope to separate out the hardware-specific pieces of `jaxlib`
|
||||
into separate plugins, at which point the minimum version could be expressed as
|
||||
a Python package dependency. For now, we do provide
|
||||
platform-specific extra requirements that install a compatible jaxlib version,
|
||||
e.g., `jax[cuda]`.
|
||||
|
||||
## How can I safely make changes to the API of `jaxlib`?
|
||||
|
||||
* `jax` may drop compatibility with older `jaxlib` releases at any time, so long
|
||||
as the minimum `jaxlib` version is increased to a compatible version. However,
|
||||
note that the minimum `jaxlib`, even for unreleased versions of `jax`, must be
|
||||
a released version! This allows us to use released `jaxlib` wheels in our CI
|
||||
builds, and allows Python developers to work on `jax` at HEAD without ever
|
||||
needing to build `jaxlib`.
|
||||
|
||||
For example, to remove an old backwards compatibility path in the `jax` Python
|
||||
code, it is sufficient to bump the minimum jaxlib version and then delete the
|
||||
compatibility path.
|
||||
|
||||
* `jaxlib` may drop compatibility with older `jax` releases lower than
|
||||
its own release version number. The version constraints enforced by `jax`
|
||||
would forbid the use of an incompatible `jaxlib`.
|
||||
|
||||
For example, for `jaxlib` to drop a Python binding API used by an older `jax`
|
||||
version, the `jaxlib` minor or major version number must be incremented.
|
||||
|
||||
* If possible, changes to the `jaxlib` should be made in a backwards-compatible
|
||||
way.
|
||||
|
||||
In general `jaxlib` may freely change its API, so long
|
||||
as the rules about `jax` being compatible with all `jaxlib`s at least as new
|
||||
as the minimum version are followed. This implies that
|
||||
`jax` must always be compatible with at least two versions of `jaxlib`,
|
||||
namely, the last release, and the tip-of-tree version, effectively
|
||||
the next release. This is easier to do if compatibility is maintained,
|
||||
although incompatible changes can be made using version tests from `jax`; see
|
||||
below.
|
||||
|
||||
For example, it is usually safe to add a new function to `jaxlib`, but unsafe
|
||||
to remove an existing function or to change its signature if current `jax` is
|
||||
still using it. Changes to `jax` must work or degrade gracefully
|
||||
for all `jaxlib` releases greater than the minimum up to HEAD.
|
||||
|
||||
|
||||
Note that the compatibility rules here only apply to *released* versions of
|
||||
`jax` and `jaxlib`. They do not apply to unreleased versions; that is, it is ok
|
||||
to introduce and then remove an API from `jaxlib` if it is never released, or if
|
||||
no released `jax` version uses that API.
|
||||
|
||||
## How is the source to `jaxlib` laid out?
|
||||
|
||||
`jaxlib` is split across two main repositories, namely the
|
||||
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
|
||||
and in the
|
||||
[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla).
|
||||
The JAX-specific pieces inside XLA are primarily in the
|
||||
[`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python).
|
||||
|
||||
|
||||
The reason that C++ pieces of JAX, such as Python bindings and runtime
|
||||
components, are inside the XLA tree is partially
|
||||
historical and partially technical.
|
||||
|
||||
The historical reason is that originally the
|
||||
`xla/python` bindings were envisaged as general purpose Python bindings that
|
||||
might be shared with other frameworks. In practice this is increasingly less
|
||||
true, and `xla/python` incorporates a number of JAX-specific pieces and is
|
||||
likely to incorporate more. So it is probably best to simply think of
|
||||
`xla/python` as part of JAX.
|
||||
|
||||
The technical reason is that the XLA C++ API is not stable. By keeping the
|
||||
XLA:Python bindings in the XLA tree, their C++ implementation can be updated
|
||||
atomically with the C++ API of XLA. It is easier to maintain backward and forward
|
||||
compatibility of Python APIs than C++ ones, so `xla/python` exposes Python APIs
|
||||
and is responsible for maintaining backward compatibility at the Python
|
||||
level.
|
||||
|
||||
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of
|
||||
`jaxlib` from the XLA repository are incorporated into the build
|
||||
[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
|
||||
To update the version of XLA used during the build, one must update the pinned
|
||||
version in the Bazel `WORKSPACE`. This is done manually on an
|
||||
as-needed basis, but can be overriden on a build-by-build basis.
|
||||
|
||||
|
||||
## How do we make changes across the `jax` and `jaxlib` boundary between releases?
|
||||
|
||||
The jaxlib version is a coarse instrument: it only lets us reason about
|
||||
*releases*.
|
||||
|
||||
However, since the `jax` and `jaxlib` code is split across repositories that
|
||||
cannot be updated atomically in a single change, we need to manage compatibility
|
||||
at a finer granularity than our release cycle. To manage fine-grained
|
||||
compatibility, we have additional versioning that is independent of the `jaxlib`
|
||||
release version numbers.
|
||||
|
||||
We maintain an additional version number (`_version`) in
|
||||
[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py).
|
||||
The idea is that this version number, is defined in `xla/python`
|
||||
together with the C++ parts of JAX, is also accessible to JAX Python as
|
||||
`jax._src.lib.xla_extension_version`, and must
|
||||
be incremented every time that a change is made to the XLA/Python code that has
|
||||
backwards compatibility implications for `jax`. The JAX Python code can then use
|
||||
this version number to maintain backwards compatibility, e.g.:
|
||||
|
||||
```
|
||||
# 123 is the new version number for _version in xla_client.py
|
||||
if jax._src.lib.xla_extension_version >= 123:
|
||||
# Use new code path
|
||||
...
|
||||
else:
|
||||
# Use old code path.
|
||||
```
|
||||
|
||||
Note that this version number is in *addition* to the constraints on the
|
||||
released version numbers, that is, this version number exists to help manage
|
||||
compatibility during development for unreleased code. Releases must also
|
||||
follow the compatibility rules given above.
|
||||
|
@ -61,7 +61,6 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import version
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
|
@ -16,9 +16,10 @@
|
||||
# checking on import.
|
||||
|
||||
import platform
|
||||
import re
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
__all__ = [
|
||||
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
|
||||
@ -26,8 +27,8 @@ __all__ = [
|
||||
'xla_extension',
|
||||
]
|
||||
|
||||
# First, before attempting to from jax import jaxlib, warn about experimental machine
|
||||
# configurations.
|
||||
# Before attempting to import jaxlib, warn about experimental
|
||||
# machine configurations.
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
|
||||
"Please see https://github.com/google/jax/issues/5501 in the "
|
||||
@ -41,33 +42,58 @@ except ModuleNotFoundError as err:
|
||||
'https://github.com/google/jax#installation for installation instructions.'
|
||||
) from err
|
||||
|
||||
import jax.version
|
||||
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
|
||||
try:
|
||||
import jaxlib.version as jaxlib_version
|
||||
import jaxlib.version
|
||||
except Exception as err:
|
||||
# jaxlib is too old to have version number.
|
||||
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
|
||||
raise ImportError(msg) from err
|
||||
|
||||
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
|
||||
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
|
||||
|
||||
# Check the jaxlib version before importing anything else from jaxlib.
|
||||
def _check_jaxlib_version():
|
||||
if version < _minimum_jaxlib_version:
|
||||
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
|
||||
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
|
||||
# Checks the jaxlib version before importing anything else from jaxlib.
|
||||
# Returns the jaxlib version string.
|
||||
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
|
||||
minimum_jaxlib_version: str):
|
||||
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
|
||||
# PEP440 allows a number of non-numeric suffixes, which we allow also.
|
||||
# We currently do not allow an epoch.
|
||||
version_regex = re.compile(r"[0-9]+(?:\.[0-9]+)*")
|
||||
def _parse_version(v: str) -> Tuple[int, ...]:
|
||||
m = version_regex.match(v)
|
||||
if m is None:
|
||||
raise ValueError(f"Unable to parse jaxlib version '{v}'")
|
||||
return tuple(int(x) for x in m.group(0).split('.'))
|
||||
|
||||
if version == (0, 1, 23):
|
||||
msg += ('\n\nA common cause of this error is that you installed jaxlib '
|
||||
'using pip, but your version of pip is too old to support '
|
||||
'manylinux2010 wheels. Try running:\n\n'
|
||||
'pip install --upgrade pip\n'
|
||||
'pip install --upgrade jax jaxlib\n')
|
||||
raise ValueError(msg)
|
||||
_jax_version = _parse_version(jax_version)
|
||||
_minimum_jaxlib_version = _parse_version(minimum_jaxlib_version)
|
||||
_jaxlib_version = _parse_version(jaxlib_version)
|
||||
|
||||
_check_jaxlib_version()
|
||||
if _jaxlib_version < _minimum_jaxlib_version:
|
||||
msg = (f'jaxlib is version {jaxlib_version}, but this version '
|
||||
f'of jax requires version >= {minimum_jaxlib_version}.')
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if _jaxlib_version > _jax_version:
|
||||
msg = (f'jaxlib version {jaxlib_version} is newer than and '
|
||||
f'incompatible with jax version {jax_version}. Please '
|
||||
'update your jax and/or jaxlib packages.')
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return _jaxlib_version
|
||||
|
||||
version_str = jaxlib.version.__version__
|
||||
version = check_jaxlib_version(
|
||||
jax_version=jax.version.__version__,
|
||||
jaxlib_version=jaxlib.version.__version__,
|
||||
minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
|
||||
|
||||
|
||||
|
||||
# Before importing any C compiled modules from jaxlib, first import the CPU
|
||||
# feature guard module to verify that jaxlib was compiled in a way that only
|
||||
# uses instructions that are present on this machine.
|
||||
import jaxlib.cpu_feature_guard as cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
|
@ -93,7 +93,7 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
_hash_compile_options(hash_obj, compile_options)
|
||||
_log_cache_key_hash(hash_obj, "compile_options")
|
||||
|
||||
hash_obj.update(bytes(jax._src.lib.version))
|
||||
hash_obj.update(jax._src.lib.version_str.encode('utf-8'))
|
||||
_log_cache_key_hash(hash_obj, "jax_lib version")
|
||||
|
||||
_hash_platform(hash_obj, backend)
|
||||
|
@ -13,4 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.2.29"
|
||||
|
||||
_minimum_jaxlib_version = "0.1.74"
|
||||
|
44
tests/version_test.py
Normal file
44
tests/version_test.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src.lib import check_jaxlib_version
|
||||
|
||||
|
||||
class JaxVersionTest(unittest.TestCase):
|
||||
|
||||
def testVersions(self):
|
||||
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3",
|
||||
minimum_jaxlib_version="1.2.3")
|
||||
check_jaxlib_version(jax_version="1.2.3.4", jaxlib_version="1.2.3",
|
||||
minimum_jaxlib_version="1.2.3")
|
||||
check_jaxlib_version(jax_version="2.5.dev234", jaxlib_version="1.2.3",
|
||||
minimum_jaxlib_version="1.2.3")
|
||||
with self.assertRaisesRegex(RuntimeError, ".*jax requires version >=.*"):
|
||||
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.0",
|
||||
minimum_jaxlib_version="1.2.3")
|
||||
with self.assertRaisesRegex(RuntimeError, ".*jax requires version >=.*"):
|
||||
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.0",
|
||||
minimum_jaxlib_version="1.0.1")
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
".incompatible with jax version.*"):
|
||||
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.4",
|
||||
minimum_jaxlib_version="1.0.5")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user