Unify jax and jaxlib versions.

Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
This commit is contained in:
Peter Hawkins 2022-06-29 12:50:24 -07:00 committed by jax authors
parent eb0052bdf2
commit 1e171ccd10
7 changed files with 15 additions and 31 deletions

1
BUILD.bazel Normal file
View File

@ -0,0 +1 @@
exports_files(["jax/version.py"])

View File

@ -17,7 +17,7 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_windows")
licenses(["notice"]) # Apache 2
@ -40,7 +40,7 @@ py_binary(
srcs = ["build_wheel.py"],
data = [
"LICENSE.txt",
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"//:jax/version.py",
"//jaxlib",
"//jaxlib:setup.py",
"//jaxlib:setup.cfg",
@ -52,6 +52,7 @@ py_binary(
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:transforms",
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + select({

View File

@ -185,8 +185,9 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
copy_to_jaxlib("__main__/jaxlib/version.py")
# The same version.py file is distributed as part of both jax and jaxlib.
copy_to_jaxlib("__main__/jax/version.py")
cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):

View File

@ -30,7 +30,8 @@ easy, even if at the cost of making C++ changes slightly harder.
## How are `jax` and `jaxlib` versioned?
Summary: `jax`'s version must be greater than or equal to `jaxlib`'s version,
Summary: `jax` and `jaxlib` share the same version number in the JAX source tree, but are released as separate Python packages.
When installed, the `jax` package version must be greater than or equal to `jaxlib`'s version,
and `jaxlib`'s version must be greater than or equal to the minimum `jaxlib`
version specified by `jax`.
@ -54,9 +55,7 @@ the following must hold:
These constraints imply the following rules for releases:
* `jax` may be released on its own at any time, without updating `jaxlib`.
* If a new `jaxlib` is released, a `jax` release whose version is equal to or
greater than version the `jaxlib`'s version number must be made at the same
time.
* If a new `jaxlib` is released, a `jax` release must be made at the same time.
These
[version constraints](https://github.com/google/jax/blob/main/jax/version.py)

View File

@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is included as part of both jax and jaxlib. It is also
# eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.3.15"
_minimum_jaxlib_version = "0.3.10"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
__version__ = "0.3.15"
__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version = "0.3.10"
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)

View File

@ -79,7 +79,6 @@ py_library(
"lapack.py",
"mhlo_helpers.py",
"pocketfft.py",
"version.py",
],
deps = [
":_lapack",

View File

@ -1,20 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# After a new jaxlib release, please remember to update the values of
# `_current_jaxlib_version` and `_available_cuda_versions` in setup.py to
# reflect the most recent available binaries.
# __version__ should be increased after releasing the current version
# (i.e. on main, this is always the next version to be released).
__version__ = "0.3.15"