mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Relax jax dependency constraints to be able to install RC wheels
Also, add a job to the release test workflow that verifies that the release wheels can be installed. TESTED: 1. Full release: https://github.com/jax-ml/jax/actions/runs/14315832784 2. jax only release: https://github.com/jax-ml/jax/actions/runs/14316157252 PiperOrigin-RevId: 744857804
This commit is contained in:
parent
b18dc1dfd7
commit
64e4bf2632
@ -1,12 +1,14 @@
|
||||
# CI - Wheel Tests (Nightly/Release)
|
||||
#
|
||||
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
|
||||
# This workflow is used to test the JAX wheels that was built by internal CI jobs.
|
||||
#
|
||||
# It orchestrates the following:
|
||||
# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was
|
||||
# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was
|
||||
# built by internal CI jobs and runs CPU tests.
|
||||
# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
|
||||
# artifacts that were built by internal CI jobs and runs the CUDA tests.
|
||||
# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was
|
||||
# built by internal CI jobs and runs CUDA tests.
|
||||
# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was
|
||||
# built by internal CI jobs and runs TPU tests.
|
||||
# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed.
|
||||
name: CI - Wheel Tests (Nightly/Release)
|
||||
|
||||
on:
|
||||
@ -106,4 +108,88 @@ jobs:
|
||||
run-full-tpu-test-suite: "1"
|
||||
libtpu-version-type: ${{ matrix.libtpu-version-type }}
|
||||
download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
verify-release-wheels-install:
|
||||
if: ${{ startsWith(github.ref_name, 'release/')}}
|
||||
defaults:
|
||||
run:
|
||||
# Set the shell to bash as GitHub actions runs with /bin/sh by default
|
||||
shell: bash
|
||||
runs-on: linux-x86-n2-16
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
python: ["3.10", "3.13", "3.13-nogil"]
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
|
||||
# Verifies that JAX's release wheels can be installed
|
||||
name: "Verify release wheels install (Python ${{ matrix.python }})"
|
||||
|
||||
env:
|
||||
PYTHON: "python${{ matrix.python }}"
|
||||
|
||||
steps:
|
||||
- name: Download release wheels from GCS
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist
|
||||
final_gcs_download_uri=${{ inputs.gcs_download_uri }}
|
||||
|
||||
# Get the major and minor version of Python.
|
||||
# E.g if python=3.10, then python_major_minor=310
|
||||
# E.g if python=3.13-nogil, then python_major_minor=313t
|
||||
python_major_minor=${{ matrix.python }}
|
||||
python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.')
|
||||
python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-"
|
||||
|
||||
gsutil -m cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
|
||||
jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null)
|
||||
echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV
|
||||
|
||||
if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then
|
||||
gsutil -m cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/
|
||||
gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/
|
||||
gsutil -m cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/
|
||||
|
||||
jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null)
|
||||
jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null)
|
||||
jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null)
|
||||
|
||||
echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV
|
||||
echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV
|
||||
echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Verify JAX CPU packages can be installed
|
||||
run: |
|
||||
$PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate
|
||||
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
|
||||
uv pip install $JAX_WHEEL
|
||||
else
|
||||
uv pip install $JAX_WHEEL $JAXLIB_WHEEL
|
||||
fi
|
||||
- name: Verify JAX TPU packages can be installed
|
||||
run: |
|
||||
$PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate
|
||||
|
||||
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
|
||||
uv pip install $JAX_WHEEL[tpu]
|
||||
else
|
||||
uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL
|
||||
fi
|
||||
- name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages)
|
||||
run: |
|
||||
$PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate
|
||||
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
|
||||
uv pip install $JAX_WHEEL[cuda]
|
||||
else
|
||||
uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda]
|
||||
fi
|
||||
- name: Verify JAX CUDA packages can be installed (CUDA local)
|
||||
run: |
|
||||
$PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate
|
||||
if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then
|
||||
uv pip install $JAX_WHEEL[cuda12-local]
|
||||
else
|
||||
uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL
|
||||
fi
|
@ -93,6 +93,12 @@ def _get_version_for_build() -> str:
|
||||
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
|
||||
|
||||
|
||||
def _is_prerelease() -> bool:
|
||||
"""Determine if this is a pre-release ("rc" wheels) build."""
|
||||
rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "")
|
||||
return True if rc_version.startswith("rc") else False
|
||||
|
||||
|
||||
def _write_version(fname: str) -> None:
|
||||
"""Used by setup.py to write the specified version info into the source tree."""
|
||||
release_version = _get_version_for_build()
|
||||
|
19
setup.py
19
setup.py
@ -38,6 +38,13 @@ _jax_version = _version_module._version # JAX version, with no .dev suffix.
|
||||
_cmdclass = _version_module._get_cmdclass(project_name)
|
||||
_minimum_jaxlib_version = _version_module._minimum_jaxlib_version
|
||||
|
||||
# If this is a pre-release ("rc" wheels), append "rc0" to
|
||||
# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to
|
||||
# install the rc wheels.
|
||||
if _version_module._is_prerelease():
|
||||
_minimum_jaxlib_version += "rc0"
|
||||
_current_jaxlib_version += "rc0"
|
||||
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
_long_description = f.read()
|
||||
|
||||
@ -81,32 +88,32 @@ setup(
|
||||
],
|
||||
|
||||
'cuda': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
|
||||
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
'cuda12': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
|
||||
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
# Deprecated alias for cuda12, kept to avoid breaking users who wrote
|
||||
# cuda12_pip in their CI.
|
||||
'cuda12_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
|
||||
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
# Target that does not depend on the CUDA pip wheels, for those who want
|
||||
# to use a preinstalled CUDA.
|
||||
'cuda12_local': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin=={_current_jaxlib_version}",
|
||||
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
|
||||
f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
# ROCm support for ROCm 6.0 and above.
|
||||
'rocm': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jaxlib>={_current_jaxlib_version},<={_jax_version}",
|
||||
f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
|
@ -143,6 +143,7 @@ class JaxVersionTest(unittest.TestCase):
|
||||
JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None):
|
||||
with assert_no_subprocess_call():
|
||||
version = jax.version._get_version_for_build()
|
||||
self.assertFalse(jax.version._is_prerelease())
|
||||
self.assertEqual(version, base_version)
|
||||
self.assertValidVersion(version)
|
||||
|
||||
@ -150,6 +151,7 @@ class JaxVersionTest(unittest.TestCase):
|
||||
JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None):
|
||||
with assert_no_subprocess_call():
|
||||
version = jax.version._get_version_for_build()
|
||||
self.assertFalse(jax.version._is_prerelease())
|
||||
self.assertEqual(version, base_version)
|
||||
self.assertValidVersion(version)
|
||||
|
||||
@ -183,6 +185,20 @@ class JaxVersionTest(unittest.TestCase):
|
||||
):
|
||||
with assert_no_subprocess_call():
|
||||
version = jax.version._get_version_for_build()
|
||||
self.assertTrue(jax.version._is_prerelease())
|
||||
self.assertEqual(version, f"{base_version}rc0")
|
||||
self.assertValidVersion(version)
|
||||
|
||||
with jtu.set_env(
|
||||
JAX_RELEASE=None,
|
||||
JAXLIB_RELEASE="1",
|
||||
JAX_NIGHTLY=None,
|
||||
JAXLIB_NIGHTLY=None,
|
||||
WHEEL_VERSION_SUFFIX="rc0",
|
||||
):
|
||||
with assert_no_subprocess_call():
|
||||
version = jax.version._get_version_for_build()
|
||||
self.assertTrue(jax.version._is_prerelease())
|
||||
self.assertEqual(version, f"{base_version}rc0")
|
||||
self.assertValidVersion(version)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user