diff --git a/WORKSPACE b/WORKSPACE index 130c9f804..8c4f49ecf 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -62,6 +62,21 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") +jax_python_wheel_repository( + name = "jax_wheel", + version_key = "_version", + version_source = "//jax:version.py", +) + +load( + "@tsl//third_party/py:python_wheel.bzl", + "python_wheel_version_suffix_repository", +) +python_wheel_version_suffix_repository( + name = "jax_wheel_version_suffix", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 80f757ca4..3e0a95029 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", - "pytype_strict_library", ) licenses(["notice"]) @@ -46,8 +45,3 @@ py_library( "//jax/experimental/jax2tf", ] + py_deps("tensorflow_core"), ) - -pytype_strict_library( - name = "build_utils", - srcs = ["build_utils.py"], -) diff --git a/jax/version.py b/jax/version.py index 484cd96ac..4c8d1798d 100644 --- a/jax/version.py +++ b/jax/version.py @@ -35,6 +35,8 @@ def _get_version_string() -> str: # In this case we return it directly. if _release_version is not None: return _release_version + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") return _version_from_git_tree(_version) or _version_from_todays_date(_version) @@ -71,16 +73,23 @@ def _get_version_for_build() -> str: """Determine the version at build time. The returned version string depends on which environment variables are set: + - if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc" + Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc". + Please note that the WHEEL_VERSION_SUFFIX value is not the same as the + JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel + wheel build rule. - if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16" - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906" - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc """ if _release_version is not None: return _release_version - if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): - return _version_from_todays_date(_version) - if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'): + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") + if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"): return _version + if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"): + return _version_from_todays_date(_version) return _version_from_git_tree(_version) or _version_from_todays_date(_version) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e85a43883..394f9caef 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,7 +14,10 @@ """Bazel macros used by the JAX build.""" +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") +load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") +load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") @@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = [] jax_test_util_visibility = [] loops_visibility = [] +PLATFORM_TAGS_DICT = { + ("Linux", "x86_64"): ("manylinux2014", "x86_64"), + ("Linux", "aarch64"): ("manylinux2014", "aarch64"), + ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), + ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), + ("Darwin", "arm64"): ("macosx_11_0", "arm64"), + ("Windows", "AMD64"): ("win", "amd64"), +} + # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13": @@ -268,7 +280,7 @@ def jax_multiplatform_test( ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): - test_tags += ["manual"] + test_tags.append("manual") if backend == "gpu": test_tags += tf_cuda_tests_tags() native.py_test( @@ -309,15 +321,60 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version): + if no_abi: + wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" + else: + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + return wheel_name_template.format( + package_name = package_name, + python_version = python_version, + major_python_version = python_version[0], + wheel_version = wheel_version, + wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]), + ) + def _jax_wheel_impl(ctx): + include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + output_path = ctx.attr.output_path[BuildSettingInfo].value + git_hash = ctx.attr.git_hash[BuildSettingInfo].value executable = ctx.executable.wheel_binary - output = ctx.actions.declare_directory(ctx.label.name) + if include_cuda_libs and not override_include_cuda_libs: + fail("JAX wheel shouldn't be built directly against the CUDA libraries." + + " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + + " If you absolutely need to build links directly against the CUDA libraries, provide" + + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + + env = {} args = ctx.actions.args() - args.add("--output_path", output.path) # required argument - args.add("--cpu", ctx.attr.platform_tag) # required argument - jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path - args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX) + env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX + if BUILD_TAG: + env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + if not WHEEL_VERSION_SUFFIX and not BUILD_TAG: + env["JAX_RELEASE"] = "1" + + cpu = ctx.attr.cpu + platform_name = ctx.attr.platform_name + wheel_name = _get_full_wheel_name( + package_name = ctx.attr.wheel_name, + no_abi = ctx.attr.no_abi, + platform_name = platform_name, + cpu_name = cpu, + wheel_version = full_wheel_version, + ) + output_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = output_file.path[:output_file.path.rfind("/")] + + args.add("--output_path", wheel_dir) # required argument + args.add("--cpu", cpu) # required argument + args.add("--jaxlib_git_hash", git_hash) # required argument if ctx.attr.enable_cuda: args.add("--enable-cuda", "True") @@ -336,11 +393,13 @@ def _jax_wheel_impl(ctx): args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], - outputs = [output], + inputs = [], + outputs = [output_file], executable = executable, + env = env, ) - return [DefaultInfo(files = depset(direct = [output]))] + + return [DefaultInfo(files = depset(direct = [output_file]))] _jax_wheel = rule( attrs = { @@ -350,19 +409,25 @@ _jax_wheel = rule( # b/365588895 Investigate cfg = "exec" for multi platform builds cfg = "target", ), - "platform_tag": attr.string(mandatory = True), - "git_hash": attr.label(allow_single_file = True), + "wheel_name": attr.string(mandatory = True), + "no_abi": attr.bool(default = False), + "cpu": attr.string(mandatory = True), + "platform_name": attr.string(mandatory = True), + "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), + "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. "platform_version": attr.string(mandatory = True, default = ""), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), + "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), + "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), }, implementation = _jax_wheel_impl, executable = False, ) -def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): +def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -370,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): Args: name: the name of the wheel wheel_binary: the binary to use to build the wheel + wheel_name: the name of the wheel + no_abi: whether to build a wheel without ABI enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel @@ -379,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): _jax_wheel( name = name, wheel_binary = wheel_binary, + wheel_name = wheel_name, + no_abi = no_abi, enable_cuda = enable_cuda, platform_version = platform_version, - # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to - # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to - # the git hash file needs to be created first. - git_hash = select({ - "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", - "//conditions:default": None, + # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` + # flag in bazel command to pass the git hash for nightly or release builds. + platform_name = select({ + "@platforms//os:osx": "Darwin", + "@platforms//os:macos": "Darwin", + "@platforms//os:windows": "Windows", + "@platforms//os:linux": "Linux", }), - # Following the convention in jax/tools/build_utils.py. # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. - platform_tag = select({ + cpu = select({ "//jaxlib/tools:macos_arm64": "arm64", "//jaxlib/tools:win_amd64": "AMD64", "//jaxlib/tools:arm64": "aarch64", diff --git a/jaxlib/jax_python_wheel.bzl b/jaxlib/jax_python_wheel.bzl new file mode 100644 index 000000000..d5b5444fe --- /dev/null +++ b/jaxlib/jax_python_wheel.bzl @@ -0,0 +1,43 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Repository rule to generate a file with JAX wheel version. """ + +def _jax_python_wheel_repository_impl(repository_ctx): + version_source = repository_ctx.attr.version_source + version_key = repository_ctx.attr.version_key + + version_file_content = repository_ctx.read( + repository_ctx.path(version_source), + ) + version_start_index = version_file_content.find(version_key) + version_end_index = version_start_index + version_file_content[version_start_index:].find("\n") + + wheel_version = version_file_content[version_start_index:version_end_index].replace( + version_key, + "WHEEL_VERSION", + ) + repository_ctx.file( + "wheel.bzl", + wheel_version, + ) + repository_ctx.file("BUILD", "") + +jax_python_wheel_repository = repository_rule( + implementation = _jax_python_wheel_repository_impl, + attrs = { + "version_source": attr.label(mandatory = True, allow_single_file = True), + "version_key": attr.string(mandatory = True), + }, +) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 63f2643fe..318846381 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -18,12 +18,38 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") +load( + "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) +load( + "//jaxlib:jax.bzl", + "PLATFORM_TAGS_DICT", + "if_windows", + "jax_py_test", + "jax_wheel", + "pytype_strict_library", +) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +genrule( + name = "platform_tags_py", + srcs = [], + outs = ["platform_tags.py"], + cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT, +) + +pytype_strict_library( + name = "build_utils", + srcs = [ + "build_utils.py", + ":platform_tags_py", + ], +) + py_binary( name = "build_wheel", srcs = ["build_wheel.py"], @@ -41,7 +67,7 @@ py_binary( "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -99,7 +125,7 @@ py_binary( "//jax_plugins/rocm:__init__.py", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -128,7 +154,7 @@ py_binary( "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -173,30 +199,73 @@ string_flag( build_setting_default = "", ) -config_setting( - name = "jaxlib_git_hash_nightly_or_release", - flag_values = { - ":jaxlib_git_hash": "nightly", - }, +string_flag( + name = "output_path", + build_setting_default = "dist", ) jax_wheel( name = "jaxlib_wheel", + no_abi = False, wheel_binary = ":build_wheel", + wheel_name = "jaxlib", ) jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, + no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_cuda12_plugin", ) jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, + no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_cuda12_pjrt", +) + +AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) + +PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) + +X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")]) + +verify_manylinux_compliance_test( + name = "jaxlib_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jaxlib_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) + +verify_manylinux_compliance_test( + name = "jax_cuda_plugin_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_plugin_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) + +verify_manylinux_compliance_test( + name = "jax_cuda_pjrt_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_pjrt_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 65412f036..09a55d3c3 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -24,7 +24,7 @@ import pathlib import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -174,12 +174,11 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 08c2389c2..667807b51 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -24,7 +24,7 @@ import pathlib import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -167,12 +167,11 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: if tmpdir: diff --git a/jax/tools/build_utils.py b/jaxlib/tools/build_utils.py similarity index 86% rename from jax/tools/build_utils.py rename to jaxlib/tools/build_utils.py index 83d0b4b25..0db7c7072 100644 --- a/jax/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -24,6 +24,7 @@ import sys import subprocess import glob from collections.abc import Sequence +from jaxlib.tools import platform_tags def is_windows() -> bool: @@ -52,21 +53,11 @@ def copy_file( def platform_tag(cpu: str) -> str: - platform_name, cpu_name = { - ("Linux", "x86_64"): ("manylinux2014", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), - ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), - ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), - ("Darwin", "arm64"): ("macosx_11_0", "arm64"), - ("Windows", "AMD64"): ("win", "amd64"), - }[(platform.system(), cpu)] + platform_name, cpu_name = platform_tags.PLATFORM_TAGS_DICT[ + (platform.system(), cpu) + ] return f"{platform_name}_{cpu_name}" -def get_githash(jaxlib_git_hash): - if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): - with open(jaxlib_git_hash, "r") as f: - return f.readline().strip() - return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4b71bd5de..2f4afae54 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -27,7 +27,7 @@ import subprocess import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -387,8 +387,12 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/tests/version_test.py b/tests/version_test.py index 51297a971..1036d958f 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -104,6 +104,7 @@ class JaxVersionTest(unittest.TestCase): self.assertEqual(version, "1.2.3.dev4567") self.assertValidVersion(version) + @jtu.thread_unsafe_test() # Setting environment variables is not thread-safe. @patch_jax_version("1.2.3", None) def testBuildVersionFromEnvironment(self): # This test covers build-time construction of version strings in the @@ -157,6 +158,18 @@ class JaxVersionTest(unittest.TestCase): self.assertTrue(version.endswith("test")) self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY="1", + WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3")