1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Relax jax dependency constraints to be able to install RC wheels

Also, add a job to the release test workflow that verifies that the release wheels can be installed.

TESTED:
1. Full release: https://github.com/jax-ml/jax/actions/runs/14315832784

2. jax only release: https://github.com/jax-ml/jax/actions/runs/14316157252

PiperOrigin-RevId: 744857804
This commit is contained in:
Nitin Srinivasan 2025-04-07 14:49:09 -07:00 committed by jax authors
parent b18dc1dfd7
commit 64e4bf2632
4 changed files with 127 additions and 12 deletions

@ -1,12 +1,14 @@
# CI - Wheel Tests (Nightly/Release)
#
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
# This workflow is used to test the JAX wheels that was built by internal CI jobs.
#
# It orchestrates the following:
# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was
# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was
# built by internal CI jobs and runs CPU tests.
# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
# artifacts that were built by internal CI jobs and runs the CUDA tests.
# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was
# built by internal CI jobs and runs CUDA tests.
# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was
# built by internal CI jobs and runs TPU tests.
# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed.
name: CI - Wheel Tests (Nightly/Release)
on:
@ -106,4 +108,88 @@ jobs:
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}}
gcs_download_uri: ${{inputs.gcs_download_uri}}
gcs_download_uri: ${{inputs.gcs_download_uri}}
verify-release-wheels-install:
if: ${{ startsWith(github.ref_name, 'release/')}}
defaults:
run:
# Set the shell to bash as GitHub actions runs with /bin/sh by default
shell: bash
runs-on: linux-x86-n2-16
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.10", "3.13", "3.13-nogil"]
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
# Verifies that JAX's release wheels can be installed
name: "Verify release wheels install (Python ${{ matrix.python }})"
env:
PYTHON: "python${{ matrix.python }}"
steps:
- name: Download release wheels from GCS
run: |
mkdir -p $(pwd)/dist
final_gcs_download_uri=${{ inputs.gcs_download_uri }}
# Get the major and minor version of Python.
# E.g if python=3.10, then python_major_minor=310
# E.g if python=3.13-nogil, then python_major_minor=313t
python_major_minor=${{ matrix.python }}
python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.')
python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-"
gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/
jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null)
echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV
if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then
gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/
gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/
gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/
jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null)
jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null)
jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null)
echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV
echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV
echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV
fi
- name: Verify JAX CPU packages can be installed
run: |
$PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
uv pip install $JAX_WHEEL
else
uv pip install $JAX_WHEEL $JAXLIB_WHEEL
fi
- name: Verify JAX TPU packages can be installed
run: |
$PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
uv pip install $JAX_WHEEL[tpu]
else
uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL
fi
- name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages)
run: |
$PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
uv pip install $JAX_WHEEL[cuda]
else
uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda]
fi
- name: Verify JAX CUDA packages can be installed (CUDA local)
run: |
$PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
uv pip install $JAX_WHEEL[cuda12-local]
else
uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL
fi

@ -93,6 +93,12 @@ def _get_version_for_build() -> str:
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
def _is_prerelease() -> bool:
"""Determine if this is a pre-release ("rc" wheels) build."""
rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "")
return True if rc_version.startswith("rc") else False
def _write_version(fname: str) -> None:
"""Used by setup.py to write the specified version info into the source tree."""
release_version = _get_version_for_build()

@ -38,6 +38,13 @@ _jax_version = _version_module._version # JAX version, with no .dev suffix.
_cmdclass = _version_module._get_cmdclass(project_name)
_minimum_jaxlib_version = _version_module._minimum_jaxlib_version
# If this is a pre-release ("rc" wheels), append "rc0" to
# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to
# install the rc wheels.
if _version_module._is_prerelease():
_minimum_jaxlib_version += "rc0"
_current_jaxlib_version += "rc0"
with open('README.md', encoding='utf-8') as f:
_long_description = f.read()
@ -81,32 +88,32 @@ setup(
],
'cuda': [
f"jaxlib=={_current_jaxlib_version}",
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
],
'cuda12': [
f"jaxlib=={_current_jaxlib_version}",
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
],
# Deprecated alias for cuda12, kept to avoid breaking users who wrote
# cuda12_pip in their CI.
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}",
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
],
# Target that does not depend on the CUDA pip wheels, for those who want
# to use a preinstalled CUDA.
'cuda12_local': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}",
],
# ROCm support for ROCm 6.0 and above.
'rocm': [
f"jaxlib=={_current_jaxlib_version}",
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}",
],

@ -143,6 +143,7 @@ class JaxVersionTest(unittest.TestCase):
JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None):
with assert_no_subprocess_call():
version = jax.version._get_version_for_build()
self.assertFalse(jax.version._is_prerelease())
self.assertEqual(version, base_version)
self.assertValidVersion(version)
@ -150,6 +151,7 @@ class JaxVersionTest(unittest.TestCase):
JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None):
with assert_no_subprocess_call():
version = jax.version._get_version_for_build()
self.assertFalse(jax.version._is_prerelease())
self.assertEqual(version, base_version)
self.assertValidVersion(version)
@ -183,6 +185,20 @@ class JaxVersionTest(unittest.TestCase):
):
with assert_no_subprocess_call():
version = jax.version._get_version_for_build()
self.assertTrue(jax.version._is_prerelease())
self.assertEqual(version, f"{base_version}rc0")
self.assertValidVersion(version)
with jtu.set_env(
JAX_RELEASE=None,
JAXLIB_RELEASE="1",
JAX_NIGHTLY=None,
JAXLIB_NIGHTLY=None,
WHEEL_VERSION_SUFFIX="rc0",
):
with assert_no_subprocess_call():
version = jax.version._get_version_for_build()
self.assertTrue(jax.version._is_prerelease())
self.assertEqual(version, f"{base_version}rc0")
self.assertValidVersion(version)