Merge pull request #276 from ROCm/ci-upstream-sync-144_1

CI: 03/12/25 upstream sync
This commit is contained in:
github-actions[bot] 2025-03-12 13:23:13 -05:00 committed by GitHub
commit 9cc545254c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 1780 additions and 444 deletions

View File

@ -253,12 +253,6 @@ build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base
build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc
build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
# Mac x86 CI configs
build:ci_darwin_x86_64 --macos_minimum_os=11.0
build:ci_darwin_x86_64 --config=macos_cache_push
build:ci_darwin_x86_64 --verbose_failures=true
build:ci_darwin_x86_64 --color=yes
# Mac Arm64 CI configs
build:ci_darwin_arm64 --macos_minimum_os=11.0
build:ci_darwin_arm64 --config=macos_cache_push

View File

@ -3,6 +3,7 @@
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
# tested to get to a stable state before we enable it as a blocking presubmit.
name: CI - Cloud TPU (presubmit)
on:
workflow_dispatch:
inputs:
@ -33,64 +34,32 @@ concurrency:
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
cloud-tpu-test:
build-jax-artifacts:
if: github.event.repository.fork == false
# Begin Presubmit Naming Check - name modification requires internal check to be updated
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
tpu: [
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
# End Presubmit Naming Check github-tpu-presubmits
env:
JAXCI_PYTHON: python${{ matrix.python-version }}
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
fail-fast: false # don't cancel all jobs on failure
matrix:
artifact: ["jax", "jaxlib"]
with:
runner: "linux-x86-n2-16"
artifact: ${{ matrix.artifact }}
python: "3.10"
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 60
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
--python_version=${{ matrix.python-version }} \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
# Install libtpu
$JAXCI_PYTHON -m uv pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install jaxlib wheel and run tests
run: ./ci/run_pytest_tpu.sh
run-pytest-tpu:
if: github.event.repository.fork == false
needs: [build-jax-artifacts]
uses: ./.github/workflows/pytest_tpu.yml
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "TPU test (jaxlib=head, v5e-8)"
with:
runner: "linux-x86-ct5lp-224-8tpu"
cores: "8"
tpu-type: "v5e-8"
python: "3.10"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
# End Presubmit Naming Check github-tpu-presubmits

View File

@ -116,6 +116,9 @@ jobs:
exit 1
- name: Install Python dependencies
run: |
# Remove installation of NVIDIA wheels for CPU tests.
sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in
# TODO(srnitin): Remove after uv is installed in the Windows Dockerfile
$JAXCI_PYTHON -m pip install uv~=0.5.30
# python 3.13t cannot compile zstandard 0.23.0 due to

151
.github/workflows/pytest_tpu.yml vendored Normal file
View File

@ -0,0 +1,151 @@
# CI - Pytest TPU
#
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib wheel from a GCS bucket.
# - Sets up the libtpu wheels.
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Runs the TPU tests with Pytest.
name: CI - Pytest TPU
on:
workflow_call:
inputs:
# Note that the values for runners, cores, and tpu-type are linked to each other.
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
# following mapping:
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-ct5lp-224-8tpu"
cores:
description: "How many TPU cores should the test use?"
type: string
required: true
default: "8"
tpu-type:
description: "Which TPU type is used for testing?"
type: string
required: true
default: "v5e-8"
python:
description: "Which Python version should be used for testing?"
type: string
required: true
default: "3.12"
run-full-tpu-test-suite:
description: "Should the full TPU test suite be run?"
type: string
required: false
default: "0"
libtpu-version-type:
description: "Which libtpu version should be used for testing?"
type: string
required: false
# Choices are:
# - "nightly": Use the nightly libtpu wheel.
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
default: "nightly"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
# End Presubmit Naming Check github-tpu-presubmits
env:
LIBTPU_OLDEST_VERSION_DATE: 20241205
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download JAX wheels from GCS
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Set up libtpu wheels
run: |
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
echo "Using nightly libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
echo "Using latest libtpu from PyPI"
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
# script will install the latest libtpu wheel from PyPI.
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
echo "Using oldest supported libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
else
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
exit 1
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest TPU tests
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
run: ./ci/run_pytest_tpu.sh

View File

@ -142,4 +142,30 @@ jobs:
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.10",]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
name: "TPU tests (jax=head, jaxlib=head)"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

View File

@ -58,4 +58,42 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-tpu:
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for
# profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302)
python: ["3.10", "3.11", "3.12"]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"]
exclude:
- libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }}
- libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }}
# Run a single Python version for v4-8.
- tpu-specs:
type: "v4-8"
python: "3.10"
- tpu-specs:
type: "v4-8"
python: "3.11"
# Run min and max Python versions for v5e-8
- tpu-specs:
type: "v5e-8"
python: "3.11"
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{inputs.gcs_download_uri}}

View File

@ -456,3 +456,4 @@ For details about the JAX API, see the
For getting started as a JAX developer, see the
[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).

View File

@ -29,7 +29,7 @@ compile_pip_requirements(
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = True,
data = ["test-requirements.txt"]
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
)
compile_pip_requirements(
@ -44,7 +44,7 @@ compile_pip_requirements(
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = False,
data = ["test-requirements.txt"]
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
)
compile_pip_requirements(
@ -58,7 +58,7 @@ compile_pip_requirements(
requirements_in = "requirements.in",
requirements_txt = REQUIREMENTS,
generate_hashes = False,
data = ["test-requirements.txt"]
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
)
py_library(

View File

@ -0,0 +1,13 @@
# NVIDIA CUDA dependencies
# Note that the wheels are downloaded only when the targets in bazel command
# contain dependencies on these wheels.
nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux"
nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux"
nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux"
nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux"
nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux"
nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux"
nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux"
nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux"
nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux"
nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux"

View File

@ -2,6 +2,7 @@
# test deps
#
-r test-requirements.txt
-r gpu-test-requirements.txt
#
# build deps

View File

@ -304,24 +304,31 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
ml-dtypes==0.4.0 \
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
ml-dtypes==0.5.1 \
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r build/requirements.in
mpmath==1.4.0a1 \
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
@ -380,6 +387,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# via -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# via -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# via -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -299,24 +299,31 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
ml-dtypes==0.4.0 \
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
ml-dtypes==0.5.1 \
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r build/requirements.in
mpmath==1.4.0a1 \
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -299,24 +299,31 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
ml-dtypes==0.4.0 \
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
ml-dtypes==0.5.1 \
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r build/requirements.in
mpmath==1.4.0a1 \
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549

View File

@ -347,28 +347,31 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
ml-dtypes==0.5.0 \
--hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \
--hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \
--hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \
--hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \
--hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \
--hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \
--hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \
--hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \
--hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \
--hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \
--hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \
--hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \
--hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \
--hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \
--hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \
--hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \
--hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \
--hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \
--hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \
--hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \
--hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599
ml-dtypes==0.5.1 \
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
# via -r build/requirements.in
mpmath==1.3.0 \
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
@ -434,6 +437,64 @@ numpy==2.1.2 ; python_version >= "3.13" \
# matplotlib
# ml-dtypes
# scipy
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac

View File

@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \
# matplotlib
# ml-dtypes
# scipy
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
# via
# -r build/test-requirements.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
# via -r build/test-requirements.txt
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
# via -r build/test-requirements.txt
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
# via -r build/test-requirements.txt
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
# via -r build/test-requirements.txt
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
# via -r build/test-requirements.txt
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
# via -r build/test-requirements.txt
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
# via
# -r build/test-requirements.txt
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
# via -r build/test-requirements.txt
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
# via
# -r build/test-requirements.txt
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac

View File

@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
# on the system. By default, it is set to match the version of the hermetic
# Python used by Bazel for building the wheels.
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests
# is run.
export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0}
# We use this environment variable to control which additional wheels to install
# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest
# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the
# list of valid values and their behavior.
export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""}

View File

@ -52,23 +52,46 @@ export JAX_SKIP_SLOW_TESTS=true
echo "Running TPU tests..."
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
tests/pallas/export_back_compat_pallas_test.py \
tests/pallas/export_pallas_test.py \
tests/pallas/tpu_ops_test.py \
tests/pallas/tpu_pallas_test.py \
tests/pallas/tpu_pallas_random_test.py \
tests/pallas/tpu_pallas_async_test.py \
tests/pallas/tpu_pallas_state_test.py
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic
# TPU does not guarantee anything about forward compatibility (unless
# jax.export is used) and the 12 week compatibility window accumulates way
# too many failures.
IGNORE_FLAGS=""
if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then
IGNORE_FLAGS="--ignore=tests/pallas"
fi
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
tests/pjit_test.py \
tests/pallas/tpu_pallas_distributed_test.py
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
else
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
tests/pallas/export_back_compat_pallas_test.py \
tests/pallas/export_pallas_test.py \
tests/pallas/tpu_ops_test.py \
tests/pallas/tpu_pallas_test.py \
tests/pallas/tpu_pallas_random_test.py \
tests/pallas/tpu_pallas_async_test.py \
tests/pallas/tpu_pallas_state_test.py
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
tests/pjit_test.py \
tests/pallas/tpu_pallas_distributed_test.py
fi

View File

@ -17,8 +17,19 @@
# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python
# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to
# avoid using the Windows version of `find` on Msys.
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
for i in "${!WHEELS[@]}"; do
if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then
if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then
# Append [tpu] to the jax wheel name to download the latest libtpu wheel
# from PyPI.
WHEELS[$i]="${WHEELS[$i]}[tpu]"
fi
fi
done
if [[ -z "${WHEELS[@]}" ]]; then
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
exit 1

View File

@ -58,6 +58,7 @@ Operators
clz
collapse
complex
composite
concatenate
conj
conv

View File

@ -14,7 +14,7 @@
# JAX is Autograd and XLA
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
@ -45,17 +45,26 @@ package(
licenses(["notice"])
# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and
# assume it has been installed, e.g., by `pip`.
bool_flag(
# The flag controls whether jaxlib should be built by Bazel.
# If ":build_jaxlib=true", then jaxlib will be built.
# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel
# is available in the "dist" folder.
# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute.
# The py_import rule unpacks the wheel and provides its content as a py_library.
string_flag(
name = "build_jaxlib",
build_setting_default = True,
build_setting_default = "true",
values = [
"true",
"false",
"wheel",
],
)
config_setting(
name = "enable_jaxlib_build",
flag_values = {
":build_jaxlib": "True",
":build_jaxlib": "true",
},
)
@ -681,6 +690,7 @@ pytype_strict_library(
deps = [
":pallas", # build_cleaner: keep
"//jax/_src/pallas/fuser:block_spec",
"//jax/_src/pallas/fuser:custom_evaluate",
"//jax/_src/pallas/fuser:fusable",
"//jax/_src/pallas/fuser:fusion",
"//jax/_src/pallas/fuser:jaxpr_fusion",

View File

@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc
Device = _xc.Device
del _xc
from jax._src.core import get_ty as get_ty
from jax._src.core import typeof as typeof
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401

View File

@ -235,6 +235,7 @@ def trace_context():
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
use_direct_linearize.value,
varying_axes_in_types.value,
softmax_custom_jvp.value,
disable_jit.value,
debug_key_reuse.value,
@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state(
help=('Use direct linearization instead JVP followed by partial eval'),
include_in_jit_key=True)
varying_axes_in_types = bool_state(
name='jax_varying_axes_in_types',
default=False,
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
' array is varying over. This will help to remove the efficient'
' transpose rewrite machinery in shard_map'),
include_in_jit_key=True)
data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
default=False,

View File

@ -1576,7 +1576,7 @@ def get_aval(x):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
get_ty = get_aval
typeof = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None
@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
varying_manual_axes: frozenset[AxisName] = frozenset()):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.sharding = get_sharding(sharding, self.shape)
if config.varying_axes_in_types.value:
self.varying_manual_axes = varying_manual_axes
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray):
weak_type = self.weak_type
if 'sharding' not in kwargs:
kwargs['sharding'] = self.sharding
if 'varying_manual_axes' not in kwargs:
kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes',
frozenset())
return ShapedArray(shape, dtype, weak_type, **kwargs)
ndim = property(lambda self: len(self.shape))
@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and self.sharding == other.sharding)
and self.sharding == other.sharding
and (getattr(self, 'varying_manual_axes', frozenset()) ==
getattr(other, 'varying_manual_axes', frozenset())))
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
getattr(self, 'varying_manual_axes', frozenset())))
def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding)
return ShapedArray(
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding,
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
def str_short(self, short_dtypes=False, mesh_axis_types=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else

View File

@ -1364,9 +1364,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
if config.disable_jit.value:
try:
val = init_val
val = tree_map(lax.asarray, init_val)
while cond_fun(val):
val = body_fun(val)
val = tree_map(lax.asarray, body_fun(val))
return val
except core.ConcretizationTypeError:
# Can't run this while_loop in Python (e.g. because there's a vmap

View File

@ -1489,14 +1489,14 @@ def composite(
):
"""Composite with semantics defined by the decomposition function.
A composite is a higher-order JAX function that encapsulates an operation mad
A composite is a higher-order JAX function that encapsulates an operation made
up (composed) of other JAX functions. The semantics of the op are implemented
by the ``decomposition`` function. In other words, the defined composite
function can be replaced with its decomposed implementation without changing
the semantics of the encapsulated operation.
The compiler can recognize specific composite operations by their ``name``,
``version``, ``kawargs``, and dtypes to emit more efficient code, potentially
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
leveraging hardware-specific instructions or optimizations. If the compiler
doesn't recognize the composite, it falls back to compiling the
``decomposition`` function.
@ -1505,11 +1505,11 @@ def composite(
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
recognize the "tangent" composite and emit a single ``tangent`` instruction
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
With compilers for hardwares without dedicated tangent support, it would fall
back to compiling the decomposition.
For hardware without dedicated tangent support, it would fall back to
compiling the decomposition.
This is useful for preserving high level abstraction that would otherwise be
lost while lowering which allows for easier pattern-matching in low-level IR.
This is useful for preserving high-level abstractions that would otherwise be
lost while lowering, which allows for easier pattern-matching in low-level IR.
Args:
decomposition: function that implements the semantics of the composite op.
@ -1517,19 +1517,20 @@ def composite(
version: optional int to indicate semantic changes to the composite.
Returns:
out: callable composite function. Note that positional arguments to this
function should be interpreted as inputs and keyword arguments should be
interpreted as attributes of the op. Any keyword arguments that are passed
with ``None`` as a value will be omitted from the
``composite_attributes``.
Callable: Returns a composite function. Note that positional arguments to
this function should be interpreted as inputs and keyword arguments should
be interpreted as attributes of the op. Any keyword arguments that are
passed with ``None`` as a value will be omitted from the
``composite_attributes``.
Examples:
Tangent kernel:
>>> def my_tangent_composite(x):
... return lax.composite(
... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent'
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
... )(x)
...
>>>
>>> pi = jnp.pi
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
@ -1538,9 +1539,10 @@ def composite(
[ 0. 1. -1. 0.]
[ 0. 1. -1. 0.]
The recommended way to create composites is via a decorator. Use `/` and `*`
in the function signature to be explicit about positional and keyword
arguments respectively:
The recommended way to create composites is via a decorator. Use ``/`` and
``*`` in the function signature to be explicit about positional and keyword
arguments, respectively:
>>> @partial(lax.composite, name="my.softmax")
... def my_softmax_composite(x, /, *, axis):
... return jax.nn.softmax(x, axis)
@ -3014,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
shard = shard.addressable_data(0)
return array.make_array_from_callback(shape, sharding, lambda _: shard)
if sharding is not None and not sharding._is_concrete:
@ -8194,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
def _zero(x):
x_aval = core.get_aval(x)
return full_like(x, shape=(), fill_value=0,
sharding=x_aval.sharding.with_spec(P()))
sharding=x_aval.sharding.with_spec(P()))
_ones: Callable = partial(full_like, fill_value=1)

View File

@ -22,6 +22,7 @@ from functools import partial
import itertools
import math
import jax
from jax import tree_util
from jax._src import core
from jax._src import dispatch
@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
def ragged_all_to_all(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
axis_name, axis_index_groups = None):
"""Ragged version of :func:`all_to_all`.
"""Ragged version of :func:`all_to_all` collective.
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
replicas (e.g. there is only one group and covers all axis indices).
We say data are "ragged" when they can be represented as a list of arrays
whose shapes differ only in the size of the leading axis. For example, these
data are ragged, comprising four component arrays::
Ragged arrays are defined by a set of three arrays:
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
along which each indexed element has variable size.
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
``data`` array, and represents the starting offset of each ragged element of
the ``data`` array.
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
the ``data`` array, where the size is specified in units of sub-elements. A
sub-element is defined as the suffix of the ``data`` array shape obtained by
removing the outermost "ragged" dimension.
The ``offsets`` and ``sizes`` arrays must have the same size.
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
# Example ragged tensor
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
offsets: [3] = {0, 1, 4}
sizes: [3] = {1, 3, 4}
We often instead want a contiguous representation, e.g. for batching. But
because the shapes of the components differ, we can't apply ``jnp.stack`` to
represent these data by a single rectangular array with the leading axis
indexing the component arrays. So instead of stacking, we concatenate along
the leading axis and keep track of offsets and sizes.
# Index 'data' at 'offsets'[0], 'sizes'[0]'
{a,b,c}
That is, we can represent ragged data contiguously using a triple of dense
arrays ``(data, offsets, sizes)``:
* ``data``: the concatenated component arrays,
* ``offsets``: 1D array of indices into the leading axis of ``data``
indicating where the data for each component array begins,
* ``sizes``: 1D array of sizes of the leading axis of each component array.
We refer to this triple as a ragged array. (Offsets can't be computed from
sizes in general to allow for internal padding.)
# Index 'data' at 'offsets'[1], 'sizes'[1]'
{d,e,f},{g,h,i},{j,k,l}
For example::
data: f32[8,3] = jnp.array([
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
])
offsets: i32[3] = jnp.array([0, 1, 4])
sizes: i32[3] = jnp.array([1, 3, 4])
# Index 'data' at 'offsets'[2], 'sizes'[2]'
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
# To extract the first component array, of type f32[1,3]
data[offsets[0]:offsets[0]+sizes[0]]
# To extract the second component array, of type f32[3,3]
data[offsets[1]:offsets[1]+sizes[1]]
``output_offsets`` must be sharded in a way that each replica has offsets in
the target replica output perspective.
# To extract the third component array, of type f32[4,3]
data[offsets[2]:offsets[2]+sizes[2]]
For i-th output offset, the current replica will send
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
replica that will be written to
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
replica ``output``.
The ``ragged_all_to_all`` collective operation communicates slices of ragged
arrays between devices. Each caller is both a sender and a receiver. The
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
caller's ``operand`` to be sent. Received results are returned in an array
that has the same value of the argument ``output`` except with received values
written at some slices. The ``output_offsets`` argument does *not* indicate
the offsets at which all the received results are written; instead,
``output_offsets`` indicates the offsets at which the *sent* slices are
written on their corresponding receivers. The sizes of received slices are
indicated by ``recv_sizes``. See below for details.
For example, if we have 2 replicas:
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
``recv_sizes`` must all be the same length, and that length must be divisible
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
``recv_sizes`` must satisfy::
replica 0:
operand: [1, 2, 2]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 2]
output_offsets: [0, 0]
recv_sizes: [1, 1]
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
replica 1:
operand: [3, 4, 0]
output: [0, 0, 0, 0]
input_offsets: [0, 1]
send_sizes: [1, 1]
output_offsets: [1, 2]
recv_sizes: [2, 1]
Specifically, given a call::
replica 0's result will be: [1, 3, 0, 0]
replica 1's result will be: [2, 2, 4, 0]
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
output_offsets, recv_sizes, axis_name)
the caller sends data like::
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
N = len(input_offsets)
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
assert not leftover
for i in range(N):
dst_idx = i // slices_per_device
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
axis_name=axis_name, to_axis_index=dst_idx)
and receives data in ``result`` like::
result = output
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
for i in range(N):
src_idx = i // slices_per_device
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
``output_offsets`` does not indicate the offsets at which its local ``result``
is updated; instead, it indicates where the corresponding sent slices are
written on their destination instances. To compute the local offsets at which
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
these arguments in each mapped function instance::
axis index 0:
operand = [1, 2, 2]
output = [0, 0, 0, 0]
input_offsets = [0, 1]
send_sizes = [1, 2]
output_offsets = [0, 0]
recv_sizes = [1, 1]
axis index 1:
operand = [3, 4, 0]
output = [0, 0, 0, 0]
input_offsets = [0, 1]
send_sizes = [1, 1]
output_offsets = [1, 2]
recv_sizes = [2, 1]
then::
axis index 0:
result = [1, 3, 0, 0]
axis index 1:
result = [2, 2, 4, 0]
Args:
operand: array with ragged dimension along its outermost dimension.
output: array of ragged input offsets.
input_offsets: array of ragged input send sizes.
send_sizes: array of ragged output data.
output_offsets: array of ragged offsets in the target replica output.
recv_sizes: array of ragged output receive sizes.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
operand: data array of shape (N, A, B, ...) representing concatenated
(possibly padded) ragged data to be sent.
output: data array of shape (M, A, B, ...) to update with received data.
input_offsets: 1D integer array of shape (K,) representing the offsets of
leading-axis slices into ``operand`` to be sent.
send_sizes: 1D integer array array of shape (K,) representing the sizes of
leading-axis slices into ``operand`` to be sent.
output_offsets: 1D integer array of shape (K,) representing where the
corresponding sent data is written on each corresponding receiver.
recv_sizes: 1D integer array of shape (K,) representing sizes of
leading-axis slices into ``output`` to update with received data.
axis_name: name of the mapped axis over which to perform the communication.
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
first two and last two replicas). Groups must cover all axis indices
@ -538,7 +596,10 @@ def ragged_all_to_all(
behavior is undefined.
Returns:
array with shape equal to ``output``.
Array of shape (M, A, B, ...) with the same value as the ``output`` except
with received data written into slices starting at
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
``recv_sizes``.
"""
if not isinstance(axis_name, (tuple, list)):
@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval(
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
def _ragged_all_to_all_jvp(primals, tangents, **params):
operand, output, *sizes_and_offsets = primals
operand_dot, output_dot, *_ = tangents
result = ragged_all_to_all_p.bind(
operand, output, *sizes_and_offsets, **params)
if type(operand_dot) is type(output_dot) is ad.Zero:
result_dot = ad.Zero.from_primal_value(result)
else:
operand_dot = ad.instantiate_zeros(operand_dot)
output_dot = ad.instantiate_zeros(output_dot)
result_dot = ragged_all_to_all_p.bind(
operand_dot, output_dot, *sizes_and_offsets, **params)
return result, result_dot
def _ragged_all_to_all_transpose(
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
*, axis_name, axis_index_groups):
if type(t) is ad.Zero:
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
else:
zero = ad.zeros_like_aval(operand.aval)
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
operand_t = ragged_all_to_all_p.bind(
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
axis_name=axis_name, axis_index_groups=axis_index_groups)
mask = jax.numpy.cumsum(
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
.at[output_offsets_ + recv_sizes].add(-1))
output_t = jax.numpy.where(mask, 0, t)
return [operand_t, output_t] + [None] * 4
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')

View File

@ -303,15 +303,16 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
def igamma_impl(a, x, *, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
a_is_zero = eq(a, _const(a, 0))
x_is_zero = eq(x, _const(x, 0))
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity]))
output = select(
use_igammac,
@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype):
)
output = select(x_is_zero, full_like(a, 0), output)
output = select(x_is_infinity, full_like(a, 1), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
output = select(domain_error, full_like(a, float('nan')), output)
return output
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
@ -433,11 +433,15 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
raise ValueError(f"Invalid mode: {mode}")
def igammac_impl(a, x, *, dtype):
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
is_nan = bitwise_or(_isnan(a), _isnan(x))
a_is_zero = eq(a, _const(a, 0))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero]))
ax = exp(ax)
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
@ -445,10 +449,10 @@ def igammac_impl(a, x, *, dtype):
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
x_is_infinity = eq(x, _const(x, float('inf')))
result = select(x_is_infinity, full_like(result, 0), result)
return select(out_of_range, full_like(a, 1), result)
output = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
output = select(bitwise_or(x_is_infinity, a_is_zero), full_like(output, 0), output)
output = select(domain_error, full_like(a, float('nan')), output)
return output
def igamma_grad_a_impl(a, x, *, dtype):
is_nan = bitwise_or(_isnan(a), _isnan(x))

View File

@ -343,7 +343,7 @@ class BlockSpec:
if self.block_shape is None:
block_shape = array_aval.shape
else:
block_shape = self.block_shape
block_shape = self.block_shape # type: ignore
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "

View File

@ -32,6 +32,7 @@ pytype_strict_library(
],
deps = [
":block_spec",
":custom_evaluate",
":fusable",
":fusion",
":jaxpr_fusion",
@ -44,6 +45,7 @@ pytype_strict_library(
"block_spec.py",
],
deps = [
":fuser_utils",
"//jax",
"//jax:ad_util",
"//jax:api_util",
@ -119,3 +121,27 @@ pytype_strict_library(
"//jax/_src/pallas",
],
)
pytype_strict_library(
name = "custom_evaluate",
srcs = ["custom_evaluate.py"],
deps = [
":fuser_utils",
"//jax",
"//jax:core",
"//jax:source_info_util",
"//jax:tree_util",
"//jax:util",
],
)
pytype_strict_library(
name = "fuser_utils",
srcs = ["fuser_utils.py"],
deps = [
"//jax:api_util",
"//jax:core",
"//jax:partial_eval",
"//jax:tree_util",
],
)

View File

@ -16,6 +16,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

View File

@ -26,15 +26,14 @@ from typing import Any, Callable, Protocol, Sequence
import jax
from jax import lax
from jax._src import ad_util
from jax._src import api_util
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.fuser import fuser_utils
import jax.numpy as jnp
import numpy as np
@ -226,18 +225,6 @@ def _unwrap_block_spec_scalar_prefetch(
return out_block_spec
def _make_jaxpr(f, *args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
flat_avals = [core.get_aval(x) for x in flat_args]
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
return jaxpr, consts, in_tree, out_tree
def pull_block_spec(
f: Callable,
out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
@ -246,7 +233,9 @@ def pull_block_spec(
grid: tuple[int | jax.Array, ...] | None = None,
):
def wrapped(*args, **kwargs):
jaxpr, consts, in_tree, out_tree_ = _make_jaxpr(f, *args, **kwargs)
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
f, *args, **kwargs
)
# TODO(sharadmv): handle these consts better, they should correspond to
# scalar prefetch.
del consts, out_tree_
@ -563,7 +552,9 @@ def make_kernel_function(
def get_fusion_values(
fusion: Callable, *args, **kwargs
) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]:
jaxpr, values, in_tree, out_tree = _make_jaxpr(fusion, *args, **kwargs)
jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr(
fusion, *args, **kwargs
)
assert len(values) == len(jaxpr.constvars), (jaxpr, values)
out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars)
read_usage_env = compute_usage(jaxpr, out_usages)
@ -1325,7 +1316,7 @@ def push_block_spec(
flat_block_specs, in_tree_ = tree_util.tree_flatten(
(in_spec_args, in_spec_kwargs)
)
jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs)
jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
if in_tree != in_tree_:
raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}')
out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs)

View File

@ -0,0 +1,82 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for evaluating functions under certain constraints."""
import dataclasses
from typing import Any
from jax import lax
from jax._src import core
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
from jax._src.pallas.fuser import fuser_utils
@dataclasses.dataclass
class CustomEvaluateSettings:
allow_transpose: bool = True
def evaluate(f, *, allow_transpose: bool = True):
def wrapped(*args, **kwargs):
jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
settings = CustomEvaluateSettings(allow_transpose=allow_transpose)
flat_args = tree_util.tree_leaves(args)
out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
return wrapped
# Disallow most higher-order primitives for now.
disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p}
def _custom_evaluate_jaxpr(
settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args
):
def read(v: core.Atom) -> Any:
return v.val if isinstance(v, core.Literal) else env[v]
def write(v: core.Var, val: Any) -> None:
env[v] = val
env: dict[core.Var, Any] = {}
util.safe_map(write, jaxpr.constvars, consts)
util.safe_map(write, jaxpr.invars, args)
lu = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
if eqn.primitive in disallowed_primitives:
raise NotImplementedError(f'Primitive {eqn.primitive} not supported.')
if not settings.allow_transpose and eqn.primitive is lax.transpose_p:
raise ValueError('Transpose not allowed.')
name_stack = (
source_info_util.current_name_stack() + eqn.source_info.name_stack
)
traceback = eqn.source_info.traceback
with source_info_util.user_context(
traceback, name_stack=name_stack
), eqn.ctx.manager:
ans = eqn.primitive.bind(
*subfuns, *util.safe_map(read, eqn.invars), **bind_params
)
if eqn.primitive.multiple_results:
util.safe_map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
core.clean_up_dead_vars(eqn, env, lu)
return util.safe_map(read, jaxpr.outvars)

View File

@ -0,0 +1,33 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic utils for fuser internals."""
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src.interpreters import partial_eval as pe
def make_jaxpr(f, *args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
flat_avals = [core.get_aval(x) for x in flat_args]
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
return jaxpr, consts, in_tree, out_tree

View File

@ -1853,7 +1853,13 @@ def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
def _dot_general_lowering_rule(
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
ctx: LoweringRuleContext,
x,
y,
dimension_numbers,
precision,
preferred_element_type,
**_,
):
(lhs_dims, rhs_dims), _ = dimension_numbers
(aval_out,) = ctx.avals_out
@ -1894,10 +1900,34 @@ def _dot_general_lowering_rule(
x = vector.broadcast(bcast_shape, x)
if ctx.avals_in[1].shape != bcast_shape:
y = vector.broadcast(bcast_shape, y)
red_dtype = (
preferred_element_type if preferred_element_type else lhs_aval.dtype
)
red_type = aval_to_ir_type(
ctx.lowering_context.dynamic_shape_replacement_fn,
lhs_aval.update(shape=(lhs_aval.shape[0],)),
lhs_aval.update(shape=(lhs_aval.shape[0],), dtype=red_dtype),
)
if lhs_aval.dtype != red_dtype:
lhs_type = aval_to_ir_type(
ctx.lowering_context.dynamic_shape_replacement_fn,
lhs_aval.update(shape=lhs_aval.shape, dtype=red_dtype),
)
if red_dtype == jnp.float32:
x = arith.extf(lhs_type, x)
else:
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
if rhs_aval.dtype != red_dtype:
rhs_type = aval_to_ir_type(
ctx.lowering_context.dynamic_shape_replacement_fn,
rhs_aval.update(shape=rhs_aval.shape, dtype=red_dtype),
)
if red_dtype == jnp.float32:
y = arith.extf(rhs_type, y)
else:
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
acc = arith.ConstantOp(
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
)

View File

@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
raise NotImplementedError(f"Unsupported layout {x.layout}")
def _reduce_lowering_rule_wg(
kind: vector_dialect.CombiningKind,
acc: object,
ctx: LoweringRuleContext,
x,
*,
axes,
) -> ir.OpView:
[x_aval] = ctx.avals_in
[out_aval] = ctx.avals_out
x = _ensure_ir_value(x, x_aval.dtype)
out_type = mgpu_utils.dtype_to_ir_type(out_aval.dtype)
if not out_aval.shape:
# Special-case: reducing to a scalar.
if x_aval.ndim != 1:
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
# 1D inputs.
raise NotImplementedError("Only 1D inputs are supported")
return vector_dialect.ReductionOp(out_type, kind, x)
acc = vector_dialect.splat(
ir.VectorType.get(out_aval.shape, out_type),
_ensure_ir_value(acc, out_aval.dtype),
)
return vector_dialect.MultiDimReductionOp(kind, x, acc, axes)
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup)
def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
op = _reduce_lowering_rule_wg(
vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes
)
op.attributes["offset"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes
)
return op.result
@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup)
def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = vector_dialect.CombiningKind.MAXIMUMF
acc = float("-inf")
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = vector_dialect.CombiningKind.MAXSI
acc = np.iinfo(x_aval.dtype).max
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = vector_dialect.CombiningKind.MAXUI
acc = np.iinfo(x_aval.dtype).max
else:
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result
@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
i32 = ir.IntegerType.get_signless(32)

View File

@ -198,7 +198,8 @@ def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
y = lax.div(lax.sub(x, loc), scale)
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

View File

@ -865,15 +865,15 @@ class Jax2TfLimitation(test_harnesses.Limitation):
def custom_assert(tst, result_jax, result_tf, *, args, tol,
err_msg): # noqa: F811
arg1, arg2 = args
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
# lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1
special_cases = (arg1 <= 0.) | (arg2 <= 0)
nr_special_cases = np.count_nonzero(special_cases)
tst.assertAllClose(
np.full((nr_special_cases,), 1., dtype=dtype),
np.full((nr_special_cases,), np.nan, dtype=dtype),
result_jax[special_cases],
err_msg=err_msg)
tst.assertAllClose(
np.full((nr_special_cases,), np.nan, dtype=dtype),
np.full((nr_special_cases,), 1, dtype=dtype),
result_tf[special_cases],
err_msg=err_msg)
# non-special cases are equal
@ -892,12 +892,12 @@ class Jax2TfLimitation(test_harnesses.Limitation):
custom_numeric(dtypes=[np.float64], tol=1e-9),
custom_numeric(devices="gpu", tol=1e-3),
custom_numeric(
modes=("compiled",),
custom_assert=custom_assert,
devices=("cpu", "gpu"),
devices=("cpu", "gpu", "tpu"),
description=(
"May return different results at undefined points "
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or "
"JAX returns 1 and TF returns `NaN`")),
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")),
]
@classmethod

View File

@ -260,7 +260,7 @@ def _construct_smem_reftree(
dynamic_smem, c(dynamic_smem_offset, index), [],
)
if layout is None:
layout = tcgen05._infer_tmem_layout(shape)
layout = tcgen05._infer_tmem_layout(shape, collective)
num_cols = layout.cols_in_shape(shape)
delayed_warp_init.append(
functools.partial(

View File

@ -259,14 +259,15 @@ def _vector_load_op_lowering_rule(
is_signed=is_signed,
vec_size=strided_layout.vec_size,
)
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
elif layouts.from_layout_attr(out_layout_attr) == fa.TILED_LAYOUT_WGMMA:
layout = ir.MemRefType(vector_load_op.base.type).layout
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
transformed_ref = transform_memref(vector_load_op.base, transforms)
fragmented_array = fa.FragmentedArray.load_tiled(
transformed_ref,
swizzle=swizzle,
is_signed=is_signed
is_signed=is_signed,
layout=fa.TILED_LAYOUT_WGMMA,
)
else:
raise ValueError(
@ -319,6 +320,34 @@ def _vector_splat_op_lowering_rule(
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
@_register_lowering(vector.ReductionOp)
def _vector_reduction_op_lowering_rule(
ctx: LoweringContext, op: vector.ReductionOp
) -> Sequence[ir.Value]:
del ctx # Unused.
[layout] = inference_utils.in_layouts(op)
() = inference_utils.out_layouts(op)
element_type = ir.VectorType(op.vector.type).element_type
is_signed = False if ir.IntegerType.isinstance(element_type) else None
a = _fragmented_array_from_ir(op.vector, layout, is_signed)
match str(op.kind):
case "#vector.kind<add>":
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
scratch = _slice_smem(
ir.MemRefType.get([4], element_type, memory_space=smem),
arith.constant(None, op.attributes["offset"]),
)
result = a.reduce_sum(scratch)
case (
"#vector.kind<maxsi>" | "#vector.kind<maxui>" | "#vector.kind<maximumf>"
):
# TODO(slebedev): Implement this and remove the raise below.
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
case _:
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
return [_fragmented_array_to_ir(result, op.result.type)]
def memref_layout_to_swizzle_and_transforms(
layout: ir.Attribute,
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
@ -634,7 +663,10 @@ def _mgpu_wgmma_op_lowering_rule(
*inference_utils.in_layouts(wgmma_op),
*inference_utils.out_layouts(wgmma_op),
)
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
is_supported_layout = (
lambda l: layouts.from_tiled_layout_attr(l) == fa.TILED_LAYOUT_WGMMA
)
if not all(map(is_supported_layout, fa_layouts)):
raise ValueError("Layout mismatch")
wgmma_layout = fa_layouts[0]
@ -667,7 +699,12 @@ def _mgpu_wgmma_op_lowering_rule(
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
return [
_fragmented_array_to_ir(
new_acc.value.to_layout(fa.TILED_LAYOUT_WGMMA),
wgmma_op.accumulator.type,
)
]
@_register_lowering(mgpu.ArriveExpectTxOp)
@ -704,16 +741,17 @@ def _mgpu_slice_smem_op_lowering_rule(
ctx: LoweringContext, op: SliceSMEMOp
) -> Sequence[ir.Value]:
del ctx
return [_slice_smem(op.result.type, op.offset)]
def _slice_smem(result: ir.Type, offset: ir.Value):
i8 = ir.IntegerType.get_signless(8)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
smem_base = gpu.dynamic_shared_memory(
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
)
offset = arith.index_cast(ir.IndexType.get(), op.offset)
return [memref.view(op.result.type, smem_base, offset, [])]
offset = arith.index_cast(ir.IndexType.get(), offset)
return memref.view(result, smem_base, offset, [])
@_register_lowering(scf.ForOp)
@ -857,7 +895,8 @@ def _should_lower(op: ir.OpView) -> bool:
def lower_mgpu_dialect(
module: ir.Module, launch_context: launch_context.LaunchContext | None
module: ir.Module,
launch_context: launch_context.LaunchContext | None,
):
# TODO(apaszke,bchetioui): Make sure the layouts match.
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full

View File

@ -230,8 +230,8 @@ def main(unused_argv):
tile_n *= 2
if m < tile_m or n < tile_n:
continue
if kwargs["collective"] and tile_n >= 512:
continue # TODO(apaszke): Support 512
if tile_n > 512:
continue
if (m // tile_m) % kwargs["grid_tile_m"]:
continue
try:

View File

@ -1389,7 +1389,7 @@ class FragmentedArray:
if isinstance(self.layout, WGSplatFragLayout):
[reg] = self.registers.flat
if ir.FloatType.isinstance(self.mlir_dtype):
op = arith.mulf
op = mulf
elif ir.IntegerType.isinstance(self.mlir_dtype):
op = arith.muli
else:

View File

@ -63,7 +63,7 @@ def _choose_representative_layout(
Given the input set of possible layouts, this function extracts a single
representative layout. Currently, this function only works with strided,
splat, and WGMMA fragmented layouts.
splat, and tiled layouts.
Returns:
A single layout that can be used to annotate the operation, or None if the
@ -86,18 +86,18 @@ def _choose_representative_layout(
)
)
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
tiled_layouts: list[fa.TiledLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
filter(layouts_lib.is_tiled_layout, layouts),
)
)
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len(
layouts
):
raise ValueError(
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
f"Expected only strided, splat, and tiled layouts, got {layouts}"
)
if len(splat_layouts) > 1:
@ -112,13 +112,19 @@ def _choose_representative_layout(
"is not supported."
)
if (wgmma_layouts and strided_layouts):
if len(tiled_layouts) > 1:
raise NotImplementedError(
"Mixing strided and WGMMA layouts is not supported."
"Finding a representative layout for several distinct tiled layouts "
"is not supported."
)
if wgmma_layouts:
return layouts_lib.to_layout_attr(wgmma_layouts[0])
if tiled_layouts and strided_layouts:
raise NotImplementedError(
"Mixing strided and tiled layouts is not supported."
)
if tiled_layouts:
return layouts_lib.to_layout_attr(tiled_layouts[0])
if strided_layouts:
[strided_layout] = strided_layouts
@ -330,10 +336,16 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
return [], [layout]
@partial(_add_layout_inference_rule, vector.ReductionOp)
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
if layout := inference_utils.value_layout(op.vector):
return [layout], []
return None
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
layout = layouts_lib.to_layout_attr(fa.TILED_LAYOUT_WGMMA)
if ir.VectorType.isinstance(wgmma_op.a.type):
return [layout, layout], [layout]

View File

@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
_tiled_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
r" warp_dim\s*=\s*(?P<warp_dim>[-\d]+),"
r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
)
def to_tiled_layout_attr(
layout: fa.TiledLayout,
) -> ir.Attribute:
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]"
tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]"
return ir.Attribute.parse(
f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim},"
f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>"
)
_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[")
def from_tiled_layout_attr(
attr: ir.Attribute,
) -> fa.TiledLayout:
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute.
Raises:
ValueError: If the attribute is not a #mosaic_gpu.TiledLayout
attribute.
"""
match = _tiled_layout_attr_pattern.fullmatch(str(attr))
if not match:
raise ValueError(
f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}"
)
tiling_str = match.group("tiling")
tile_strings = []
if len(tiling_str) > 2:
tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1])
tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings)
return fa.TiledLayout(
tiling=fa.Tiling(tiles),
warp_dim=int(match.group("warp_dim")),
lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")),
vector_dim=int(match.group("vector_dim"))
)
def is_tiled_layout(attr: ir.Attribute) -> bool:
return bool(_tiled_layout_attr_pattern.search(str(attr)))
def to_layout_attr(
layout: (
fa.WGSplatFragLayout
| fa.WGStridedFragLayout
| fa.WGMMAFragLayout
| fa.TiledLayout
| fa.WGMMARowFragLayout
),
) -> ir.Attribute:
@ -108,8 +164,8 @@ def to_layout_attr(
return to_splat_fragmented_layout_attr(layout)
case fa.WGStridedFragLayout():
return to_strided_fragmented_layout_attr(layout)
case fa.WGMMAFragLayout():
return ir.Attribute.parse("#mosaic_gpu.WGMMAFragLayout")
case fa.TiledLayout():
return to_tiled_layout_attr(layout)
case fa.WGMMARowFragLayout():
return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout")
case _:
@ -118,15 +174,6 @@ def to_layout_attr(
)
_wgmma_fragmented_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.WGMMAFragLayout$"
)
def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr)))
_wgmma_row_fragmented_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.WGMMARowFragLayout$"
)
@ -141,7 +188,7 @@ def from_layout_attr(
) -> (
fa.WGSplatFragLayout
| fa.WGStridedFragLayout
| fa.WGMMAFragLayout
| fa.TiledLayout
| fa.WGMMARowFragLayout
):
"""Constructs a layout from an MLIR attribute."""
@ -149,8 +196,8 @@ def from_layout_attr(
return from_splat_fragmented_layout_attr(attr)
elif is_strided_fragmented_layout(attr):
return from_strided_fragmented_layout_attr(attr)
elif is_wgmma_fragmented_layout(attr):
return fa.WGMMAFragLayout()
elif is_tiled_layout(attr):
return from_tiled_layout_attr(attr)
elif is_wgmma_row_fragmented_layout(attr):
return fa.WGMMARowFragLayout()
else:

View File

@ -83,6 +83,7 @@ def mma(
accumulate: ir.Value | bool = True,
collective: bool = False,
):
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
@ -112,6 +113,10 @@ def mma(
raise ValueError(
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
)
if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)):
raise ValueError(
f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}"
)
f32 = ir.F32Type.get()
if element_type == f32 or element_type == ir.BF16Type.get():
if d.dtype != f32:
@ -136,11 +141,7 @@ def mma(
raise ValueError(f"N must be a multiple of 8, got: {n}")
elif n > 256 and n != 512:
raise ValueError("Only N below 256 or N=512 are supported")
if num_cta == 2 and n > 256:
raise NotImplementedError(
"N is too big for collective MMA. Only up to 256 is supported."
)
n_group_elems = min(n, 256)
n_group_elems = min(n, 256 // num_cta)
if m % m_group_elems:
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
if k % k_group_elems:
@ -179,6 +180,7 @@ def mma(
# Step 4. Issue the instructions.
true = arith.constant(ir.IntegerType.get_signless(1), 1)
n_collective_group_elems = n_group_elems * num_cta
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
@ -188,9 +190,9 @@ def mma(
raise NotImplementedError("D needs to be sliced")
acc = accumulate if ki == 0 else true
_do_mma(
d.slice(
slice(None), utils.ds(ni * n_group_elems, n_group_elems)
).address,
arith.addi(
d.address, arith.constant(i32, ni * n_collective_group_elems)
),
a_mk,
b_nk,
d_type=ir.F32Type.get(),
@ -377,8 +379,15 @@ class TMEMLayout:
+------------------+------------------+
| [0:64, 64:128] | [64:128, 64:128] |
+------------------+------------------+
The above is further complicated by column_tile_stride, which is used to
swizzle the ordering of column tiles. That is, if column_tile_stride is 2,
we will first lay out all tiles that have the column index 0, 2, 4, and so on
until we run out of tiles. Only then we lay out the tiles with column index
1, 3, etc.
"""
elements_in_tile: tuple[int, int]
column_tile_stride: int = 1
def __post_init__(self):
row_tiling = self.elements_in_tile[0]
@ -405,7 +414,7 @@ class TMEMLayout:
return num_tiles // tiles_in_row * cols_in_tile
def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout:
if shape[0] > TMEM_ROWS:
raise ValueError(
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
"Can only infer TMEM layout for shapes with row count that's a power of"
f" 2, got: {shape[0]}"
)
return TMEMLayout(elements_in_tile=(shape[0], 1))
if shape[1] % 8:
raise ValueError(
"Can only infer TMEM layout for shapes with column count that's a"
f" multiple of 8, got: {shape[1]}"
)
if collective and shape[1] == 512:
return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2)
else:
return TMEMLayout(elements_in_tile=(shape[0], 8))
@dataclasses.dataclass(frozen=True)
@ -432,7 +449,14 @@ class TMEMRef:
layout: TMEMLayout
@classmethod
def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None):
def from_alloc(
cls,
tmem_addr_ref: ir.Value,
shape: tuple[int, int],
dtype,
collective: bool | None = None,
layout: TMEMLayout | None = None,
):
i32 = ir.IntegerType.get_signless(32)
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
@ -449,7 +473,11 @@ class TMEMRef:
if shape[0] < 32:
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
if layout is None:
layout = _infer_tmem_layout(shape)
if collective is None:
raise ValueError(
"collective argument must be provided when TMEM layout is inferred"
)
layout = _infer_tmem_layout(shape, collective)
else:
layout.check_shape(shape)
# TODO: Do we have to do this??
@ -461,12 +489,17 @@ class TMEMRef:
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
if any(is_squeezed):
raise ValueError("TMEM can only be sliced, not indexed")
if self.layout.elements_in_tile[0] != TMEM_ROWS:
if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
raise NotImplementedError(
f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows"
"Slicing only implemented for refs with standard layout, got:"
f" {self.layout}"
)
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
raise NotImplementedError("TMEM cannot be sliced along rows")
if slice_shape[1] % 8:
raise NotImplementedError(
"TMEM column slice length must be a multiple of 8"
)
col_idx = base_idx[1]
if not isinstance(col_idx, ir.Value):
col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx)
@ -484,48 +517,75 @@ class TMEMRef:
raise ValueError("TMEM loads only support slicing")
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
raise NotImplementedError("Slicing of TMEM not impelmented yet")
if self.layout.elements_in_tile[0] != TMEM_ROWS:
raise NotImplementedError(
f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows"
)
if self.shape[1] % 8:
raise NotImplementedError
if self.dtype != ir.F32Type.get():
raise NotImplementedError(self.dtype)
layout = _m128_256bit_32bit_layout(self.shape)
regs_shape = layout.registers_shape(self.shape)
num = self.shape[1] // 8
# TODO(apaszke): Make the tiling configurable through the args too.
if num <= 32:
num_tiling = num
elif num == 64:
num_tiling = 32
else:
raise NotImplementedError(num)
registers = np.empty(regs_shape, dtype=object)
# We load 16 lanes at a time, but need 32 in total.
for row_group in range(2):
addr_row = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16))
regs = []
cols_per_num_tile = 8 # This depends on the 16x256b below.
for num_group in range(num // num_tiling):
addr_row_col = arith.addi(
addr_row,
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
# load_32xcols returns a 4xN array, but the FA tiling we use here tiles
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
registers = _load_32xcols(
self.address, self.shape[1], self.dtype
).T.reshape(regs_shape)
elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2):
if self.shape[1] % 128 != 0:
raise ValueError(
f"TMEM layout {self.layout} is not compatible with shape {self.shape}"
)
regs += tmem_load(addr_row_col, "16x256b", num_tiling)
regs = [llvm.bitcast(self.dtype, r) for r in regs]
vector_regs = []
undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype))
for r_low, r_high in zip(regs[::2], regs[1::2]):
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
vector_regs.append(vreg)
# Dimension 4 is the one where we split 32 rows into tiles of 8.
regs_slice = (slice(None),) * 4 + (slice(row_group * 2, (row_group + 1) * 2),)
registers[regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[regs_slice].shape)
num_column_tiles = self.shape[1] // 128
column_tile_stride = self.layout.column_tile_stride
num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride)
tiles = []
for col_tile_base in range(num_strided_col_groups):
for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride):
tiles.append(
_load_32xcols(
arith.addi(self.address, arith.constant(i32, col_tile * 128)),
cols=128,
dtype=self.dtype,
)
)
registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape)
else:
raise NotImplementedError(
f"Loads only implemented for refs with standard layout, got: {self.layout}"
)
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
def _load_32xcols(base_addr, cols, dtype):
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
i32 = ir.IntegerType.get_signless(32)
assert cols % 8 == 0
cols_per_num_tile = 8
load_shape = "16x256b"
num = cols // 8
if num <= 32:
num_tiling = num
elif num == 64:
num_tiling = 32
else:
raise NotImplementedError(num)
vector_regs = np.ndarray((4, num), dtype=object)
# We load 16 lanes at a time, but need 32 in total.
for row_group in range(2):
addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16))
regs = []
for num_group in range(num // num_tiling):
addr_row_col = arith.addi(
addr_row,
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
)
regs += tmem_load(addr_row_col, load_shape, num_tiling)
regs = [llvm.bitcast(dtype, r) for r in regs]
undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)):
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg
return vector_regs
def _m128_256bit_32bit_layout(shape: tuple[int, ...]):
if len(shape) != 2:

View File

@ -1201,3 +1201,7 @@ def bitcast(x: ir.Value, new_type: ir.Type):
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
def ceil_div(x: int, y: int):
return (x + y - 1) // y

View File

@ -18,6 +18,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
from jax._src.pallas.fuser.fusable import fusable as fusable
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

View File

@ -43,14 +43,22 @@ class MultiPageAsyncCopyDescriptor:
):
self._vmem_buf = vmem_buf
seq_id, kv_pages_start = offset
self._async_copies = [
pltpu.make_async_copy(
pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]],
vmem_buf.at[i],
sem,
)
for i in range(vmem_buf.shape[0])
]
pages_per_seq = page_indices_ref.shape[1]
self._async_copies = []
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
# a bunch of if-ops. Check the performance when we have benchmarking setup.
for i in range(vmem_buf.shape[0]):
page_idx = kv_pages_start + i
page_idx = jax.lax.select(
page_idx < pages_per_seq, page_idx, pages_per_seq - 1
)
self._async_copies.append(
pltpu.make_async_copy(
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
vmem_buf.at[i],
sem,
)
)
def start(self):
"""Starts the async copies."""

View File

@ -49,7 +49,7 @@ py_library_providing_imports_info(
config_setting(
name = "disable_jaxlib_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "False",
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "False",
},
)
@ -57,7 +57,23 @@ config_setting(
config_setting(
name = "disable_jaxlib_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "False",
"//jax:build_jaxlib": "false",
"@local_config_cuda//:enable_cuda": "True",
},
)
)
config_setting(
name = "enable_py_import_for_cpu_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "False",
},
)
config_setting(
name = "enable_py_import_for_cuda12_build",
flag_values = {
"//jax:build_jaxlib": "wheel",
"@local_config_cuda//:enable_cuda": "True",
},
)

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include "nanobind/nanobind.h"
#include "nanobind/stl/pair.h"
#include "jaxlib/absl_status_casters.h"
@ -29,7 +31,7 @@ namespace nb = nanobind;
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
int batch_size, int max_seq_length, float dropout,
bool bidirectional, bool cudnn_allow_tf32,
int workspace_size, int reserve_space_size) {
size_t workspace_size, size_t reserve_space_size) {
return PackDescriptor(RnnDescriptor{
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "jaxlib/gpu/rnn_kernels.h"
#include <cstddef>
#include <utility>
#include <vector>
@ -71,7 +72,7 @@ template <>
namespace JAX_GPU_NAMESPACE {
static absl::StatusOr<std::pair<int, int>>
static absl::StatusOr<std::pair<size_t, size_t>>
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
return std::make_pair(workSpaceSize, reserveSpaceSize);
}
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32) {

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
#define JAXLIB_GPU_RNN_KERNELS_H_
#include <cstddef>
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
@ -34,12 +36,12 @@ struct RnnDescriptor {
float dropout;
int bidirectional;
int cudnn_allow_tf32;
int workspace_size;
int reserve_space_size;
size_t workspace_size;
size_t reserve_space_size;
};
// Return (workspace size, reserve space size).
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32);

View File

@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
param.value)));
}
}
// Triton's kernel ABI expects an additional scratchpad global memory.
// For now it is only used for on-device creation of TMA descriptors, which
// we do not use yet, so we are just replacing this argument with a null
// pointer.
// TODO: b/381242007 - Allocate a proper buffer if we want to use
// device-side TMA APIs.
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
params.push_back(&scratch_ptr);
params.push_back(buffers++); // Scratch buffer.
return kernel_.Launch(stream, grid_, params.data());
}

View File

@ -224,7 +224,15 @@ def if_building_jaxlib(
"@pypi_jax_cuda12_plugin//:pkg",
"@pypi_jax_cuda12_pjrt//:pkg",
],
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]):
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"],
if_py_import = [
"//jaxlib/tools:jaxlib_py_import",
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
],
if_py_import_for_cpu = [
"//jaxlib/tools:jaxlib_py_import",
]):
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
@ -234,12 +242,16 @@ def if_building_jaxlib(
if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of
gpu-enabled builds
if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds
if_py_import: the py_import targets to depend on in case of gpu-enabled builds
if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds
"""
return select({
"//jax:enable_jaxlib_build": if_building,
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu,
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building,
"//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu,
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import,
})
# buildifier: disable=function-docstring

View File

@ -128,7 +128,6 @@ def MosaicGPU_WGStridedFragLayout : AttrDef<MosaicGPU_Dialect, "WGStridedFragLay
let assemblyFormat = "`<` $shape`,` $vector_size `>`";
}
def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout", []> {
let summary = "Annotates an array that is the result of a splat.";
let description = [{
@ -143,20 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout"
let assemblyFormat = "`<` $shape `>`";
}
def MosaicGPU_WGMMAFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMAFragLayout", []> {
let summary = "2D array that can be tiled by supported WGMMA shapes.";
let description = [{
This layout annotates arrays that are fragmented across all threads in a
warpgroup that is executing a WGMMA operation. The shape of the array is
(m, n) where:
- m % 64 == 0
- n % 8 == 0
}];
let mnemonic = "WGMMAFragLayout";
let assemblyFormat = "";
}
def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayout", []> {
let summary = "1D array that is a row that can be tiled by supported WGMMA shapes.";
let description = [{
@ -169,6 +154,24 @@ def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayou
let assemblyFormat = "";
}
def MosaicGPU_TiledLayout : AttrDef<MosaicGPU_Dialect, "TiledLayout", []> {
let summary = "A layout derived from a tiling expression.";
let description = [{
See mosaic/gpu/fragmented_array.py -> TiledLayout for more details.
}];
let parameters = (ins
"::mlir::ArrayAttr":$tiling,
"int":$warp_dim,
"::mlir::ArrayAttr":$lane_dims,
"int":$vector_dim
);
let mnemonic = "TiledLayout";
let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` "
"`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`";
}
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
// but it was not possible to reuse that definition. Including that file
// pulls in ops definitions that we don't want and they fail to compile.

View File

@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load(
"@xla//third_party/py:py_import.bzl",
"py_import",
)
load(
"@xla//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
@ -228,6 +232,18 @@ string_flag(
build_setting_default = "dist",
)
NVIDIA_WHEELS_DEPS = [
"@pypi_nvidia_cublas_cu12//:whl",
"@pypi_nvidia_cuda_cupti_cu12//:whl",
"@pypi_nvidia_cuda_runtime_cu12//:whl",
"@pypi_nvidia_cudnn_cu12//:whl",
"@pypi_nvidia_cufft_cu12//:whl",
"@pypi_nvidia_cusolver_cu12//:whl",
"@pypi_nvidia_cusparse_cu12//:whl",
"@pypi_nvidia_nccl_cu12//:whl",
"@pypi_nvidia_nvjitlink_cu12//:whl",
]
jax_wheel(
name = "jaxlib_wheel",
no_abi = False,
@ -235,6 +251,11 @@ jax_wheel(
wheel_name = "jaxlib",
)
py_import(
name = "jaxlib_py_import",
wheel = ":jaxlib_wheel",
)
jax_wheel(
name = "jaxlib_wheel_editable",
editable = True,
@ -252,6 +273,12 @@ jax_wheel(
wheel_name = "jax_cuda12_plugin",
)
py_import(
name = "jax_cuda_plugin_py_import",
wheel = ":jax_cuda_plugin_wheel",
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)
jax_wheel(
name = "jax_cuda_plugin_wheel_editable",
editable = True,
@ -290,6 +317,12 @@ jax_wheel(
wheel_name = "jax_cuda12_pjrt",
)
py_import(
name = "jax_cuda_pjrt_py_import",
wheel = ":jax_cuda_pjrt_wheel",
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
)
jax_wheel(
name = "jax_cuda_pjrt_wheel_editable",
editable = True,

View File

@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase):
k = jax.random.split(jax.random.PRNGKey(1), 4)
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
stablehlo)
if jtu.jaxlib_version() <= (0, 5, 2):
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
stablehlo)
else:
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"',
stablehlo)
@jtu.run_on_devices("cuda")
def test_no_workspace_overflow(self):
if jtu.jaxlib_version() <= (0, 5, 2):
self.skipTest("Older versions fail because of integer overflow.")
# Problem sizes known to cause overflows on older versions.
batch_size, max_seq_length, input_size = 256, 500, 512
num_layers, hidden_size = 1, 256
num_params = rnn.get_num_params_in_lstm(
input_size, hidden_size, num_layers, True)
x = jax.ShapeDtypeStruct(
(batch_size, max_seq_length, input_size), jnp.float32)
h_0 = jax.ShapeDtypeStruct(
(2 * num_layers, batch_size, hidden_size), jnp.float32)
c_0 = jax.ShapeDtypeStruct(
(2 * num_layers, batch_size, hidden_size), jnp.float32)
weights = jax.ShapeDtypeStruct((num_params,), jnp.float32)
seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32)
fun = jax.jit(partial(
rnn.lstm, input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, dropout=0.0, bidirectional=True))
fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash.
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -2445,7 +2445,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
assert b.shape == ()
return c, b
xs = jnp.ones((5, 3))
xs = jnp.ones((20, 3))
c = jnp.ones(4)
scan = lambda c, xs: lax.scan(f, c, xs)
@ -2502,6 +2502,28 @@ class LaxControlFlowTest(jtu.JaxTestCase):
x, n = jnp.arange(3), jnp.arange(4)
jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
def test_disable_jit_while_loop_with_mutation(self):
# https://github.com/jax-ml/jax/issues/27019
def body_fun(carry):
x, y = carry
x += 1 # in-place if x is mutable
return x, y + x
def cond_fun(carry):
x, _ = carry
return x < 10
def f():
val = np.array(1.0) # mutable value
return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1]
with jax.disable_jit(False):
result_jit = f()
with jax.disable_jit(True):
result_nojit = f()
self.assertEqual(result_jit, result_nojit)
@parameterized.named_parameters(
{"testcase_name": f"_{shape}_{axis=}",
"shape": shape, "axis": axis}

View File

@ -278,6 +278,35 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
with jax.checking_leaks():
lsp_special.expi(jnp.ones(()))
def testExpiDisableJit(self):
# Regression test for https://github.com/jax-ml/jax/issues/27019
x = jnp.array([-0.5])
with jax.disable_jit(True):
result_nojit = lsp_special.expi(x)
with jax.disable_jit(False):
result_jit = lsp_special.expi(x)
self.assertAllClose(result_jit, result_nojit)
def testGammaIncBoundaryValues(self):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
nan = float('nan')
inf = float('inf')
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
def testGammaIncCBoundaryValues(self):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
nan = float('nan')
inf = float('inf')
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -39,7 +39,10 @@ jax_multiplatform_test(
],
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
shard_count = 16,
tags = ["multiaccelerator"],
tags = [
"multiaccelerator",
"noasan", # Times out.
],
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),

View File

@ -210,7 +210,7 @@ class LayoutInferenceTest(parameterized.TestCase):
for layout in [
mgpu.WGSplatFragLayout(shape),
mgpu.WGStridedFragLayout(shape, vec_size=4),
mgpu.WGMMAFragLayout(),
mgpu.TILED_LAYOUT_WGMMA,
]
)
def test_infer_layout_from_yield_op_in_layouts_for_for_op(
@ -278,7 +278,7 @@ class LayoutInferenceTest(parameterized.TestCase):
mgpu.infer_layout(self.module)
wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout())
wgmma_layout = layouts.to_layout_attr(mgpu.TILED_LAYOUT_WGMMA)
self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout])
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout])
@ -312,7 +312,7 @@ class LayoutInferenceTest(parameterized.TestCase):
@parameterized.parameters(
mgpu.WGStridedFragLayout((32, 4), vec_size=1),
mgpu.WGMMAFragLayout(),
mgpu.TILED_LAYOUT_WGMMA,
)
def test_infer_layout_picks_non_splat_layout_over_splat_layout(
self, layout

View File

@ -1026,7 +1026,7 @@ class TCGen05Test(TestCase):
in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
m=(256,), # TODO(apaszke): 64, 192, 256
n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2
n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
k_steps=(1, 2),
swizzle=(32, 64, 128,),
)

View File

@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase):
@jax.jit
def f(x_ref):
self.assertEqual(core.get_ty(x_ref).sharding.spec,
core.get_ty(x_ref[...]).sharding.spec)
self.assertEqual(core.typeof(x_ref).sharding.spec,
core.typeof(x_ref[...]).sharding.spec)
y = x_ref[...] + 1
return y

View File

@ -184,6 +184,23 @@ class PallasCallTest(PallasTest):
y = jnp.flip(x).reshape(1, 256)
np.testing.assert_array_equal(kernel(x, y), x + y[0])
@parameterized.product(
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_reduce_sum(self, shape, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), jnp.sum(x))
def test_reshape(self):
shape1, shape2 = (128,), (2, 16, 4)
@ -200,10 +217,14 @@ class PallasCallTest(PallasTest):
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
def test_add_xy_indexed(self):
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
def test_add_xy_indexed(self, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, y_ref, o_ref):
idx = _sum_same_dtype(y_ref[...])
@ -1078,10 +1099,14 @@ class PallasCallTest(PallasTest):
self.assertIn("acc % 2", output())
def test_cond_returning_array(self):
@parameterized.parameters([*plgpu.ThreadSemantics])
def test_cond_returning_array(self, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(
thread_semantics=thread_semantics
),
)
def kernel(x_ref, o_ref):
acc = _sum_same_dtype(x_ref[...])

View File

@ -470,6 +470,27 @@ class OpsTest(PallasBaseTest):
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
np.testing.assert_array_equal(out, expected)
def test_reduce_with_const(self):
m = 1
d = 1024
x = jnp.ones((m, d), jnp.bfloat16)
def dot(x, y):
return jax.lax.dot_general(
x,
y,
(((1,), (1,)), ((), ())),
preferred_element_type=jnp.float32,
)
def kernel(x, out):
out[:] = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
run = pl.pallas_call(kernel, jax.ShapeDtypeStruct((m, 1), jnp.float32))
output = run(x)
expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
np.testing.assert_array_equal(output, expected)
class OpsInterpretTest(OpsTest):
INTERPRET = True

View File

@ -64,10 +64,6 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
max_num_seq = max(len(seq_lens), max_num_seq)
max_kv_len = max(kv_lens)
pages_per_seq = ceil_div(max_kv_len, page_size)
pages_per_seq = (
ceil_div(pages_per_seq, num_kv_pages_per_block)
* num_kv_pages_per_block
)
num_q_heads, num_kv_heads = num_heads
cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
@ -130,8 +126,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
num_seqs=num_seqs,
)
tols = {
"float32": 1e-1,
"bfloat16": 2e-1,
"float32": 0.15,
"bfloat16": 0.2,
}
tol = tols[jnp.dtype(dtype).name]
self.assertAllClose(output, expected, atol=tol, rtol=tol)

View File

@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
def f(x):
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
x = x * 2
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
x = x * x
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
return x
# Eager mode

View File

@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
)
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
),
)
def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
device_type = jax.devices()[0].platform
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
raise unittest.SkipTest(
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
f' v{jtu.get_tpu_version()}'
)
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
operand = jax.device_put(
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output = jax.device_put(
jnp.zeros((2, 4), dtype=jnp.float32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
input_offsets = jax.device_put(
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
send_sizes = jax.device_put(
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
output_offsets = jax.device_put(
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
recv_sizes = jax.device_put(
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
)
@partial(
shard_map,
mesh=mesh,
in_specs=(
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
P(axis_name, None),
),
out_specs=P(axis_name),
check_rep=False,
)
def fwd(
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
):
operand = operand.reshape(operand.shape[1:])
output = output.reshape(output.shape[1:])
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
return lax.ragged_all_to_all(
operand,
output,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=axis_name,
)
args = input_offsets, send_sizes, output_offsets, recv_sizes
jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1)
@parameterized.named_parameters(
dict(
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240"
XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603"
XLA_COMMIT = "c270a6ce45df7f7bb3024f2e4df56b688d76ebd6"
XLA_SHA256 = "b2f7d0293fc62bb670d0b58c5847108652eac4d9e6c7e420bed2029e74af6f2d"
def repo():
tf_http_archive(