diff --git a/.bazelrc b/.bazelrc index 2fb51664b..47cf3e766 100644 --- a/.bazelrc +++ b/.bazelrc @@ -253,12 +253,6 @@ build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -# Mac x86 CI configs -build:ci_darwin_x86_64 --macos_minimum_os=11.0 -build:ci_darwin_x86_64 --config=macos_cache_push -build:ci_darwin_x86_64 --verbose_failures=true -build:ci_darwin_x86_64 --color=yes - # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 build:ci_darwin_arm64 --config=macos_cache_push diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index 4dc8e55fb..a92e3cc19 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -3,6 +3,7 @@ # This job currently runs as a non-blocking presubmit. It is experimental and is currently being # tested to get to a stable state before we enable it as a blocking presubmit. name: CI - Cloud TPU (presubmit) + on: workflow_dispatch: inputs: @@ -33,64 +34,32 @@ concurrency: cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} jobs: - cloud-tpu-test: + build-jax-artifacts: if: github.event.repository.fork == false -# Begin Presubmit Naming Check - name modification requires internal check to be updated + uses: ./.github/workflows/build_artifacts.yml strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - tpu: [ - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} - ] - python-version: ["3.10"] - name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})" -# End Presubmit Naming Check github-tpu-presubmits - env: - JAXCI_PYTHON: python${{ matrix.python-version }} - JAXCI_TPU_CORES: ${{ matrix.tpu.cores }} + fail-fast: false # don't cancel all jobs on failure + matrix: + artifact: ["jax", "jaxlib"] + with: + runner: "linux-x86-n2-16" + artifact: ${{ matrix.artifact }} + python: "3.10" + clone_main_xla: 1 + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - runs-on: ${{ matrix.tpu.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" - - timeout-minutes: 60 - - defaults: - run: - shell: bash -ex {0} - steps: - # https://opensource.google/documentation/reference/github/services#actions - # mandates using a specific commit for non-Google actions. We use - # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # Checkout XLA at head, if we're building jaxlib at head. - - name: Checkout XLA at head - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: openxla/xla - path: xla - # We need to mark the GitHub workspace as safe as otherwise git commands will fail. - - name: Mark GitHub workspace as safe - run: | - git config --global --add safe.directory "$GITHUB_WORKSPACE" - - name: Install JAX test requirements - run: | - $JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt - - name: Build jaxlib at head with latest XLA - run: | - # Build and install jaxlib at head - $JAXCI_PYTHON build/build.py build --wheels=jaxlib \ - --python_version=${{ matrix.python-version }} \ - --bazel_options=--config=rbe_linux_x86_64 \ - --local_xla_path="$(pwd)/xla" \ - --verbose - - # Install libtpu - $JAXCI_PYTHON -m uv pip install --pre libtpu \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Install jaxlib wheel and run tests - run: ./ci/run_pytest_tpu.sh \ No newline at end of file + run-pytest-tpu: + if: github.event.repository.fork == false + needs: [build-jax-artifacts] + uses: ./.github/workflows/pytest_tpu.yml + # Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "TPU test (jaxlib=head, v5e-8)" + with: + runner: "linux-x86-ct5lp-224-8tpu" + cores: "8" + tpu-type: "v5e-8" + python: "3.10" + libtpu-version-type: "nightly" + gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} + # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml index e64f81809..1cfc7a883 100644 --- a/.github/workflows/pytest_cpu.yml +++ b/.github/workflows/pytest_cpu.yml @@ -116,6 +116,9 @@ jobs: exit 1 - name: Install Python dependencies run: | + # Remove installation of NVIDIA wheels for CPU tests. + sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in + # TODO(srnitin): Remove after uv is installed in the Windows Dockerfile $JAXCI_PYTHON -m pip install uv~=0.5.30 # python 3.13t cannot compile zstandard 0.23.0 due to diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml new file mode 100644 index 000000000..a105a2feb --- /dev/null +++ b/.github/workflows/pytest_tpu.yml @@ -0,0 +1,151 @@ +# CI - Pytest TPU +# +# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via +# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib wheel from a GCS bucket. +# - Sets up the libtpu wheels. +# - Executes the `run_pytest_cpu.sh` script, which performs the following actions: +# - Installs the downloaded jaxlib wheel. +# - Runs the TPU tests with Pytest. +name: CI - Pytest TPU + +on: + workflow_call: + inputs: + # Note that the values for runners, cores, and tpu-type are linked to each other. + # For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the + # following mapping: + # {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + # {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-ct5lp-224-8tpu" + cores: + description: "How many TPU cores should the test use?" + type: string + required: true + default: "8" + tpu-type: + description: "Which TPU type is used for testing?" + type: string + required: true + default: "v5e-8" + python: + description: "Which Python version should be used for testing?" + type: string + required: true + default: "3.12" + run-full-tpu-test-suite: + description: "Should the full TPU test suite be run?" + type: string + required: false + default: "0" + libtpu-version-type: + description: "Which libtpu version should be used for testing?" + type: string + required: false + # Choices are: + # - "nightly": Use the nightly libtpu wheel. + # - "pypi_latest": Use the latest libtpu wheel from PyPI. + # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. + default: "nightly" + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + required: true + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false + +jobs: + run-tests: + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + # Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" + # End Presubmit Naming Check github-tpu-presubmits + + env: + LIBTPU_OLDEST_VERSION_DATE: 20241205 + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" + JAXCI_TPU_CORES: "${{ inputs.cores }}" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set env vars for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + # Python wheels follow a naming convention: standard wheels use the pattern + # `*-cp-cp-*`, while free-threaded wheels use + # `*-cp-cpt-*`. + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV + - name: Download JAX wheels from GCS + id: download-wheel-artifacts + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in the step below so that we can print a more + # informative error message. + continue-on-error: true + run: | + mkdir -p $(pwd)/dist + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + - name: Skip the test run if the wheel artifacts were not downloaded successfully + if: steps.download-wheel-artifacts.outcome == 'failure' + run: | + echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" + echo "built successfully by the artifact build jobs and are available in the GCS bucket." + echo "Skipping the test run." + exit 1 + - name: Install Python dependencies + run: | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt + - name: Set up libtpu wheels + run: | + if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then + echo "Using nightly libtpu" + $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then + echo "Using latest libtpu from PyPI" + # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` + # script will install the latest libtpu wheel from PyPI. + echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV + elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then + echo "Using oldest supported libtpu" + $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + + echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV + else + echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" + exit 1 + fi + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Pytest TPU tests + timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }} + run: ./ci/run_pytest_tpu.sh diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 5c818bf56..f12d8a7f0 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -142,4 +142,30 @@ jobs: python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-pytest-tpu: + # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated + # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we + # still want to run the tests for other platforms. + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact] + uses: ./.github/workflows/pytest_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.10",] + tpu-specs: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] + name: "TPU tests (jax=head, jaxlib=head)" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index b88b000e4..574cc0628 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -58,4 +58,42 @@ jobs: python: ${{ matrix.python }} cuda: ${{ matrix.cuda }} enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + run-pytest-tpu: + uses: ./.github/workflows/pytest_tpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for + # profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302) + python: ["3.10", "3.11", "3.12"] + tpu-specs: [ + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + ] + libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] + exclude: + - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} + # Run a single Python version for v4-8. + - tpu-specs: + type: "v4-8" + python: "3.10" + - tpu-specs: + type: "v4-8" + python: "3.11" + # Run min and max Python versions for v5e-8 + - tpu-specs: + type: "v5e-8" + python: "3.11" + name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + with: + runner: ${{ matrix.tpu-specs.runner }} + cores: ${{ matrix.tpu-specs.cores }} + tpu-type: ${{ matrix.tpu-specs.type }} + python: ${{ matrix.python }} + run-full-tpu-test-suite: "1" + libtpu-version-type: ${{ matrix.libtpu-version-type }} gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/README.md b/README.md index 0aca7cf58..c2e8657a6 100644 --- a/README.md +++ b/README.md @@ -456,3 +456,4 @@ For details about the JAX API, see the For getting started as a JAX developer, see the [developer documentation](https://jax.readthedocs.io/en/latest/developer.html). + diff --git a/build/BUILD.bazel b/build/BUILD.bazel index cf43fdab0..f088cd58a 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -29,7 +29,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = True, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) compile_pip_requirements( @@ -44,7 +44,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = False, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) compile_pip_requirements( @@ -58,7 +58,7 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, generate_hashes = False, - data = ["test-requirements.txt"] + data = ["test-requirements.txt", "gpu-test-requirements.txt"] ) py_library( diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt new file mode 100644 index 000000000..ff43f91ba --- /dev/null +++ b/build/gpu-test-requirements.txt @@ -0,0 +1,13 @@ +# NVIDIA CUDA dependencies +# Note that the wheels are downloaded only when the targets in bazel command +# contain dependencies on these wheels. +nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" +nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" +nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" +nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" +nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" +nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" +nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" +nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" +nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" diff --git a/build/requirements.in b/build/requirements.in index d4e13d943..ec7fc71b0 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -2,6 +2,7 @@ # test deps # -r test-requirements.txt +-r gpu-test-requirements.txt # # build deps diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index 290c7e732..6ed6b59aa 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -304,24 +304,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -380,6 +387,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # via -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # via -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # via -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index f73065950..8446e8361 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -299,24 +299,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index feebc33dc..0436ab6dd 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -299,24 +299,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 \ - --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ - --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ - --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ - --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ - --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ - --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ - --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ - --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ - --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ - --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ - --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ - --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ - --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ - --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ - --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ - --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ - --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ @@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \ # ml-dtypes # opt-einsum # scipy +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 0a32888f6..e74d40b79 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -347,28 +347,31 @@ mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.5.0 \ - --hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \ - --hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \ - --hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \ - --hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \ - --hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \ - --hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \ - --hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \ - --hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \ - --hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \ - --hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \ - --hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \ - --hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \ - --hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \ - --hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \ - --hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \ - --hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \ - --hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \ - --hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \ - --hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \ - --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \ - --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599 +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 # via -r build/requirements.in mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ @@ -434,6 +437,64 @@ numpy==2.1.2 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index dfefaf042..e7a2968e9 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \ # matplotlib # ml-dtypes # scipy +nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ + --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ + --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 + # via + # -r build/test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ + --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ + --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 + # via -r build/test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ + --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ + --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b + # via -r build/test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ + --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ + --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 + # via -r build/test-requirements.txt +nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ + --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ + --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ + --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef + # via -r build/test-requirements.txt +nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ + --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ + --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 + # via -r build/test-requirements.txt +nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ + --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ + --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac + # via -r build/test-requirements.txt +nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ + --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ + --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 + # via + # -r build/test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r build/test-requirements.txt +nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ + --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ + --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 + # via + # -r build/test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/ci/envs/default.env b/ci/envs/default.env index 7a2448944..a5a5d56eb 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. -export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} \ No newline at end of file +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} + +# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests +# is run. +export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} + +# We use this environment variable to control which additional wheels to install +# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest +# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the +# list of valid values and their behavior. +export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index feaccea8e..5d8aa9ed6 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -52,23 +52,46 @@ export JAX_SKIP_SLOW_TESTS=true echo "Running TPU tests..." -# Run single-accelerator tests in parallel -JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ - --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ - --maxfail=20 -m "not multiaccelerator" \ - tests/pallas/ops_test.py \ - tests/pallas/export_back_compat_pallas_test.py \ - tests/pallas/export_pallas_test.py \ - tests/pallas/tpu_ops_test.py \ - tests/pallas/tpu_pallas_test.py \ - tests/pallas/tpu_pallas_random_test.py \ - tests/pallas/tpu_pallas_async_test.py \ - tests/pallas/tpu_pallas_state_test.py +if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then + # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic + # TPU does not guarantee anything about forward compatibility (unless + # jax.export is used) and the 12 week compatibility window accumulates way + # too many failures. + IGNORE_FLAGS="" + if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then + IGNORE_FLAGS="--ignore=tests/pallas" + fi -# Run Pallas printing tests, which need to run with I/O capturing disabled. -TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples -# Run multi-accelerator across all chips -"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ - tests/pjit_test.py \ - tests/pallas/tpu_pallas_distributed_test.py \ No newline at end of file + # Run Pallas printing tests, which need to run with I/O capturing disabled. + TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \ + tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + + # Run multi-accelerator across all chips + "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests +else + # Run single-accelerator tests in parallel + JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" \ + tests/pallas/ops_test.py \ + tests/pallas/export_back_compat_pallas_test.py \ + tests/pallas/export_pallas_test.py \ + tests/pallas/tpu_ops_test.py \ + tests/pallas/tpu_pallas_test.py \ + tests/pallas/tpu_pallas_random_test.py \ + tests/pallas/tpu_pallas_async_test.py \ + tests/pallas/tpu_pallas_state_test.py + + # Run Pallas printing tests, which need to run with I/O capturing disabled. + TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + + # Run multi-accelerator across all chips + "$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ + tests/pjit_test.py \ + tests/pallas/tpu_pallas_distributed_test.py +fi \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 41274b95f..f98f7658a 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -17,8 +17,19 @@ # Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python # binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to # avoid using the Windows version of `find` on Msys. + WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) +for i in "${!WHEELS[@]}"; do + if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then + if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then + # Append [tpu] to the jax wheel name to download the latest libtpu wheel + # from PyPI. + WHEELS[$i]="${WHEELS[$i]}[tpu]" + fi + fi +done + if [[ -z "${WHEELS[@]}" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" exit 1 diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 5f51cdb3b..9db79f591 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -58,6 +58,7 @@ Operators clz collapse complex + composite concatenate conj conv diff --git a/jax/BUILD b/jax/BUILD index fb2e59864..5e9cbd7a1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -14,7 +14,7 @@ # JAX is Autograd and XLA -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", @@ -45,17 +45,26 @@ package( licenses(["notice"]) -# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and -# assume it has been installed, e.g., by `pip`. -bool_flag( +# The flag controls whether jaxlib should be built by Bazel. +# If ":build_jaxlib=true", then jaxlib will be built. +# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel +# is available in the "dist" folder. +# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( name = "build_jaxlib", - build_setting_default = True, + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], ) config_setting( name = "enable_jaxlib_build", flag_values = { - ":build_jaxlib": "True", + ":build_jaxlib": "true", }, ) @@ -681,6 +690,7 @@ pytype_strict_library( deps = [ ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", + "//jax/_src/pallas/fuser:custom_evaluate", "//jax/_src/pallas/fuser:fusable", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", diff --git a/jax/__init__.py b/jax/__init__.py index 950c3ed4b..ae3bac4ad 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc Device = _xc.Device del _xc -from jax._src.core import get_ty as get_ty +from jax._src.core import typeof as typeof from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 diff --git a/jax/_src/config.py b/jax/_src/config.py index 00f65726a..1e46fb8bd 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,6 +235,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, + varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, @@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state( help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) +varying_axes_in_types = bool_state( + name='jax_varying_axes_in_types', + default=False, + help=('Adds varying manual axes to ShapedArray to track which mesh axes the' + ' array is varying over. This will help to remove the efficient' + ' transpose rewrite machinery in shard_map'), + include_in_jit_key=True) + data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 9d8edeb8b..e53aec755 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1576,7 +1576,7 @@ def get_aval(x): return get_aval(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") -get_ty = get_aval +typeof = get_aval def is_concrete(x): return to_concrete_value(x) is not None @@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape): class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent array_abstraction_level = 2 - def __init__(self, shape, dtype, weak_type=False, *, sharding=None): + def __init__(self, shape, dtype, weak_type=False, *, sharding=None, + varying_manual_axes: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) + if config.varying_axes_in_types.value: + self.varying_manual_axes = varying_manual_axes def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding + if 'varying_manual_axes' not in kwargs: + kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', + frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and self.sharding == other.sharding) + and self.sharding == other.sharding + and (getattr(self, 'varying_manual_axes', frozenset()) == + getattr(other, 'varying_manual_axes', frozenset()))) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, self.sharding)) + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + getattr(self, 'varying_manual_axes', frozenset()))) def to_tangent_aval(self): - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding) + return ShapedArray( + self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, sharding=self.sharding, + varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c7dee3e71..3084fa722 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1364,9 +1364,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") if config.disable_jit.value: try: - val = init_val + val = tree_map(lax.asarray, init_val) while cond_fun(val): - val = body_fun(val) + val = tree_map(lax.asarray, body_fun(val)) return val except core.ConcretizationTypeError: # Can't run this while_loop in Python (e.g. because there's a vmap diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 99760099d..12706426b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1489,14 +1489,14 @@ def composite( ): """Composite with semantics defined by the decomposition function. - A composite is a higher-order JAX function that encapsulates an operation mad + A composite is a higher-order JAX function that encapsulates an operation made up (composed) of other JAX functions. The semantics of the op are implemented by the ``decomposition`` function. In other words, the defined composite function can be replaced with its decomposed implementation without changing the semantics of the encapsulated operation. The compiler can recognize specific composite operations by their ``name``, - ``version``, ``kawargs``, and dtypes to emit more efficient code, potentially + ``version``, ``kwargs``, and dtypes to emit more efficient code, potentially leveraging hardware-specific instructions or optimizations. If the compiler doesn't recognize the composite, it falls back to compiling the ``decomposition`` function. @@ -1505,11 +1505,11 @@ def composite( be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could recognize the "tangent" composite and emit a single ``tangent`` instruction instead of three separate instructions (``sin``, ``divide``, and ``cos``). - With compilers for hardwares without dedicated tangent support, it would fall - back to compiling the decomposition. + For hardware without dedicated tangent support, it would fall back to + compiling the decomposition. - This is useful for preserving high level abstraction that would otherwise be - lost while lowering which allows for easier pattern-matching in low-level IR. + This is useful for preserving high-level abstractions that would otherwise be + lost while lowering, which allows for easier pattern-matching in low-level IR. Args: decomposition: function that implements the semantics of the composite op. @@ -1517,19 +1517,20 @@ def composite( version: optional int to indicate semantic changes to the composite. Returns: - out: callable composite function. Note that positional arguments to this - function should be interpreted as inputs and keyword arguments should be - interpreted as attributes of the op. Any keyword arguments that are passed - with ``None`` as a value will be omitted from the - ``composite_attributes``. + Callable: Returns a composite function. Note that positional arguments to + this function should be interpreted as inputs and keyword arguments should + be interpreted as attributes of the op. Any keyword arguments that are + passed with ``None`` as a value will be omitted from the + ``composite_attributes``. Examples: Tangent kernel: + >>> def my_tangent_composite(x): ... return lax.composite( - ... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent' + ... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent" ... )(x) - ... + >>> >>> pi = jnp.pi >>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): @@ -1538,9 +1539,10 @@ def composite( [ 0. 1. -1. 0.] [ 0. 1. -1. 0.] - The recommended way to create composites is via a decorator. Use `/` and `*` - in the function signature to be explicit about positional and keyword - arguments respectively: + The recommended way to create composites is via a decorator. Use ``/`` and + ``*`` in the function signature to be explicit about positional and keyword + arguments, respectively: + >>> @partial(lax.composite, name="my.softmax") ... def my_softmax_composite(x, /, *, axis): ... return jax.nn.softmax(x, axis) @@ -3014,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) + shard = shard.addressable_data(0) return array.make_array_from_callback(shape, sharding, lambda _: shard) if sharding is not None and not sharding._is_concrete: @@ -8194,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + sharding=x_aval.sharding.with_spec(P())) _ones: Callable = partial(full_like, fill_value=1) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index b556042fe..764e4dcbe 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -22,6 +22,7 @@ from functools import partial import itertools import math +import jax from jax import tree_util from jax._src import core from jax._src import dispatch @@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, def ragged_all_to_all( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *, axis_name, axis_index_groups = None): - """Ragged version of :func:`all_to_all`. + """Ragged version of :func:`all_to_all` collective. - For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent - and the outermost (ragged) dimension. ``axis_index_groups`` is default to all - replicas (e.g. there is only one group and covers all axis indices). + We say data are "ragged" when they can be represented as a list of arrays + whose shapes differ only in the size of the leading axis. For example, these + data are ragged, comprising four component arrays:: - Ragged arrays are defined by a set of three arrays: - * ``data``: the ``data`` array is "ragged" along its outermost dimension, - along which each indexed element has variable size. - * ``offsets``: the ``offsets`` array indexes the outermost dimension of the - ``data`` array, and represents the starting offset of each ragged element of - the ``data`` array. - * ``sizes``: the ``sizes`` array represents the size of each ragged element of - the ``data`` array, where the size is specified in units of sub-elements. A - sub-element is defined as the suffix of the ``data`` array shape obtained by - removing the outermost "ragged" dimension. - The ``offsets`` and ``sizes`` arrays must have the same size. + ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)] - # Example ragged tensor - data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} - offsets: [3] = {0, 1, 4} - sizes: [3] = {1, 3, 4} + We often instead want a contiguous representation, e.g. for batching. But + because the shapes of the components differ, we can't apply ``jnp.stack`` to + represent these data by a single rectangular array with the leading axis + indexing the component arrays. So instead of stacking, we concatenate along + the leading axis and keep track of offsets and sizes. - # Index 'data' at 'offsets'[0], 'sizes'[0]' - {a,b,c} + That is, we can represent ragged data contiguously using a triple of dense + arrays ``(data, offsets, sizes)``: + * ``data``: the concatenated component arrays, + * ``offsets``: 1D array of indices into the leading axis of ``data`` + indicating where the data for each component array begins, + * ``sizes``: 1D array of sizes of the leading axis of each component array. + We refer to this triple as a ragged array. (Offsets can't be computed from + sizes in general to allow for internal padding.) - # Index 'data' at 'offsets'[1], 'sizes'[1]' - {d,e,f},{g,h,i},{j,k,l} + For example:: + data: f32[8,3] = jnp.array([ + [a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x], + ]) + offsets: i32[3] = jnp.array([0, 1, 4]) + sizes: i32[3] = jnp.array([1, 3, 4]) - # Index 'data' at 'offsets'[2], 'sizes'[2]' - {m,n,o},{p,q,r},{s,t,u},{v,w,x} + # To extract the first component array, of type f32[1,3] + data[offsets[0]:offsets[0]+sizes[0]] + # To extract the second component array, of type f32[3,3] + data[offsets[1]:offsets[1]+sizes[1]] - ``output_offsets`` must be sharded in a way that each replica has offsets in - the target replica output perspective. + # To extract the third component array, of type f32[4,3] + data[offsets[2]:offsets[2]+sizes[2]] - For i-th output offset, the current replica will send - `operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th - replica that will be written to - `output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th - replica ``output``. + The ``ragged_all_to_all`` collective operation communicates slices of ragged + arrays between devices. Each caller is both a sender and a receiver. The + ``input_offsets`` and ``send_sizes`` arguments indicate the slices of the + caller's ``operand`` to be sent. Received results are returned in an array + that has the same value of the argument ``output`` except with received values + written at some slices. The ``output_offsets`` argument does *not* indicate + the offsets at which all the received results are written; instead, + ``output_offsets`` indicates the offsets at which the *sent* slices are + written on their corresponding receivers. The sizes of received slices are + indicated by ``recv_sizes``. See below for details. - For example, if we have 2 replicas: + The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and + ``recv_sizes`` must all be the same length, and that length must be divisible + by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and + ``recv_sizes`` must satisfy:: - replica 0: - operand: [1, 2, 2] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 2] - output_offsets: [0, 0] - recv_sizes: [1, 1] + jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True)) - replica 1: - operand: [3, 4, 0] - output: [0, 0, 0, 0] - input_offsets: [0, 1] - send_sizes: [1, 1] - output_offsets: [1, 2] - recv_sizes: [2, 1] + Specifically, given a call:: - replica 0's result will be: [1, 3, 0, 0] - replica 1's result will be: [2, 2, 4, 0] + result = ragged_all_to_all(operand, output, input_offsets, send_sizes, + output_offsets, recv_sizes, axis_name) + + the caller sends data like:: + + assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes) + N = len(input_offsets) + slices_per_device, leftover = divmod(N, lax.axis_size(axis_name)) + assert not leftover + + for i in range(N): + dst_idx = i // slices_per_device + SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]], + axis_name=axis_name, to_axis_index=dst_idx) + + and receives data in ``result`` like:: + + result = output + output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + for i in range(N): + src_idx = i // slices_per_device + result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i] + ].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx)) + + where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local + ``output_offsets`` does not indicate the offsets at which its local ``result`` + is updated; instead, it indicates where the corresponding sent slices are + written on their destination instances. To compute the local offsets at which + received data are written, we apply an ``all_to_all`` on ``output_offsets``. + + For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with + these arguments in each mapped function instance:: + + axis index 0: + operand = [1, 2, 2] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 2] + output_offsets = [0, 0] + recv_sizes = [1, 1] + + axis index 1: + operand = [3, 4, 0] + output = [0, 0, 0, 0] + input_offsets = [0, 1] + send_sizes = [1, 1] + output_offsets = [1, 2] + recv_sizes = [2, 1] + + then:: + + axis index 0: + result = [1, 3, 0, 0] + + axis index 1: + result = [2, 2, 4, 0] Args: - operand: array with ragged dimension along its outermost dimension. - output: array of ragged input offsets. - input_offsets: array of ragged input send sizes. - send_sizes: array of ragged output data. - output_offsets: array of ragged offsets in the target replica output. - recv_sizes: array of ragged output receive sizes. - axis_name: hashable Python object used to name a pmapped axis (see the - :func:`jax.pmap` documentation for more details). + operand: data array of shape (N, A, B, ...) representing concatenated + (possibly padded) ragged data to be sent. + output: data array of shape (M, A, B, ...) to update with received data. + input_offsets: 1D integer array of shape (K,) representing the offsets of + leading-axis slices into ``operand`` to be sent. + send_sizes: 1D integer array array of shape (K,) representing the sizes of + leading-axis slices into ``operand`` to be sent. + output_offsets: 1D integer array of shape (K,) representing where the + corresponding sent data is written on each corresponding receiver. + recv_sizes: 1D integer array of shape (K,) representing sizes of + leading-axis slices into ``output`` to update with received data. + axis_name: name of the mapped axis over which to perform the communication. axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the first two and last two replicas). Groups must cover all axis indices @@ -538,7 +596,10 @@ def ragged_all_to_all( behavior is undefined. Returns: - array with shape equal to ``output``. + Array of shape (M, A, B, ...) with the same value as the ``output`` except + with received data written into slices starting at + ``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size + ``recv_sizes``. """ if not isinstance(axis_name, (tuple, list)): @@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval( effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects +def _ragged_all_to_all_jvp(primals, tangents, **params): + operand, output, *sizes_and_offsets = primals + operand_dot, output_dot, *_ = tangents + result = ragged_all_to_all_p.bind( + operand, output, *sizes_and_offsets, **params) + if type(operand_dot) is type(output_dot) is ad.Zero: + result_dot = ad.Zero.from_primal_value(result) + else: + operand_dot = ad.instantiate_zeros(operand_dot) + output_dot = ad.instantiate_zeros(output_dot) + result_dot = ragged_all_to_all_p.bind( + operand_dot, output_dot, *sizes_and_offsets, **params) + return result, result_dot + +def _ragged_all_to_all_transpose( + t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, + *, axis_name, axis_index_groups): + if type(t) is ad.Zero: + operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None + output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None + else: + zero = ad.zeros_like_aval(operand.aval) + output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True) + input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True) + operand_t = ragged_all_to_all_p.bind( + t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, + axis_name=axis_name, axis_index_groups=axis_index_groups) + mask = jax.numpy.cumsum( + jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + .at[output_offsets_ + recv_sizes].add(-1)) + output_t = jax.numpy.where(mask, 0, t) + return [operand_t, output_t] + [None] * 4 + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) +ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp +ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index f6cd8bc0d..4a36bf186 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -303,15 +303,16 @@ def _igamma_series(ax, x, a, enabled, dtype, mode): def igamma_impl(a, x, *, dtype): is_nan = bitwise_or(_isnan(a), _isnan(x)) - x_is_zero = eq(x, _const(x, 0)) x_is_infinity = eq(x, _const(x, float('inf'))) - domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0))) - use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a)) + a_is_zero = eq(a, _const(a, 0)) + x_is_zero = eq(x, _const(x, 0)) + domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)]) + + use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a)) ax = a * log(x) - x - lgamma(a) underflow = lt(ax, -log(dtypes.finfo(dtype).max)) ax = exp(ax) - enabled = bitwise_not( - _reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan])) + enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity])) output = select( use_igammac, @@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype): ) output = select(x_is_zero, full_like(a, 0), output) output = select(x_is_infinity, full_like(a, 1), output) - output = select(bitwise_or(domain_error, is_nan), - full_like(a, float('nan')), output) + output = select(domain_error, full_like(a, float('nan')), output) return output def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): @@ -433,11 +433,15 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode): raise ValueError(f"Invalid mode: {mode}") def igammac_impl(a, x, *, dtype): - out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0))) + is_nan = bitwise_or(_isnan(a), _isnan(x)) + a_is_zero = eq(a, _const(a, 0)) + x_is_zero = eq(x, _const(x, 0)) + x_is_infinity = eq(x, _const(x, float('inf'))) + domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)]) use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a)) ax = a * log(x) - x - lgamma(a) underflow = lt(ax, -log(dtypes.finfo(dtype).max)) - enabled = bitwise_not(bitwise_or(out_of_range, underflow)) + enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero])) ax = exp(ax) igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma), @@ -445,10 +449,10 @@ def igammac_impl(a, x, *, dtype): igammac_cf_call = _igammac_continued_fraction(ax, x, a, bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE) - result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) - x_is_infinity = eq(x, _const(x, float('inf'))) - result = select(x_is_infinity, full_like(result, 0), result) - return select(out_of_range, full_like(a, 1), result) + output = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call) + output = select(bitwise_or(x_is_infinity, a_is_zero), full_like(output, 0), output) + output = select(domain_error, full_like(a, float('nan')), output) + return output def igamma_grad_a_impl(a, x, *, dtype): is_nan = bitwise_or(_isnan(a), _isnan(x)) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index ad6ce6ab4..466f6037a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -343,7 +343,7 @@ class BlockSpec: if self.block_shape is None: block_shape = array_aval.shape else: - block_shape = self.block_shape + block_shape = self.block_shape # type: ignore if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) " diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 18e136623..66bbac33a 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -32,6 +32,7 @@ pytype_strict_library( ], deps = [ ":block_spec", + ":custom_evaluate", ":fusable", ":fusion", ":jaxpr_fusion", @@ -44,6 +45,7 @@ pytype_strict_library( "block_spec.py", ], deps = [ + ":fuser_utils", "//jax", "//jax:ad_util", "//jax:api_util", @@ -119,3 +121,27 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "custom_evaluate", + srcs = ["custom_evaluate.py"], + deps = [ + ":fuser_utils", + "//jax", + "//jax:core", + "//jax:source_info_util", + "//jax:tree_util", + "//jax:util", + ], +) + +pytype_strict_library( + name = "fuser_utils", + srcs = ["fuser_utils.py"], + deps = [ + "//jax:api_util", + "//jax:core", + "//jax:partial_eval", + "//jax:tree_util", + ], +) diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index a9f6ce390..3295c8f10 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -16,6 +16,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec +from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate from jax._src.pallas.fuser.fusable import fusable as fusable from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index 83b485107..d0767aeeb 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -26,15 +26,14 @@ from typing import Any, Callable, Protocol, Sequence import jax from jax import lax from jax._src import ad_util -from jax._src import api_util from jax._src import core from jax._src import custom_derivatives -from jax._src import linear_util as lu from jax._src import pjit from jax._src import tree_util from jax._src import util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core +from jax._src.pallas.fuser import fuser_utils import jax.numpy as jnp import numpy as np @@ -226,18 +225,6 @@ def _unwrap_block_spec_scalar_prefetch( return out_block_spec -def _make_jaxpr(f, *args, **kwargs): - flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - flat_avals = [core.get_aval(x) for x in flat_args] - debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs) - flat_fun, out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(f, debug_info=debug_info), in_tree - ) - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - return jaxpr, consts, in_tree, out_tree - - def pull_block_spec( f: Callable, out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...], @@ -246,7 +233,9 @@ def pull_block_spec( grid: tuple[int | jax.Array, ...] | None = None, ): def wrapped(*args, **kwargs): - jaxpr, consts, in_tree, out_tree_ = _make_jaxpr(f, *args, **kwargs) + jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( + f, *args, **kwargs + ) # TODO(sharadmv): handle these consts better, they should correspond to # scalar prefetch. del consts, out_tree_ @@ -563,7 +552,9 @@ def make_kernel_function( def get_fusion_values( fusion: Callable, *args, **kwargs ) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]: - jaxpr, values, in_tree, out_tree = _make_jaxpr(fusion, *args, **kwargs) + jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr( + fusion, *args, **kwargs + ) assert len(values) == len(jaxpr.constvars), (jaxpr, values) out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars) read_usage_env = compute_usage(jaxpr, out_usages) @@ -1325,7 +1316,7 @@ def push_block_spec( flat_block_specs, in_tree_ = tree_util.tree_flatten( (in_spec_args, in_spec_kwargs) ) - jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs) + jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs) if in_tree != in_tree_: raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}') out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs) diff --git a/jax/_src/pallas/fuser/custom_evaluate.py b/jax/_src/pallas/fuser/custom_evaluate.py new file mode 100644 index 000000000..fff0f7d7e --- /dev/null +++ b/jax/_src/pallas/fuser/custom_evaluate.py @@ -0,0 +1,82 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for evaluating functions under certain constraints.""" +import dataclasses +from typing import Any + +from jax import lax +from jax._src import core +from jax._src import source_info_util +from jax._src import tree_util +from jax._src import util +from jax._src.pallas.fuser import fuser_utils + + +@dataclasses.dataclass +class CustomEvaluateSettings: + allow_transpose: bool = True + + +def evaluate(f, *, allow_transpose: bool = True): + def wrapped(*args, **kwargs): + jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs) + settings = CustomEvaluateSettings(allow_transpose=allow_transpose) + flat_args = tree_util.tree_leaves(args) + out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped + + +# Disallow most higher-order primitives for now. +disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p} + + +def _custom_evaluate_jaxpr( + settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args +): + def read(v: core.Atom) -> Any: + return v.val if isinstance(v, core.Literal) else env[v] + + def write(v: core.Var, val: Any) -> None: + env[v] = val + + env: dict[core.Var, Any] = {} + util.safe_map(write, jaxpr.constvars, consts) + util.safe_map(write, jaxpr.invars, args) + lu = core.last_used(jaxpr) + for eqn in jaxpr.eqns: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + + if eqn.primitive in disallowed_primitives: + raise NotImplementedError(f'Primitive {eqn.primitive} not supported.') + if not settings.allow_transpose and eqn.primitive is lax.transpose_p: + raise ValueError('Transpose not allowed.') + name_stack = ( + source_info_util.current_name_stack() + eqn.source_info.name_stack + ) + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack + ), eqn.ctx.manager: + ans = eqn.primitive.bind( + *subfuns, *util.safe_map(read, eqn.invars), **bind_params + ) + if eqn.primitive.multiple_results: + util.safe_map(write, eqn.outvars, ans) + else: + write(eqn.outvars[0], ans) + core.clean_up_dead_vars(eqn, env, lu) + return util.safe_map(read, jaxpr.outvars) diff --git a/jax/_src/pallas/fuser/fuser_utils.py b/jax/_src/pallas/fuser/fuser_utils.py new file mode 100644 index 000000000..ff44725bb --- /dev/null +++ b/jax/_src/pallas/fuser/fuser_utils.py @@ -0,0 +1,33 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic utils for fuser internals.""" +from jax._src import api_util +from jax._src import core +from jax._src import linear_util as lu +from jax._src import tree_util +from jax._src.interpreters import partial_eval as pe + + + +def make_jaxpr(f, *args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + flat_avals = [core.get_aval(x) for x in flat_args] + debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs) + flat_fun, out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(f, debug_info=debug_info), in_tree + ) + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + return jaxpr, consts, in_tree, out_tree diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9bc20ed2c..762141270 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1853,7 +1853,13 @@ def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): def _dot_general_lowering_rule( - ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_ + ctx: LoweringRuleContext, + x, + y, + dimension_numbers, + precision, + preferred_element_type, + **_, ): (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out @@ -1894,10 +1900,34 @@ def _dot_general_lowering_rule( x = vector.broadcast(bcast_shape, x) if ctx.avals_in[1].shape != bcast_shape: y = vector.broadcast(bcast_shape, y) + red_dtype = ( + preferred_element_type if preferred_element_type else lhs_aval.dtype + ) red_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, - lhs_aval.update(shape=(lhs_aval.shape[0],)), + lhs_aval.update(shape=(lhs_aval.shape[0],), dtype=red_dtype), ) + + if lhs_aval.dtype != red_dtype: + lhs_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + lhs_aval.update(shape=lhs_aval.shape, dtype=red_dtype), + ) + if red_dtype == jnp.float32: + x = arith.extf(lhs_type, x) + else: + raise NotImplementedError(f"Unsupported {preferred_element_type=}") + + if rhs_aval.dtype != red_dtype: + rhs_type = aval_to_ir_type( + ctx.lowering_context.dynamic_shape_replacement_fn, + rhs_aval.update(shape=rhs_aval.shape, dtype=red_dtype), + ) + if red_dtype == jnp.float32: + y = arith.extf(rhs_type, y) + else: + raise NotImplementedError(f"Unsupported {preferred_element_type=}") + acc = arith.ConstantOp( red_type, ir.DenseElementsAttr.get_splat(red_type, val) ) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 43b9008bb..0333a9b03 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") +def _reduce_lowering_rule_wg( + kind: vector_dialect.CombiningKind, + acc: object, + ctx: LoweringRuleContext, + x, + *, + axes, +) -> ir.OpView: + [x_aval] = ctx.avals_in + [out_aval] = ctx.avals_out + x = _ensure_ir_value(x, x_aval.dtype) + out_type = mgpu_utils.dtype_to_ir_type(out_aval.dtype) + if not out_aval.shape: + # Special-case: reducing to a scalar. + if x_aval.ndim != 1: + # TODO(slebedev): Flatten to 1D, since vector.reduction only supports + # 1D inputs. + raise NotImplementedError("Only 1D inputs are supported") + return vector_dialect.ReductionOp(out_type, kind, x) + acc = vector_dialect.splat( + ir.VectorType.get(out_aval.shape, out_type), + _ensure_ir_value(acc, out_aval.dtype), + ) + return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) + + +@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + op = _reduce_lowering_rule_wg( + vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes + ) + op.attributes["offset"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes + ) + return op.result + + +@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): + [x_aval] = ctx.avals_in + if jnp.issubdtype(x_aval.dtype, jnp.floating): + kind = vector_dialect.CombiningKind.MAXIMUMF + acc = float("-inf") + elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger): + kind = vector_dialect.CombiningKind.MAXSI + acc = np.iinfo(x_aval.dtype).max + elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger): + kind = vector_dialect.CombiningKind.MAXUI + acc = np.iinfo(x_aval.dtype).max + else: + raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") + return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result + + @register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): i32 = ir.IntegerType.get_signless(32) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 4343c0802..97d73a3ee 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -198,7 +198,8 @@ def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> - :func:`jax.scipy.stats.gamma.logsf` """ x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) - return gammaincc(a, lax.div(lax.sub(x, loc), scale)) + y = lax.div(lax.sub(x, loc), scale) + return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y)) def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index c3b9e96dc..63f019b31 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -865,15 +865,15 @@ class Jax2TfLimitation(test_harnesses.Limitation): def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg): # noqa: F811 arg1, arg2 = args - # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN + # lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1 special_cases = (arg1 <= 0.) | (arg2 <= 0) nr_special_cases = np.count_nonzero(special_cases) tst.assertAllClose( - np.full((nr_special_cases,), 1., dtype=dtype), + np.full((nr_special_cases,), np.nan, dtype=dtype), result_jax[special_cases], err_msg=err_msg) tst.assertAllClose( - np.full((nr_special_cases,), np.nan, dtype=dtype), + np.full((nr_special_cases,), 1, dtype=dtype), result_tf[special_cases], err_msg=err_msg) # non-special cases are equal @@ -892,12 +892,12 @@ class Jax2TfLimitation(test_harnesses.Limitation): custom_numeric(dtypes=[np.float64], tol=1e-9), custom_numeric(devices="gpu", tol=1e-3), custom_numeric( + modes=("compiled",), custom_assert=custom_assert, - devices=("cpu", "gpu"), + devices=("cpu", "gpu", "tpu"), description=( "May return different results at undefined points " - "(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or " - "JAX returns 1 and TF returns `NaN`")), + "(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")), ] @classmethod diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index c92074dc2..66e19bb5f 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -260,7 +260,7 @@ def _construct_smem_reftree( dynamic_smem, c(dynamic_smem_offset, index), [], ) if layout is None: - layout = tcgen05._infer_tmem_layout(shape) + layout = tcgen05._infer_tmem_layout(shape, collective) num_cols = layout.cols_in_shape(shape) delayed_warp_init.append( functools.partial( diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 368b47df4..55e8c5583 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -259,14 +259,15 @@ def _vector_load_op_lowering_rule( is_signed=is_signed, vec_size=strided_layout.vec_size, ) - elif layouts.is_wgmma_fragmented_layout(out_layout_attr): + elif layouts.from_layout_attr(out_layout_attr) == fa.TILED_LAYOUT_WGMMA: layout = ir.MemRefType(vector_load_op.base.type).layout swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout) transformed_ref = transform_memref(vector_load_op.base, transforms) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, - is_signed=is_signed + is_signed=is_signed, + layout=fa.TILED_LAYOUT_WGMMA, ) else: raise ValueError( @@ -319,6 +320,34 @@ def _vector_splat_op_lowering_rule( return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)] +@_register_lowering(vector.ReductionOp) +def _vector_reduction_op_lowering_rule( + ctx: LoweringContext, op: vector.ReductionOp +) -> Sequence[ir.Value]: + del ctx # Unused. + [layout] = inference_utils.in_layouts(op) + () = inference_utils.out_layouts(op) + element_type = ir.VectorType(op.vector.type).element_type + is_signed = False if ir.IntegerType.isinstance(element_type) else None + a = _fragmented_array_from_ir(op.vector, layout, is_signed) + match str(op.kind): + case "#vector.kind": + smem = ir.Attribute.parse("#gpu.address_space") + scratch = _slice_smem( + ir.MemRefType.get([4], element_type, memory_space=smem), + arith.constant(None, op.attributes["offset"]), + ) + result = a.reduce_sum(scratch) + case ( + "#vector.kind" | "#vector.kind" | "#vector.kind" + ): + # TODO(slebedev): Implement this and remove the raise below. + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + case _: + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + return [_fragmented_array_to_ir(result, op.result.type)] + + def memref_layout_to_swizzle_and_transforms( layout: ir.Attribute, ) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]: @@ -634,7 +663,10 @@ def _mgpu_wgmma_op_lowering_rule( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), ) - if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)): + is_supported_layout = ( + lambda l: layouts.from_tiled_layout_attr(l) == fa.TILED_LAYOUT_WGMMA + ) + if not all(map(is_supported_layout, fa_layouts)): raise ValueError("Layout mismatch") wgmma_layout = fa_layouts[0] @@ -667,7 +699,12 @@ def _mgpu_wgmma_op_lowering_rule( new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) - return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)] + return [ + _fragmented_array_to_ir( + new_acc.value.to_layout(fa.TILED_LAYOUT_WGMMA), + wgmma_op.accumulator.type, + ) + ] @_register_lowering(mgpu.ArriveExpectTxOp) @@ -704,16 +741,17 @@ def _mgpu_slice_smem_op_lowering_rule( ctx: LoweringContext, op: SliceSMEMOp ) -> Sequence[ir.Value]: del ctx + return [_slice_smem(op.result.type, op.offset)] + + +def _slice_smem(result: ir.Type, offset: ir.Value): i8 = ir.IntegerType.get_signless(8) smem = ir.Attribute.parse("#gpu.address_space") - smem_base = gpu.dynamic_shared_memory( ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) - - offset = arith.index_cast(ir.IndexType.get(), op.offset) - - return [memref.view(op.result.type, smem_base, offset, [])] + offset = arith.index_cast(ir.IndexType.get(), offset) + return memref.view(result, smem_base, offset, []) @_register_lowering(scf.ForOp) @@ -857,7 +895,8 @@ def _should_lower(op: ir.OpView) -> bool: def lower_mgpu_dialect( - module: ir.Module, launch_context: launch_context.LaunchContext | None + module: ir.Module, + launch_context: launch_context.LaunchContext | None, ): # TODO(apaszke,bchetioui): Make sure the layouts match. # TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index d15cecbdc..6af394d00 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -230,8 +230,8 @@ def main(unused_argv): tile_n *= 2 if m < tile_m or n < tile_n: continue - if kwargs["collective"] and tile_n >= 512: - continue # TODO(apaszke): Support 512 + if tile_n > 512: + continue if (m // tile_m) % kwargs["grid_tile_m"]: continue try: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a52eb329d..5cacab511 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1389,7 +1389,7 @@ class FragmentedArray: if isinstance(self.layout, WGSplatFragLayout): [reg] = self.registers.flat if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.mulf + op = mulf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.muli else: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 044e7537d..f0f0998b8 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -63,7 +63,7 @@ def _choose_representative_layout( Given the input set of possible layouts, this function extracts a single representative layout. Currently, this function only works with strided, - splat, and WGMMA fragmented layouts. + splat, and tiled layouts. Returns: A single layout that can be used to annotate the operation, or None if the @@ -86,18 +86,18 @@ def _choose_representative_layout( ) ) - wgmma_layouts: list[fa.WGMMAFragLayout] = list( + tiled_layouts: list[fa.TiledLayout] = list( map( layouts_lib.from_layout_attr, - filter(layouts_lib.is_wgmma_fragmented_layout, layouts), + filter(layouts_lib.is_tiled_layout, layouts), ) ) - if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len( + if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len( layouts ): raise ValueError( - f"Expected only strided, splat, and wgmma layouts, got {layouts}" + f"Expected only strided, splat, and tiled layouts, got {layouts}" ) if len(splat_layouts) > 1: @@ -112,13 +112,19 @@ def _choose_representative_layout( "is not supported." ) - if (wgmma_layouts and strided_layouts): + if len(tiled_layouts) > 1: raise NotImplementedError( - "Mixing strided and WGMMA layouts is not supported." + "Finding a representative layout for several distinct tiled layouts " + "is not supported." ) - if wgmma_layouts: - return layouts_lib.to_layout_attr(wgmma_layouts[0]) + if tiled_layouts and strided_layouts: + raise NotImplementedError( + "Mixing strided and tiled layouts is not supported." + ) + + if tiled_layouts: + return layouts_lib.to_layout_attr(tiled_layouts[0]) if strided_layouts: [strided_layout] = strided_layouts @@ -330,10 +336,16 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: return [], [layout] +@partial(_add_layout_inference_rule, vector.ReductionOp) +def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: + if layout := inference_utils.value_layout(op.vector): + return [layout], [] + return None + @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: - layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout()) + layout = layouts_lib.to_layout_attr(fa.TILED_LAYOUT_WGMMA) if ir.VectorType.isinstance(wgmma_op.a.type): return [layout, layout], [layout] diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 334ebeddd..5c3b23119 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: return bool(_strided_fragmented_layout_attr_pattern.search(str(attr))) +_tiled_layout_attr_pattern = re.compile( + r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," + r" warp_dim\s*=\s*(?P[-\d]+)," + r" lane_dims\s*=\s*\[(?P.*)\]," + r" vector_dim\s*=\s*(?P[-\d]+)>$" +) + + +def to_tiled_layout_attr( + layout: fa.TiledLayout, +) -> ir.Attribute: + """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" + tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + return ir.Attribute.parse( + f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," + f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + ) + + +_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") + + +def from_tiled_layout_attr( + attr: ir.Attribute, +) -> fa.TiledLayout: + """Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute. + + Raises: + ValueError: If the attribute is not a #mosaic_gpu.TiledLayout + attribute. + """ + match = _tiled_layout_attr_pattern.fullmatch(str(attr)) + if not match: + raise ValueError( + f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" + ) + + tiling_str = match.group("tiling") + tile_strings = [] + if len(tiling_str) > 2: + tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1]) + tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) + return fa.TiledLayout( + tiling=fa.Tiling(tiles), + warp_dim=int(match.group("warp_dim")), + lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), + vector_dim=int(match.group("vector_dim")) + ) + + +def is_tiled_layout(attr: ir.Attribute) -> bool: + return bool(_tiled_layout_attr_pattern.search(str(attr))) + + def to_layout_attr( layout: ( fa.WGSplatFragLayout | fa.WGStridedFragLayout - | fa.WGMMAFragLayout + | fa.TiledLayout | fa.WGMMARowFragLayout ), ) -> ir.Attribute: @@ -108,8 +164,8 @@ def to_layout_attr( return to_splat_fragmented_layout_attr(layout) case fa.WGStridedFragLayout(): return to_strided_fragmented_layout_attr(layout) - case fa.WGMMAFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMAFragLayout") + case fa.TiledLayout(): + return to_tiled_layout_attr(layout) case fa.WGMMARowFragLayout(): return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: @@ -118,15 +174,6 @@ def to_layout_attr( ) -_wgmma_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMAFragLayout$" -) - - -def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr))) - - _wgmma_row_fragmented_layout_attr_pattern = re.compile( r"^#mosaic_gpu.WGMMARowFragLayout$" ) @@ -141,7 +188,7 @@ def from_layout_attr( ) -> ( fa.WGSplatFragLayout | fa.WGStridedFragLayout - | fa.WGMMAFragLayout + | fa.TiledLayout | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" @@ -149,8 +196,8 @@ def from_layout_attr( return from_splat_fragmented_layout_attr(attr) elif is_strided_fragmented_layout(attr): return from_strided_fragmented_layout_attr(attr) - elif is_wgmma_fragmented_layout(attr): - return fa.WGMMAFragLayout() + elif is_tiled_layout(attr): + return from_tiled_layout_attr(attr) elif is_wgmma_row_fragmented_layout(attr): return fa.WGMMARowFragLayout() else: diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 7a349f50c..e5a2d3aa5 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -83,6 +83,7 @@ def mma( accumulate: ir.Value | bool = True, collective: bool = False, ): + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(accumulate, bool): accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate) @@ -112,6 +113,10 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) + if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)): + raise ValueError( + f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" + ) f32 = ir.F32Type.get() if element_type == f32 or element_type == ir.BF16Type.get(): if d.dtype != f32: @@ -136,11 +141,7 @@ def mma( raise ValueError(f"N must be a multiple of 8, got: {n}") elif n > 256 and n != 512: raise ValueError("Only N below 256 or N=512 are supported") - if num_cta == 2 and n > 256: - raise NotImplementedError( - "N is too big for collective MMA. Only up to 256 is supported." - ) - n_group_elems = min(n, 256) + n_group_elems = min(n, 256 // num_cta) if m % m_group_elems: raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}") if k % k_group_elems: @@ -179,6 +180,7 @@ def mma( # Step 4. Issue the instructions. true = arith.constant(ir.IntegerType.get_signless(1), 1) + n_collective_group_elems = n_group_elems * num_cta for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): a_offset = mi * a_m_group_stride + ki * a_k_group_stride a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) @@ -188,9 +190,9 @@ def mma( raise NotImplementedError("D needs to be sliced") acc = accumulate if ki == 0 else true _do_mma( - d.slice( - slice(None), utils.ds(ni * n_group_elems, n_group_elems) - ).address, + arith.addi( + d.address, arith.constant(i32, ni * n_collective_group_elems) + ), a_mk, b_nk, d_type=ir.F32Type.get(), @@ -377,8 +379,15 @@ class TMEMLayout: +------------------+------------------+ | [0:64, 64:128] | [64:128, 64:128] | +------------------+------------------+ + + The above is further complicated by column_tile_stride, which is used to + swizzle the ordering of column tiles. That is, if column_tile_stride is 2, + we will first lay out all tiles that have the column index 0, 2, 4, and so on + until we run out of tiles. Only then we lay out the tiles with column index + 1, 3, etc. """ elements_in_tile: tuple[int, int] + column_tile_stride: int = 1 def __post_init__(self): row_tiling = self.elements_in_tile[0] @@ -405,7 +414,7 @@ class TMEMLayout: return num_tiles // tiles_in_row * cols_in_tile -def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout: +def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: if shape[0] > TMEM_ROWS: raise ValueError( "Can only infer TMEM layout for shapes with at most 128 rows, got:" @@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout: "Can only infer TMEM layout for shapes with row count that's a power of" f" 2, got: {shape[0]}" ) - return TMEMLayout(elements_in_tile=(shape[0], 1)) + if shape[1] % 8: + raise ValueError( + "Can only infer TMEM layout for shapes with column count that's a" + f" multiple of 8, got: {shape[1]}" + ) + if collective and shape[1] == 512: + return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2) + else: + return TMEMLayout(elements_in_tile=(shape[0], 8)) @dataclasses.dataclass(frozen=True) @@ -432,7 +449,14 @@ class TMEMRef: layout: TMEMLayout @classmethod - def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None): + def from_alloc( + cls, + tmem_addr_ref: ir.Value, + shape: tuple[int, int], + dtype, + collective: bool | None = None, + layout: TMEMLayout | None = None, + ): i32 = ir.IntegerType.get_signless(32) if not ir.MemRefType.isinstance(tmem_addr_ref.type): raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}") @@ -449,7 +473,11 @@ class TMEMRef: if shape[0] < 32: raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}") if layout is None: - layout = _infer_tmem_layout(shape) + if collective is None: + raise ValueError( + "collective argument must be provided when TMEM layout is inferred" + ) + layout = _infer_tmem_layout(shape, collective) else: layout.check_shape(shape) # TODO: Do we have to do this?? @@ -461,12 +489,17 @@ class TMEMRef: base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") - if self.layout.elements_in_tile[0] != TMEM_ROWS: + if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): raise NotImplementedError( - f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows" + "Slicing only implemented for refs with standard layout, got:" + f" {self.layout}" ) if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS: raise NotImplementedError("TMEM cannot be sliced along rows") + if slice_shape[1] % 8: + raise NotImplementedError( + "TMEM column slice length must be a multiple of 8" + ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx) @@ -484,48 +517,75 @@ class TMEMRef: raise ValueError("TMEM loads only support slicing") if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: raise NotImplementedError("Slicing of TMEM not impelmented yet") - if self.layout.elements_in_tile[0] != TMEM_ROWS: - raise NotImplementedError( - f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows" - ) if self.shape[1] % 8: raise NotImplementedError if self.dtype != ir.F32Type.get(): raise NotImplementedError(self.dtype) layout = _m128_256bit_32bit_layout(self.shape) regs_shape = layout.registers_shape(self.shape) - num = self.shape[1] // 8 - # TODO(apaszke): Make the tiling configurable through the args too. - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 - else: - raise NotImplementedError(num) - registers = np.empty(regs_shape, dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - cols_per_num_tile = 8 # This depends on the 16x256b below. - for num_group in range(num // num_tiling): - addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), + if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): + # load_32xcols returns a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + registers = _load_32xcols( + self.address, self.shape[1], self.dtype + ).T.reshape(regs_shape) + elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2): + if self.shape[1] % 128 != 0: + raise ValueError( + f"TMEM layout {self.layout} is not compatible with shape {self.shape}" ) - regs += tmem_load(addr_row_col, "16x256b", num_tiling) - regs = [llvm.bitcast(self.dtype, r) for r in regs] - vector_regs = [] - undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype)) - for r_low, r_high in zip(regs[::2], regs[1::2]): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs.append(vreg) - # Dimension 4 is the one where we split 32 rows into tiles of 8. - regs_slice = (slice(None),) * 4 + (slice(row_group * 2, (row_group + 1) * 2),) - registers[regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[regs_slice].shape) + num_column_tiles = self.shape[1] // 128 + column_tile_stride = self.layout.column_tile_stride + num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) + tiles = [] + for col_tile_base in range(num_strided_col_groups): + for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): + tiles.append( + _load_32xcols( + arith.addi(self.address, arith.constant(i32, col_tile * 128)), + cols=128, + dtype=self.dtype, + ) + ) + registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + else: + raise NotImplementedError( + f"Loads only implemented for refs with standard layout, got: {self.layout}" + ) return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) +def _load_32xcols(base_addr, cols, dtype): + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + i32 = ir.IntegerType.get_signless(32) + assert cols % 8 == 0 + cols_per_num_tile = 8 + load_shape = "16x256b" + num = cols // 8 + if num <= 32: + num_tiling = num + elif num == 64: + num_tiling = 32 + else: + raise NotImplementedError(num) + vector_regs = np.ndarray((4, num), dtype=object) + # We load 16 lanes at a time, but need 32 in total. + for row_group in range(2): + addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) + regs = [] + for num_group in range(num // num_tiling): + addr_row_col = arith.addi( + addr_row, + arith.constant(i32, num_tiling * num_group * cols_per_num_tile), + ) + regs += tmem_load(addr_row_col, load_shape, num_tiling) + regs = [llvm.bitcast(dtype, r) for r in regs] + undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) + for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): + high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) + vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) + vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + return vector_regs + def _m128_256bit_32bit_layout(shape: tuple[int, ...]): if len(shape) != 2: diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index f90f7ff08..080397bbb 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1201,3 +1201,7 @@ def bitcast(x: ir.Value, new_type: ir.Type): assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape) return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x)) raise ValueError(f"Can't bitcast {x.type} to {new_type}") + + +def ceil_div(x: int, y: int): + return (x + y - 1) // y diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 28b62f4f0..729a447b7 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -18,6 +18,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec +from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate from jax._src.pallas.fuser.fusable import fusable as fusable from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py index 30cb20733..6600d7650 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py @@ -43,14 +43,22 @@ class MultiPageAsyncCopyDescriptor: ): self._vmem_buf = vmem_buf seq_id, kv_pages_start = offset - self._async_copies = [ - pltpu.make_async_copy( - pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]], - vmem_buf.at[i], - sem, - ) - for i in range(vmem_buf.shape[0]) - ] + pages_per_seq = page_indices_ref.shape[1] + self._async_copies = [] + # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert + # a bunch of if-ops. Check the performance when we have benchmarking setup. + for i in range(vmem_buf.shape[0]): + page_idx = kv_pages_start + i + page_idx = jax.lax.select( + page_idx < pages_per_seq, page_idx, pages_per_seq - 1 + ) + self._async_copies.append( + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], + vmem_buf.at[i], + sem, + ) + ) def start(self): """Starts the async copies.""" diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 79aebcd86..1f4e5a08d 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -49,7 +49,7 @@ py_library_providing_imports_info( config_setting( name = "disable_jaxlib_for_cpu_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "False", }, ) @@ -57,7 +57,23 @@ config_setting( config_setting( name = "disable_jaxlib_for_cuda12_build", flag_values = { - "//jax:build_jaxlib": "False", + "//jax:build_jaxlib": "false", "@local_config_cuda//:enable_cuda": "True", }, -) \ No newline at end of file +) + +config_setting( + name = "enable_py_import_for_cpu_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "False", + }, +) + +config_setting( + name = "enable_py_import_for_cuda12_build", + flag_values = { + "//jax:build_jaxlib": "wheel", + "@local_config_cuda//:enable_cuda": "True", + }, +) diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index c88b164e6..eaa815d33 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "jaxlib/absl_status_casters.h" @@ -29,7 +31,7 @@ namespace nb = nanobind; nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32, - int workspace_size, int reserve_space_size) { + size_t workspace_size, size_t reserve_space_size) { return PackDescriptor(RnnDescriptor{ input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size}); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 89a6d0a30..e9820bc31 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" +#include #include #include @@ -71,7 +72,7 @@ template <> namespace JAX_GPU_NAMESPACE { -static absl::StatusOr> +static absl::StatusOr> DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, @@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, return std::make_pair(workSpaceSize, reserveSpaceSize); } -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32) { diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 468c02eac..e95b77883 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef JAXLIB_GPU_RNN_KERNELS_H_ #define JAXLIB_GPU_RNN_KERNELS_H_ +#include + #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -34,12 +36,12 @@ struct RnnDescriptor { float dropout; int bidirectional; int cudnn_allow_tf32; - int workspace_size; - int reserve_space_size; + size_t workspace_size; + size_t reserve_space_size; }; // Return (workspace size, reserve space size). -absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff90..e8a72d44e 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { param.value))); } } - // Triton's kernel ABI expects an additional scratchpad global memory. - // For now it is only used for on-device creation of TMA descriptors, which - // we do not use yet, so we are just replacing this argument with a null - // pointer. - // TODO: b/381242007 - Allocate a proper buffer if we want to use - // device-side TMA APIs. - void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns. - params.push_back(&scratch_ptr); - + params.push_back(buffers++); // Scratch buffer. return kernel_.Launch(stream, grid_, params.data()); } diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 58a83d9b0..c4ac8d00f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -224,7 +224,15 @@ def if_building_jaxlib( "@pypi_jax_cuda12_plugin//:pkg", "@pypi_jax_cuda12_pjrt//:pkg", ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]): + if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], + if_py_import = [ + "//jaxlib/tools:jaxlib_py_import", + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], + if_py_import_for_cpu = [ + "//jaxlib/tools:jaxlib_py_import", + ]): """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. @@ -234,12 +242,16 @@ def if_building_jaxlib( if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of gpu-enabled builds if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds + if_py_import: the py_import targets to depend on in case of gpu-enabled builds + if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds """ return select({ "//jax:enable_jaxlib_build": if_building, "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, + "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, + "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, }) # buildifier: disable=function-docstring diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index dbc829832..0882986fc 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -128,7 +128,6 @@ def MosaicGPU_WGStridedFragLayout : AttrDef { let summary = "Annotates an array that is the result of a splat."; let description = [{ @@ -143,20 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "2D array that can be tiled by supported WGMMA shapes."; - let description = [{ - This layout annotates arrays that are fragmented across all threads in a - warpgroup that is executing a WGMMA operation. The shape of the array is - (m, n) where: - - m % 64 == 0 - - n % 8 == 0 - }]; - - let mnemonic = "WGMMAFragLayout"; - let assemblyFormat = ""; -} - def MosaicGPU_WGMMARowFragLayout : AttrDef { let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; let description = [{ @@ -169,6 +154,24 @@ def MosaicGPU_WGMMARowFragLayout : AttrDef { + let summary = "A layout derived from a tiling expression."; + let description = [{ + See mosaic/gpu/fragmented_array.py -> TiledLayout for more details. + }]; + + let parameters = (ins + "::mlir::ArrayAttr":$tiling, + "int":$warp_dim, + "::mlir::ArrayAttr":$lane_dims, + "int":$vector_dim + ); + let mnemonic = "TiledLayout"; + let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` " + "`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`"; +} + + // Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td // but it was not possible to reuse that definition. Including that file // pulls in ops definitions that we don't want and they fail to compile. diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index baf996d50..5b24d2359 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load( "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", @@ -228,6 +232,18 @@ string_flag( build_setting_default = "dist", ) +NVIDIA_WHEELS_DEPS = [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", +] + jax_wheel( name = "jaxlib_wheel", no_abi = False, @@ -235,6 +251,11 @@ jax_wheel( wheel_name = "jaxlib", ) +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + jax_wheel( name = "jaxlib_wheel_editable", editable = True, @@ -252,6 +273,12 @@ jax_wheel( wheel_name = "jax_cuda12_plugin", ) +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda_plugin_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, @@ -290,6 +317,12 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda_pjrt_wheel", + wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), +) + jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 376a9b1a1..7fa3b93f3 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) + if jtu.jaxlib_version() <= (0, 5, 2): + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', + stablehlo) + else: + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', + stablehlo) + + @jtu.run_on_devices("cuda") + def test_no_workspace_overflow(self): + if jtu.jaxlib_version() <= (0, 5, 2): + self.skipTest("Older versions fail because of integer overflow.") + + # Problem sizes known to cause overflows on older versions. + batch_size, max_seq_length, input_size = 256, 500, 512 + num_layers, hidden_size = 1, 256 + num_params = rnn.get_num_params_in_lstm( + input_size, hidden_size, num_layers, True) + x = jax.ShapeDtypeStruct( + (batch_size, max_seq_length, input_size), jnp.float32) + h_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + c_0 = jax.ShapeDtypeStruct( + (2 * num_layers, batch_size, hidden_size), jnp.float32) + weights = jax.ShapeDtypeStruct((num_params,), jnp.float32) + seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32) + fun = jax.jit(partial( + rnn.lstm, input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, dropout=0.0, bidirectional=True)) + fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash. + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 15fc37805..3871a87a7 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2445,7 +2445,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): assert b.shape == () return c, b - xs = jnp.ones((5, 3)) + xs = jnp.ones((20, 3)) c = jnp.ones(4) scan = lambda c, xs: lax.scan(f, c, xs) @@ -2502,6 +2502,28 @@ class LaxControlFlowTest(jtu.JaxTestCase): x, n = jnp.arange(3), jnp.arange(4) jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash + def test_disable_jit_while_loop_with_mutation(self): + # https://github.com/jax-ml/jax/issues/27019 + + def body_fun(carry): + x, y = carry + x += 1 # in-place if x is mutable + return x, y + x + + def cond_fun(carry): + x, _ = carry + return x < 10 + + def f(): + val = np.array(1.0) # mutable value + return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1] + + with jax.disable_jit(False): + result_jit = f() + with jax.disable_jit(True): + result_nojit = f() + self.assertEqual(result_jit, result_nojit) + @parameterized.named_parameters( {"testcase_name": f"_{shape}_{axis=}", "shape": shape, "axis": axis} diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 96d48dcd3..2c09252f4 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -278,6 +278,35 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase): with jax.checking_leaks(): lsp_special.expi(jnp.ones(())) + def testExpiDisableJit(self): + # Regression test for https://github.com/jax-ml/jax/issues/27019 + x = jnp.array([-0.5]) + with jax.disable_jit(True): + result_nojit = lsp_special.expi(x) + with jax.disable_jit(False): + result_jit = lsp_special.expi(x) + self.assertAllClose(result_jit, result_nojit) + + def testGammaIncBoundaryValues(self): + dtype = jax.numpy.zeros(0).dtype # default float dtype. + nan = float('nan') + inf = float('inf') + args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype), + np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) + + def testGammaIncCBoundaryValues(self): + dtype = jax.numpy.zeros(0).dtype # default float dtype. + nan = float('nan') + inf = float('inf') + args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype), + np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)] + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 + self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 4b231939a..4ec2dbf3b 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -39,7 +39,10 @@ jax_multiplatform_test( ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 16, - tags = ["multiaccelerator"], + tags = [ + "multiaccelerator", + "noasan", # Times out. + ], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 91debfe57..ea83d1583 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -210,7 +210,7 @@ class LayoutInferenceTest(parameterized.TestCase): for layout in [ mgpu.WGSplatFragLayout(shape), mgpu.WGStridedFragLayout(shape, vec_size=4), - mgpu.WGMMAFragLayout(), + mgpu.TILED_LAYOUT_WGMMA, ] ) def test_infer_layout_from_yield_op_in_layouts_for_for_op( @@ -278,7 +278,7 @@ class LayoutInferenceTest(parameterized.TestCase): mgpu.infer_layout(self.module) - wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout()) + wgmma_layout = layouts.to_layout_attr(mgpu.TILED_LAYOUT_WGMMA) self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout]) self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) @@ -312,7 +312,7 @@ class LayoutInferenceTest(parameterized.TestCase): @parameterized.parameters( mgpu.WGStridedFragLayout((32, 4), vec_size=1), - mgpu.WGMMAFragLayout(), + mgpu.TILED_LAYOUT_WGMMA, ) def test_infer_layout_picks_non_splat_layout_over_splat_layout( self, layout diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index cc654eb2b..f6b94b777 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1026,7 +1026,7 @@ class TCGen05Test(TestCase): in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation m=(256,), # TODO(apaszke): 64, 192, 256 - n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2 + n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), swizzle=(32, 64, 128,), ) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index c510c2cfa..4c6a8eb7a 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase): @jax.jit def f(x_ref): - self.assertEqual(core.get_ty(x_ref).sharding.spec, - core.get_ty(x_ref[...]).sharding.spec) + self.assertEqual(core.typeof(x_ref).sharding.spec, + core.typeof(x_ref[...]).sharding.spec) y = x_ref[...] + 1 return y diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0a3af26de..8a7b6c98b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -184,6 +184,23 @@ class PallasCallTest(PallasTest): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) + @parameterized.product( + shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics] + ) + def test_reduce_sum(self, shape, thread_semantics): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) + + x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), jnp.sum(x)) + def test_reshape(self): shape1, shape2 = (128,), (2, 16, 4) @@ -200,10 +217,14 @@ class PallasCallTest(PallasTest): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - def test_add_xy_indexed(self): + @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) + def test_add_xy_indexed(self, thread_semantics): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -1078,10 +1099,14 @@ class PallasCallTest(PallasTest): self.assertIn("acc % 2", output()) - def test_cond_returning_array(self): + @parameterized.parameters([*plgpu.ThreadSemantics]) + def test_cond_returning_array(self, thread_semantics): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + compiler_params=plgpu.GPUCompilerParams( + thread_semantics=thread_semantics + ), ) def kernel(x_ref, o_ref): acc = _sum_same_dtype(x_ref[...]) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 2011602d8..c4d600d23 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -470,6 +470,27 @@ class OpsTest(PallasBaseTest): expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x)) np.testing.assert_array_equal(out, expected) + def test_reduce_with_const(self): + m = 1 + d = 1024 + x = jnp.ones((m, d), jnp.bfloat16) + + def dot(x, y): + return jax.lax.dot_general( + x, + y, + (((1,), (1,)), ((), ())), + preferred_element_type=jnp.float32, + ) + + def kernel(x, out): + out[:] = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) + + run = pl.pallas_call(kernel, jax.ShapeDtypeStruct((m, 1), jnp.float32)) + output = run(x) + expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) + np.testing.assert_array_equal(output, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index cca8e3bc8..bffcebc52 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -64,10 +64,6 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): max_num_seq = max(len(seq_lens), max_num_seq) max_kv_len = max(kv_lens) pages_per_seq = ceil_div(max_kv_len, page_size) - pages_per_seq = ( - ceil_div(pages_per_seq, num_kv_pages_per_block) - * num_kv_pages_per_block - ) num_q_heads, num_kv_heads = num_heads cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) @@ -130,8 +126,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase): num_seqs=num_seqs, ) tols = { - "float32": 1e-1, - "bfloat16": 2e-1, + "float32": 0.15, + "bfloat16": 0.2, } tol = tols[jnp.dtype(dtype).name] self.assertAllClose(output, expected, atol=tol, rtol=tol) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bd7954d60..4cd1af9d3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, s) def f(x): - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * 2 - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * x - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) return x # Eager mode diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 48f3d062b..844892adc 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase): c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32) ) + @parameterized.named_parameters( + dict( + testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2) + ), + ) + def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + operand = jax.device_put( + jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output = jax.device_put( + jnp.zeros((2, 4), dtype=jnp.float32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + input_offsets = jax.device_put( + jnp.array([[0, 1], [0, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + send_sizes = jax.device_put( + jnp.array([[1, 2], [1, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + output_offsets = jax.device_put( + jnp.array([[0, 0], [1, 2]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + recv_sizes = jax.device_put( + jnp.array([[1, 1], [2, 1]], dtype=jnp.int32), + jax.sharding.NamedSharding(mesh, P(axis_name, None)), + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_rep=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + operand = operand.reshape(operand.shape[1:]) + output = output.reshape(output.shape[1:]) + input_offsets = input_offsets.reshape(input_offsets.shape[1:]) + send_sizes = send_sizes.reshape(send_sizes.shape[1:]) + output_offsets = output_offsets.reshape(output_offsets.shape[1:]) + recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:]) + return lax.ragged_all_to_all( + operand, + output, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=axis_name, + ) + + args = input_offsets, send_sizes, output_offsets, recv_sizes + jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1) + @parameterized.named_parameters( dict( testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6710de12a..5bdf1f541 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240" -XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603" +XLA_COMMIT = "c270a6ce45df7f7bb3024f2e4df56b688d76ebd6" +XLA_SHA256 = "b2f7d0293fc62bb670d0b58c5847108652eac4d9e6c7e420bed2029e74af6f2d" def repo(): tf_http_archive(