mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source tree. It is tedious and confusing to bump both version numbers. However, there is a simpler way to think of things: it is the source tree that is versioned using a single version number, and jax/jaxlib releases are made using that unified source version number. PiperOrigin-RevId: 458041752
This commit is contained in:
parent
eb0052bdf2
commit
1e171ccd10
1
BUILD.bazel
Normal file
1
BUILD.bazel
Normal file
@ -0,0 +1 @@
|
||||
exports_files(["jax/version.py"])
|
@ -17,7 +17,7 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_windows")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
@ -40,7 +40,7 @@ py_binary(
|
||||
srcs = ["build_wheel.py"],
|
||||
data = [
|
||||
"LICENSE.txt",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
"//:jax/version.py",
|
||||
"//jaxlib",
|
||||
"//jaxlib:setup.py",
|
||||
"//jaxlib:setup.cfg",
|
||||
@ -52,6 +52,7 @@ py_binary(
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib/mlir:pass_manager",
|
||||
"//jaxlib/mlir:transforms",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + select({
|
||||
|
@ -185,8 +185,9 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/version.py")
|
||||
|
||||
# The same version.py file is distributed as part of both jax and jaxlib.
|
||||
copy_to_jaxlib("__main__/jax/version.py")
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
|
||||
|
@ -30,7 +30,8 @@ easy, even if at the cost of making C++ changes slightly harder.
|
||||
|
||||
## How are `jax` and `jaxlib` versioned?
|
||||
|
||||
Summary: `jax`'s version must be greater than or equal to `jaxlib`'s version,
|
||||
Summary: `jax` and `jaxlib` share the same version number in the JAX source tree, but are released as separate Python packages.
|
||||
When installed, the `jax` package version must be greater than or equal to `jaxlib`'s version,
|
||||
and `jaxlib`'s version must be greater than or equal to the minimum `jaxlib`
|
||||
version specified by `jax`.
|
||||
|
||||
@ -54,9 +55,7 @@ the following must hold:
|
||||
|
||||
These constraints imply the following rules for releases:
|
||||
* `jax` may be released on its own at any time, without updating `jaxlib`.
|
||||
* If a new `jaxlib` is released, a `jax` release whose version is equal to or
|
||||
greater than version the `jaxlib`'s version number must be made at the same
|
||||
time.
|
||||
* If a new `jaxlib` is released, a `jax` release must be made at the same time.
|
||||
|
||||
These
|
||||
[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
|
||||
|
@ -12,11 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file is included as part of both jax and jaxlib. It is also
|
||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||
|
||||
__version__ = "0.3.15"
|
||||
_minimum_jaxlib_version = "0.3.10"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
||||
__version__ = "0.3.15"
|
||||
__version_info__ = _version_as_tuple(__version__)
|
||||
|
||||
_minimum_jaxlib_version = "0.3.10"
|
||||
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)
|
||||
|
@ -79,7 +79,6 @@ py_library(
|
||||
"lapack.py",
|
||||
"mhlo_helpers.py",
|
||||
"pocketfft.py",
|
||||
"version.py",
|
||||
],
|
||||
deps = [
|
||||
":_lapack",
|
||||
|
@ -1,20 +0,0 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# After a new jaxlib release, please remember to update the values of
|
||||
# `_current_jaxlib_version` and `_available_cuda_versions` in setup.py to
|
||||
# reflect the most recent available binaries.
|
||||
# __version__ should be increased after releasing the current version
|
||||
# (i.e. on main, this is always the next version to be released).
|
||||
__version__ = "0.3.15"
|
Loading…
x
Reference in New Issue
Block a user