mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
Merge pull request #276 from ROCm/ci-upstream-sync-144_1
CI: 03/12/25 upstream sync
This commit is contained in:
commit
9cc545254c
6
.bazelrc
6
.bazelrc
@ -253,12 +253,6 @@ build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base
|
||||
build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc
|
||||
build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
|
||||
|
||||
# Mac x86 CI configs
|
||||
build:ci_darwin_x86_64 --macos_minimum_os=11.0
|
||||
build:ci_darwin_x86_64 --config=macos_cache_push
|
||||
build:ci_darwin_x86_64 --verbose_failures=true
|
||||
build:ci_darwin_x86_64 --color=yes
|
||||
|
||||
# Mac Arm64 CI configs
|
||||
build:ci_darwin_arm64 --macos_minimum_os=11.0
|
||||
build:ci_darwin_arm64 --config=macos_cache_push
|
||||
|
85
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
85
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
@ -3,6 +3,7 @@
|
||||
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
|
||||
# tested to get to a stable state before we enable it as a blocking presubmit.
|
||||
name: CI - Cloud TPU (presubmit)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
@ -33,64 +34,32 @@ concurrency:
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
cloud-tpu-test:
|
||||
build-jax-artifacts:
|
||||
if: github.event.repository.fork == false
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
tpu: [
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
python-version: ["3.10"]
|
||||
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
||||
env:
|
||||
JAXCI_PYTHON: python${{ matrix.python-version }}
|
||||
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
artifact: ["jax", "jaxlib"]
|
||||
with:
|
||||
runner: "linux-x86-n2-16"
|
||||
artifact: ${{ matrix.artifact }}
|
||||
python: "3.10"
|
||||
clone_main_xla: 1
|
||||
upload_artifacts_to_gcs: true
|
||||
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
|
||||
runs-on: ${{ matrix.tpu.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
|
||||
timeout-minutes: 60
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -ex {0}
|
||||
steps:
|
||||
# https://opensource.google/documentation/reference/github/services#actions
|
||||
# mandates using a specific commit for non-Google actions. We use
|
||||
# https://github.com/sethvargo/ratchet to pin specific versions.
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
# Checkout XLA at head, if we're building jaxlib at head.
|
||||
- name: Checkout XLA at head
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: openxla/xla
|
||||
path: xla
|
||||
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
|
||||
- name: Mark GitHub workspace as safe
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- name: Install JAX test requirements
|
||||
run: |
|
||||
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
||||
- name: Build jaxlib at head with latest XLA
|
||||
run: |
|
||||
# Build and install jaxlib at head
|
||||
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
|
||||
--python_version=${{ matrix.python-version }} \
|
||||
--bazel_options=--config=rbe_linux_x86_64 \
|
||||
--local_xla_path="$(pwd)/xla" \
|
||||
--verbose
|
||||
|
||||
# Install libtpu
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Install jaxlib wheel and run tests
|
||||
run: ./ci/run_pytest_tpu.sh
|
||||
run-pytest-tpu:
|
||||
if: github.event.repository.fork == false
|
||||
needs: [build-jax-artifacts]
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
name: "TPU test (jaxlib=head, v5e-8)"
|
||||
with:
|
||||
runner: "linux-x86-ct5lp-224-8tpu"
|
||||
cores: "8"
|
||||
tpu-type: "v5e-8"
|
||||
python: "3.10"
|
||||
libtpu-version-type: "nightly"
|
||||
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
3
.github/workflows/pytest_cpu.yml
vendored
3
.github/workflows/pytest_cpu.yml
vendored
@ -116,6 +116,9 @@ jobs:
|
||||
exit 1
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
# Remove installation of NVIDIA wheels for CPU tests.
|
||||
sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in
|
||||
|
||||
# TODO(srnitin): Remove after uv is installed in the Windows Dockerfile
|
||||
$JAXCI_PYTHON -m pip install uv~=0.5.30
|
||||
# python 3.13t cannot compile zstandard 0.23.0 due to
|
||||
|
151
.github/workflows/pytest_tpu.yml
vendored
Normal file
151
.github/workflows/pytest_tpu.yml
vendored
Normal file
@ -0,0 +1,151 @@
|
||||
# CI - Pytest TPU
|
||||
#
|
||||
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
|
||||
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
|
||||
#
|
||||
# It consists of the following job:
|
||||
# run-tests:
|
||||
# - Downloads the jaxlib wheel from a GCS bucket.
|
||||
# - Sets up the libtpu wheels.
|
||||
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
|
||||
# - Installs the downloaded jaxlib wheel.
|
||||
# - Runs the TPU tests with Pytest.
|
||||
name: CI - Pytest TPU
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
# Note that the values for runners, cores, and tpu-type are linked to each other.
|
||||
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
|
||||
# following mapping:
|
||||
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
runner:
|
||||
description: "Which runner should the workflow run on?"
|
||||
type: string
|
||||
required: true
|
||||
default: "linux-x86-ct5lp-224-8tpu"
|
||||
cores:
|
||||
description: "How many TPU cores should the test use?"
|
||||
type: string
|
||||
required: true
|
||||
default: "8"
|
||||
tpu-type:
|
||||
description: "Which TPU type is used for testing?"
|
||||
type: string
|
||||
required: true
|
||||
default: "v5e-8"
|
||||
python:
|
||||
description: "Which Python version should be used for testing?"
|
||||
type: string
|
||||
required: true
|
||||
default: "3.12"
|
||||
run-full-tpu-test-suite:
|
||||
description: "Should the full TPU test suite be run?"
|
||||
type: string
|
||||
required: false
|
||||
default: "0"
|
||||
libtpu-version-type:
|
||||
description: "Which libtpu version should be used for testing?"
|
||||
type: string
|
||||
required: false
|
||||
# Choices are:
|
||||
# - "nightly": Use the nightly libtpu wheel.
|
||||
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
|
||||
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
|
||||
default: "nightly"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
type: string
|
||||
halt-for-connection:
|
||||
description: 'Should this workflow run wait for a remote connection?'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
runs-on: ${{ inputs.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
||||
|
||||
env:
|
||||
LIBTPU_OLDEST_VERSION_DATE: 20241205
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
|
||||
JAXCI_TPU_CORES: "${{ inputs.cores }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set env vars for use in artifact download URL
|
||||
run: |
|
||||
os=$(uname -s | awk '{print tolower($0)}')
|
||||
arch=$(uname -m)
|
||||
|
||||
# Get the major and minor version of Python.
|
||||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
|
||||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
|
||||
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
|
||||
|
||||
echo "OS=${os}" >> $GITHUB_ENV
|
||||
echo "ARCH=${arch}" >> $GITHUB_ENV
|
||||
# Python wheels follow a naming convention: standard wheels use the pattern
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download JAX wheels from GCS
|
||||
id: download-wheel-artifacts
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
||||
# informative error message.
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts.outcome == 'failure'
|
||||
run: |
|
||||
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
|
||||
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
|
||||
echo "Skipping the test run."
|
||||
exit 1
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
||||
- name: Set up libtpu wheels
|
||||
run: |
|
||||
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
|
||||
echo "Using nightly libtpu"
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
|
||||
echo "Using latest libtpu from PyPI"
|
||||
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
|
||||
# script will install the latest libtpu wheel from PyPI.
|
||||
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
|
||||
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
|
||||
echo "Using oldest supported libtpu"
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
|
||||
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
|
||||
else
|
||||
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
|
||||
exit 1
|
||||
fi
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Run Pytest TPU tests
|
||||
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
|
||||
run: ./ci/run_pytest_tpu.sh
|
26
.github/workflows/wheel_tests_continuous.yml
vendored
26
.github/workflows/wheel_tests_continuous.yml
vendored
@ -142,4 +142,30 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-pytest-tpu:
|
||||
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact]
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
python: ["3.10",]
|
||||
tpu-specs: [
|
||||
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
||||
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
name: "TPU tests (jax=head, jaxlib=head)"
|
||||
with:
|
||||
runner: ${{ matrix.tpu-specs.runner }}
|
||||
cores: ${{ matrix.tpu-specs.cores }}
|
||||
tpu-type: ${{ matrix.tpu-specs.type }}
|
||||
python: ${{ matrix.python }}
|
||||
run-full-tpu-test-suite: "1"
|
||||
libtpu-version-type: "nightly"
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
@ -58,4 +58,42 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
run-pytest-tpu:
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
# Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for
|
||||
# profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302)
|
||||
python: ["3.10", "3.11", "3.12"]
|
||||
tpu-specs: [
|
||||
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
||||
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"]
|
||||
exclude:
|
||||
- libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }}
|
||||
- libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }}
|
||||
# Run a single Python version for v4-8.
|
||||
- tpu-specs:
|
||||
type: "v4-8"
|
||||
python: "3.10"
|
||||
- tpu-specs:
|
||||
type: "v4-8"
|
||||
python: "3.11"
|
||||
# Run min and max Python versions for v5e-8
|
||||
- tpu-specs:
|
||||
type: "v5e-8"
|
||||
python: "3.11"
|
||||
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
|
||||
with:
|
||||
runner: ${{ matrix.tpu-specs.runner }}
|
||||
cores: ${{ matrix.tpu-specs.cores }}
|
||||
tpu-type: ${{ matrix.tpu-specs.type }}
|
||||
python: ${{ matrix.python }}
|
||||
run-full-tpu-test-suite: "1"
|
||||
libtpu-version-type: ${{ matrix.libtpu-version-type }}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
@ -456,3 +456,4 @@ For details about the JAX API, see the
|
||||
|
||||
For getting started as a JAX developer, see the
|
||||
[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
|
||||
|
||||
|
@ -29,7 +29,7 @@ compile_pip_requirements(
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = REQUIREMENTS,
|
||||
generate_hashes = True,
|
||||
data = ["test-requirements.txt"]
|
||||
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
|
||||
)
|
||||
|
||||
compile_pip_requirements(
|
||||
@ -44,7 +44,7 @@ compile_pip_requirements(
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = REQUIREMENTS,
|
||||
generate_hashes = False,
|
||||
data = ["test-requirements.txt"]
|
||||
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
|
||||
)
|
||||
|
||||
compile_pip_requirements(
|
||||
@ -58,7 +58,7 @@ compile_pip_requirements(
|
||||
requirements_in = "requirements.in",
|
||||
requirements_txt = REQUIREMENTS,
|
||||
generate_hashes = False,
|
||||
data = ["test-requirements.txt"]
|
||||
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
13
build/gpu-test-requirements.txt
Normal file
13
build/gpu-test-requirements.txt
Normal file
@ -0,0 +1,13 @@
|
||||
# NVIDIA CUDA dependencies
|
||||
# Note that the wheels are downloaded only when the targets in bazel command
|
||||
# contain dependencies on these wheels.
|
||||
nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux"
|
||||
nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux"
|
||||
nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux"
|
||||
nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux"
|
||||
nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux"
|
||||
nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux"
|
||||
nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux"
|
||||
nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux"
|
||||
nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux"
|
||||
nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux"
|
@ -2,6 +2,7 @@
|
||||
# test deps
|
||||
#
|
||||
-r test-requirements.txt
|
||||
-r gpu-test-requirements.txt
|
||||
|
||||
#
|
||||
# build deps
|
||||
|
@ -304,24 +304,31 @@ mdurl==0.1.2 \
|
||||
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
|
||||
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
|
||||
# via markdown-it-py
|
||||
ml-dtypes==0.4.0 \
|
||||
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
|
||||
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
|
||||
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
|
||||
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
|
||||
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
|
||||
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
|
||||
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
|
||||
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
|
||||
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
|
||||
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
|
||||
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
|
||||
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
|
||||
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
|
||||
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
|
||||
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
|
||||
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
|
||||
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
|
||||
ml-dtypes==0.5.1 \
|
||||
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
||||
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
||||
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
||||
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
|
||||
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
|
||||
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
|
||||
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
|
||||
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
|
||||
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
|
||||
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
|
||||
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
|
||||
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
|
||||
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
|
||||
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
|
||||
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
|
||||
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
|
||||
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
|
||||
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
|
||||
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
|
||||
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
|
||||
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
|
||||
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
|
||||
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
|
||||
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
|
||||
# via -r build/requirements.in
|
||||
mpmath==1.4.0a1 \
|
||||
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
|
||||
@ -380,6 +387,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
|
||||
# ml-dtypes
|
||||
# opt-einsum
|
||||
# scipy
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
|
||||
# via
|
||||
# via -r build/test-requirements.txt
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
|
||||
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
|
||||
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
|
||||
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
|
||||
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
|
||||
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
|
||||
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
|
||||
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
|
||||
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
|
||||
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
|
||||
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
|
||||
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
|
||||
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
|
||||
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
|
||||
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
|
||||
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
|
||||
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
|
||||
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
|
||||
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
|
||||
# via
|
||||
# via -r build/test-requirements.txt
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
|
||||
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
|
||||
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
|
||||
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
|
||||
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
|
||||
# via
|
||||
# via -r build/test-requirements.txt
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
opt-einsum==3.3.0 \
|
||||
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
|
||||
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
|
||||
|
@ -299,24 +299,31 @@ mdurl==0.1.2 \
|
||||
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
|
||||
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
|
||||
# via markdown-it-py
|
||||
ml-dtypes==0.4.0 \
|
||||
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
|
||||
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
|
||||
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
|
||||
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
|
||||
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
|
||||
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
|
||||
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
|
||||
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
|
||||
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
|
||||
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
|
||||
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
|
||||
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
|
||||
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
|
||||
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
|
||||
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
|
||||
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
|
||||
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
|
||||
ml-dtypes==0.5.1 \
|
||||
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
||||
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
||||
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
||||
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
|
||||
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
|
||||
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
|
||||
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
|
||||
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
|
||||
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
|
||||
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
|
||||
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
|
||||
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
|
||||
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
|
||||
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
|
||||
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
|
||||
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
|
||||
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
|
||||
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
|
||||
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
|
||||
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
|
||||
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
|
||||
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
|
||||
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
|
||||
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
|
||||
# via -r build/requirements.in
|
||||
mpmath==1.4.0a1 \
|
||||
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
|
||||
@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
|
||||
# ml-dtypes
|
||||
# opt-einsum
|
||||
# scipy
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
|
||||
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
|
||||
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
|
||||
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
|
||||
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
|
||||
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
|
||||
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
|
||||
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
|
||||
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
|
||||
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
|
||||
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
|
||||
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
|
||||
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
|
||||
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
|
||||
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
|
||||
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
|
||||
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
|
||||
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
|
||||
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
|
||||
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
|
||||
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
|
||||
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
|
||||
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
opt-einsum==3.3.0 \
|
||||
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
|
||||
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
|
||||
|
@ -299,24 +299,31 @@ mdurl==0.1.2 \
|
||||
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
|
||||
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
|
||||
# via markdown-it-py
|
||||
ml-dtypes==0.4.0 \
|
||||
--hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \
|
||||
--hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \
|
||||
--hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \
|
||||
--hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \
|
||||
--hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \
|
||||
--hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \
|
||||
--hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \
|
||||
--hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \
|
||||
--hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \
|
||||
--hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \
|
||||
--hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \
|
||||
--hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \
|
||||
--hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \
|
||||
--hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \
|
||||
--hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \
|
||||
--hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \
|
||||
--hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1
|
||||
ml-dtypes==0.5.1 \
|
||||
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
||||
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
||||
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
||||
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
|
||||
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
|
||||
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
|
||||
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
|
||||
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
|
||||
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
|
||||
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
|
||||
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
|
||||
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
|
||||
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
|
||||
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
|
||||
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
|
||||
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
|
||||
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
|
||||
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
|
||||
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
|
||||
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
|
||||
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
|
||||
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
|
||||
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
|
||||
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
|
||||
# via -r build/requirements.in
|
||||
mpmath==1.4.0a1 \
|
||||
--hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \
|
||||
@ -375,6 +382,64 @@ numpy==2.0.0 ; python_version <= "3.12" \
|
||||
# ml-dtypes
|
||||
# opt-einsum
|
||||
# scipy
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
|
||||
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
|
||||
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
|
||||
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
|
||||
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
|
||||
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
|
||||
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
|
||||
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
|
||||
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
|
||||
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
|
||||
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
|
||||
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
|
||||
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
|
||||
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
|
||||
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
|
||||
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
|
||||
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
|
||||
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
|
||||
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
|
||||
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
|
||||
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
|
||||
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
|
||||
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
opt-einsum==3.3.0 \
|
||||
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
|
||||
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
|
||||
|
@ -347,28 +347,31 @@ mdurl==0.1.2 \
|
||||
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
|
||||
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
|
||||
# via markdown-it-py
|
||||
ml-dtypes==0.5.0 \
|
||||
--hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \
|
||||
--hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \
|
||||
--hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \
|
||||
--hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \
|
||||
--hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \
|
||||
--hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \
|
||||
--hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \
|
||||
--hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \
|
||||
--hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \
|
||||
--hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \
|
||||
--hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \
|
||||
--hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \
|
||||
--hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \
|
||||
--hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \
|
||||
--hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \
|
||||
--hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \
|
||||
--hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \
|
||||
--hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \
|
||||
--hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \
|
||||
--hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \
|
||||
--hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599
|
||||
ml-dtypes==0.5.1 \
|
||||
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
|
||||
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
|
||||
--hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \
|
||||
--hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \
|
||||
--hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \
|
||||
--hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \
|
||||
--hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \
|
||||
--hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \
|
||||
--hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \
|
||||
--hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \
|
||||
--hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \
|
||||
--hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \
|
||||
--hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \
|
||||
--hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \
|
||||
--hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \
|
||||
--hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \
|
||||
--hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \
|
||||
--hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \
|
||||
--hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \
|
||||
--hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \
|
||||
--hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \
|
||||
--hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \
|
||||
--hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \
|
||||
--hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1
|
||||
# via -r build/requirements.in
|
||||
mpmath==1.3.0 \
|
||||
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
|
||||
@ -434,6 +437,64 @@ numpy==2.1.2 ; python_version >= "3.13" \
|
||||
# matplotlib
|
||||
# ml-dtypes
|
||||
# scipy
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
|
||||
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
|
||||
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
|
||||
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
|
||||
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
|
||||
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
|
||||
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
|
||||
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
|
||||
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
|
||||
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
|
||||
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
|
||||
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
|
||||
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
|
||||
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
|
||||
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
|
||||
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
|
||||
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
|
||||
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
|
||||
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
|
||||
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
|
||||
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
|
||||
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
|
||||
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
opt-einsum==3.4.0 \
|
||||
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
|
||||
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
|
||||
|
@ -390,6 +390,64 @@ numpy==2.2.1 ; python_version >= "3.13" \
|
||||
# matplotlib
|
||||
# ml-dtypes
|
||||
# scipy
|
||||
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
|
||||
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
|
||||
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \
|
||||
--hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \
|
||||
--hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \
|
||||
--hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \
|
||||
--hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \
|
||||
--hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \
|
||||
--hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \
|
||||
--hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \
|
||||
--hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \
|
||||
--hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \
|
||||
--hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \
|
||||
--hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \
|
||||
--hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \
|
||||
--hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \
|
||||
--hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \
|
||||
--hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \
|
||||
--hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \
|
||||
--hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \
|
||||
--hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \
|
||||
--hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \
|
||||
--hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cusolver-cu12
|
||||
nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \
|
||||
--hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \
|
||||
--hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8
|
||||
# via -r build/test-requirements.txt
|
||||
nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \
|
||||
--hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \
|
||||
--hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \
|
||||
--hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0
|
||||
# via
|
||||
# -r build/test-requirements.txt
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
opt-einsum==3.4.0 \
|
||||
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
|
||||
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
|
||||
|
@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
|
||||
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
|
||||
# on the system. By default, it is set to match the version of the hermetic
|
||||
# Python used by Bazel for building the wheels.
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
|
||||
# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests
|
||||
# is run.
|
||||
export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0}
|
||||
|
||||
# We use this environment variable to control which additional wheels to install
|
||||
# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest
|
||||
# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the
|
||||
# list of valid values and their behavior.
|
||||
export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""}
|
@ -52,23 +52,46 @@ export JAX_SKIP_SLOW_TESTS=true
|
||||
|
||||
echo "Running TPU tests..."
|
||||
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" \
|
||||
tests/pallas/ops_test.py \
|
||||
tests/pallas/export_back_compat_pallas_test.py \
|
||||
tests/pallas/export_pallas_test.py \
|
||||
tests/pallas/tpu_ops_test.py \
|
||||
tests/pallas/tpu_pallas_test.py \
|
||||
tests/pallas/tpu_pallas_random_test.py \
|
||||
tests/pallas/tpu_pallas_async_test.py \
|
||||
tests/pallas/tpu_pallas_state_test.py
|
||||
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
|
||||
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic
|
||||
# TPU does not guarantee anything about forward compatibility (unless
|
||||
# jax.export is used) and the 12 week compatibility window accumulates way
|
||||
# too many failures.
|
||||
IGNORE_FLAGS=""
|
||||
if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then
|
||||
IGNORE_FLAGS="--ignore=tests/pallas"
|
||||
fi
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
|
||||
tests/pjit_test.py \
|
||||
tests/pallas/tpu_pallas_distributed_test.py
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \
|
||||
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
||||
else
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" \
|
||||
tests/pallas/ops_test.py \
|
||||
tests/pallas/export_back_compat_pallas_test.py \
|
||||
tests/pallas/export_pallas_test.py \
|
||||
tests/pallas/tpu_ops_test.py \
|
||||
tests/pallas/tpu_pallas_test.py \
|
||||
tests/pallas/tpu_pallas_random_test.py \
|
||||
tests/pallas/tpu_pallas_async_test.py \
|
||||
tests/pallas/tpu_pallas_state_test.py
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
|
||||
tests/pjit_test.py \
|
||||
tests/pallas/tpu_pallas_distributed_test.py
|
||||
fi
|
@ -17,8 +17,19 @@
|
||||
# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python
|
||||
# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to
|
||||
# avoid using the Windows version of `find` on Msys.
|
||||
|
||||
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
|
||||
|
||||
for i in "${!WHEELS[@]}"; do
|
||||
if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then
|
||||
if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then
|
||||
# Append [tpu] to the jax wheel name to download the latest libtpu wheel
|
||||
# from PyPI.
|
||||
WHEELS[$i]="${WHEELS[$i]}[tpu]"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -z "${WHEELS[@]}" ]]; then
|
||||
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
|
||||
exit 1
|
||||
|
@ -58,6 +58,7 @@ Operators
|
||||
clz
|
||||
collapse
|
||||
complex
|
||||
composite
|
||||
concatenate
|
||||
conj
|
||||
conv
|
||||
|
22
jax/BUILD
22
jax/BUILD
@ -14,7 +14,7 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
@ -45,17 +45,26 @@ package(
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# If this flag is true, jaxlib should be built by bazel. If false, then we do not build jaxlib and
|
||||
# assume it has been installed, e.g., by `pip`.
|
||||
bool_flag(
|
||||
# The flag controls whether jaxlib should be built by Bazel.
|
||||
# If ":build_jaxlib=true", then jaxlib will be built.
|
||||
# If ":build_jaxlib=false", then jaxlib is not built. It is assumed that the pre-built jaxlib wheel
|
||||
# is available in the "dist" folder.
|
||||
# If ":build_jaxlib=wheel", then jaxlib wheel will be built as a py_import rule attribute.
|
||||
# The py_import rule unpacks the wheel and provides its content as a py_library.
|
||||
string_flag(
|
||||
name = "build_jaxlib",
|
||||
build_setting_default = True,
|
||||
build_setting_default = "true",
|
||||
values = [
|
||||
"true",
|
||||
"false",
|
||||
"wheel",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_jaxlib_build",
|
||||
flag_values = {
|
||||
":build_jaxlib": "True",
|
||||
":build_jaxlib": "true",
|
||||
},
|
||||
)
|
||||
|
||||
@ -681,6 +690,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":pallas", # build_cleaner: keep
|
||||
"//jax/_src/pallas/fuser:block_spec",
|
||||
"//jax/_src/pallas/fuser:custom_evaluate",
|
||||
"//jax/_src/pallas/fuser:fusable",
|
||||
"//jax/_src/pallas/fuser:fusion",
|
||||
"//jax/_src/pallas/fuser:jaxpr_fusion",
|
||||
|
@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc
|
||||
Device = _xc.Device
|
||||
del _xc
|
||||
|
||||
from jax._src.core import get_ty as get_ty
|
||||
from jax._src.core import typeof as typeof
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||
|
@ -235,6 +235,7 @@ def trace_context():
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
use_direct_linearize.value,
|
||||
varying_axes_in_types.value,
|
||||
softmax_custom_jvp.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state(
|
||||
help=('Use direct linearization instead JVP followed by partial eval'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
varying_axes_in_types = bool_state(
|
||||
name='jax_varying_axes_in_types',
|
||||
default=False,
|
||||
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
|
||||
' array is varying over. This will help to remove the efficient'
|
||||
' transpose rewrite machinery in shard_map'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
data_dependent_tracing_fallback = bool_state(
|
||||
name='jax_data_dependent_tracing_fallback',
|
||||
default=False,
|
||||
|
@ -1576,7 +1576,7 @@ def get_aval(x):
|
||||
return get_aval(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
|
||||
|
||||
get_ty = get_aval
|
||||
typeof = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
|
||||
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'sharding'] # inherits slots from parent
|
||||
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
|
||||
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
|
||||
varying_manual_axes: frozenset[AxisName] = frozenset()):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.sharding = get_sharding(sharding, self.shape)
|
||||
if config.varying_axes_in_types.value:
|
||||
self.varying_manual_axes = varying_manual_axes
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
if shape is None:
|
||||
@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray):
|
||||
weak_type = self.weak_type
|
||||
if 'sharding' not in kwargs:
|
||||
kwargs['sharding'] = self.sharding
|
||||
if 'varying_manual_axes' not in kwargs:
|
||||
kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes',
|
||||
frozenset())
|
||||
return ShapedArray(shape, dtype, weak_type, **kwargs)
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type
|
||||
and self.sharding == other.sharding)
|
||||
and self.sharding == other.sharding
|
||||
and (getattr(self, 'varying_manual_axes', frozenset()) ==
|
||||
getattr(other, 'varying_manual_axes', frozenset())))
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
|
||||
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
|
||||
getattr(self, 'varying_manual_axes', frozenset())))
|
||||
|
||||
def to_tangent_aval(self):
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding)
|
||||
return ShapedArray(
|
||||
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding,
|
||||
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
|
||||
|
||||
def str_short(self, short_dtypes=False, mesh_axis_types=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
|
@ -1364,9 +1364,9 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
|
||||
if config.disable_jit.value:
|
||||
try:
|
||||
val = init_val
|
||||
val = tree_map(lax.asarray, init_val)
|
||||
while cond_fun(val):
|
||||
val = body_fun(val)
|
||||
val = tree_map(lax.asarray, body_fun(val))
|
||||
return val
|
||||
except core.ConcretizationTypeError:
|
||||
# Can't run this while_loop in Python (e.g. because there's a vmap
|
||||
|
@ -1489,14 +1489,14 @@ def composite(
|
||||
):
|
||||
"""Composite with semantics defined by the decomposition function.
|
||||
|
||||
A composite is a higher-order JAX function that encapsulates an operation mad
|
||||
A composite is a higher-order JAX function that encapsulates an operation made
|
||||
up (composed) of other JAX functions. The semantics of the op are implemented
|
||||
by the ``decomposition`` function. In other words, the defined composite
|
||||
function can be replaced with its decomposed implementation without changing
|
||||
the semantics of the encapsulated operation.
|
||||
|
||||
The compiler can recognize specific composite operations by their ``name``,
|
||||
``version``, ``kawargs``, and dtypes to emit more efficient code, potentially
|
||||
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
|
||||
leveraging hardware-specific instructions or optimizations. If the compiler
|
||||
doesn't recognize the composite, it falls back to compiling the
|
||||
``decomposition`` function.
|
||||
@ -1505,11 +1505,11 @@ def composite(
|
||||
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
|
||||
recognize the "tangent" composite and emit a single ``tangent`` instruction
|
||||
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
|
||||
With compilers for hardwares without dedicated tangent support, it would fall
|
||||
back to compiling the decomposition.
|
||||
For hardware without dedicated tangent support, it would fall back to
|
||||
compiling the decomposition.
|
||||
|
||||
This is useful for preserving high level abstraction that would otherwise be
|
||||
lost while lowering which allows for easier pattern-matching in low-level IR.
|
||||
This is useful for preserving high-level abstractions that would otherwise be
|
||||
lost while lowering, which allows for easier pattern-matching in low-level IR.
|
||||
|
||||
Args:
|
||||
decomposition: function that implements the semantics of the composite op.
|
||||
@ -1517,19 +1517,20 @@ def composite(
|
||||
version: optional int to indicate semantic changes to the composite.
|
||||
|
||||
Returns:
|
||||
out: callable composite function. Note that positional arguments to this
|
||||
function should be interpreted as inputs and keyword arguments should be
|
||||
interpreted as attributes of the op. Any keyword arguments that are passed
|
||||
with ``None`` as a value will be omitted from the
|
||||
``composite_attributes``.
|
||||
Callable: Returns a composite function. Note that positional arguments to
|
||||
this function should be interpreted as inputs and keyword arguments should
|
||||
be interpreted as attributes of the op. Any keyword arguments that are
|
||||
passed with ``None`` as a value will be omitted from the
|
||||
``composite_attributes``.
|
||||
|
||||
Examples:
|
||||
Tangent kernel:
|
||||
|
||||
>>> def my_tangent_composite(x):
|
||||
... return lax.composite(
|
||||
... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent'
|
||||
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
|
||||
... )(x)
|
||||
...
|
||||
>>>
|
||||
>>> pi = jnp.pi
|
||||
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
|
||||
>>> with jnp.printoptions(precision=3, suppress=True):
|
||||
@ -1538,9 +1539,10 @@ def composite(
|
||||
[ 0. 1. -1. 0.]
|
||||
[ 0. 1. -1. 0.]
|
||||
|
||||
The recommended way to create composites is via a decorator. Use `/` and `*`
|
||||
in the function signature to be explicit about positional and keyword
|
||||
arguments respectively:
|
||||
The recommended way to create composites is via a decorator. Use ``/`` and
|
||||
``*`` in the function signature to be explicit about positional and keyword
|
||||
arguments, respectively:
|
||||
|
||||
>>> @partial(lax.composite, name="my.softmax")
|
||||
... def my_softmax_composite(x, /, *, axis):
|
||||
... return jax.nn.softmax(x, axis)
|
||||
@ -3014,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
|
||||
broadcast_shape = sharding.shard_shape(shape)
|
||||
shard = broadcast(fill_value, broadcast_shape)
|
||||
shard = shard.addressable_data(0)
|
||||
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
||||
|
||||
if sharding is not None and not sharding._is_concrete:
|
||||
@ -8194,7 +8197,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
|
||||
def _zero(x):
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -459,78 +460,135 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
|
||||
def ragged_all_to_all(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
|
||||
axis_name, axis_index_groups = None):
|
||||
"""Ragged version of :func:`all_to_all`.
|
||||
"""Ragged version of :func:`all_to_all` collective.
|
||||
|
||||
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
|
||||
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
|
||||
replicas (e.g. there is only one group and covers all axis indices).
|
||||
We say data are "ragged" when they can be represented as a list of arrays
|
||||
whose shapes differ only in the size of the leading axis. For example, these
|
||||
data are ragged, comprising four component arrays::
|
||||
|
||||
Ragged arrays are defined by a set of three arrays:
|
||||
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
|
||||
along which each indexed element has variable size.
|
||||
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
|
||||
``data`` array, and represents the starting offset of each ragged element of
|
||||
the ``data`` array.
|
||||
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
|
||||
the ``data`` array, where the size is specified in units of sub-elements. A
|
||||
sub-element is defined as the suffix of the ``data`` array shape obtained by
|
||||
removing the outermost "ragged" dimension.
|
||||
The ``offsets`` and ``sizes`` arrays must have the same size.
|
||||
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
|
||||
|
||||
# Example ragged tensor
|
||||
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
|
||||
offsets: [3] = {0, 1, 4}
|
||||
sizes: [3] = {1, 3, 4}
|
||||
We often instead want a contiguous representation, e.g. for batching. But
|
||||
because the shapes of the components differ, we can't apply ``jnp.stack`` to
|
||||
represent these data by a single rectangular array with the leading axis
|
||||
indexing the component arrays. So instead of stacking, we concatenate along
|
||||
the leading axis and keep track of offsets and sizes.
|
||||
|
||||
# Index 'data' at 'offsets'[0], 'sizes'[0]'
|
||||
{a,b,c}
|
||||
That is, we can represent ragged data contiguously using a triple of dense
|
||||
arrays ``(data, offsets, sizes)``:
|
||||
* ``data``: the concatenated component arrays,
|
||||
* ``offsets``: 1D array of indices into the leading axis of ``data``
|
||||
indicating where the data for each component array begins,
|
||||
* ``sizes``: 1D array of sizes of the leading axis of each component array.
|
||||
We refer to this triple as a ragged array. (Offsets can't be computed from
|
||||
sizes in general to allow for internal padding.)
|
||||
|
||||
# Index 'data' at 'offsets'[1], 'sizes'[1]'
|
||||
{d,e,f},{g,h,i},{j,k,l}
|
||||
For example::
|
||||
data: f32[8,3] = jnp.array([
|
||||
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
|
||||
])
|
||||
offsets: i32[3] = jnp.array([0, 1, 4])
|
||||
sizes: i32[3] = jnp.array([1, 3, 4])
|
||||
|
||||
# Index 'data' at 'offsets'[2], 'sizes'[2]'
|
||||
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
|
||||
# To extract the first component array, of type f32[1,3]
|
||||
data[offsets[0]:offsets[0]+sizes[0]]
|
||||
|
||||
# To extract the second component array, of type f32[3,3]
|
||||
data[offsets[1]:offsets[1]+sizes[1]]
|
||||
|
||||
``output_offsets`` must be sharded in a way that each replica has offsets in
|
||||
the target replica output perspective.
|
||||
# To extract the third component array, of type f32[4,3]
|
||||
data[offsets[2]:offsets[2]+sizes[2]]
|
||||
|
||||
For i-th output offset, the current replica will send
|
||||
`operand[input_offsets[i]:input_offsets[i]+input_sizes[i]]` update to `i`-th
|
||||
replica that will be written to
|
||||
`output_i[output_offsets[i]:output_offsets[i]+send_sizes[i]]` in `i`-th
|
||||
replica ``output``.
|
||||
The ``ragged_all_to_all`` collective operation communicates slices of ragged
|
||||
arrays between devices. Each caller is both a sender and a receiver. The
|
||||
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
|
||||
caller's ``operand`` to be sent. Received results are returned in an array
|
||||
that has the same value of the argument ``output`` except with received values
|
||||
written at some slices. The ``output_offsets`` argument does *not* indicate
|
||||
the offsets at which all the received results are written; instead,
|
||||
``output_offsets`` indicates the offsets at which the *sent* slices are
|
||||
written on their corresponding receivers. The sizes of received slices are
|
||||
indicated by ``recv_sizes``. See below for details.
|
||||
|
||||
For example, if we have 2 replicas:
|
||||
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
|
||||
``recv_sizes`` must all be the same length, and that length must be divisible
|
||||
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
|
||||
``recv_sizes`` must satisfy::
|
||||
|
||||
replica 0:
|
||||
operand: [1, 2, 2]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 2]
|
||||
output_offsets: [0, 0]
|
||||
recv_sizes: [1, 1]
|
||||
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
|
||||
|
||||
replica 1:
|
||||
operand: [3, 4, 0]
|
||||
output: [0, 0, 0, 0]
|
||||
input_offsets: [0, 1]
|
||||
send_sizes: [1, 1]
|
||||
output_offsets: [1, 2]
|
||||
recv_sizes: [2, 1]
|
||||
Specifically, given a call::
|
||||
|
||||
replica 0's result will be: [1, 3, 0, 0]
|
||||
replica 1's result will be: [2, 2, 4, 0]
|
||||
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
|
||||
output_offsets, recv_sizes, axis_name)
|
||||
|
||||
the caller sends data like::
|
||||
|
||||
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
|
||||
N = len(input_offsets)
|
||||
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
|
||||
assert not leftover
|
||||
|
||||
for i in range(N):
|
||||
dst_idx = i // slices_per_device
|
||||
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
|
||||
axis_name=axis_name, to_axis_index=dst_idx)
|
||||
|
||||
and receives data in ``result`` like::
|
||||
|
||||
result = output
|
||||
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
||||
for i in range(N):
|
||||
src_idx = i // slices_per_device
|
||||
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
|
||||
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
|
||||
|
||||
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
|
||||
``output_offsets`` does not indicate the offsets at which its local ``result``
|
||||
is updated; instead, it indicates where the corresponding sent slices are
|
||||
written on their destination instances. To compute the local offsets at which
|
||||
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
|
||||
|
||||
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
|
||||
these arguments in each mapped function instance::
|
||||
|
||||
axis index 0:
|
||||
operand = [1, 2, 2]
|
||||
output = [0, 0, 0, 0]
|
||||
input_offsets = [0, 1]
|
||||
send_sizes = [1, 2]
|
||||
output_offsets = [0, 0]
|
||||
recv_sizes = [1, 1]
|
||||
|
||||
axis index 1:
|
||||
operand = [3, 4, 0]
|
||||
output = [0, 0, 0, 0]
|
||||
input_offsets = [0, 1]
|
||||
send_sizes = [1, 1]
|
||||
output_offsets = [1, 2]
|
||||
recv_sizes = [2, 1]
|
||||
|
||||
then::
|
||||
|
||||
axis index 0:
|
||||
result = [1, 3, 0, 0]
|
||||
|
||||
axis index 1:
|
||||
result = [2, 2, 4, 0]
|
||||
|
||||
Args:
|
||||
operand: array with ragged dimension along its outermost dimension.
|
||||
output: array of ragged input offsets.
|
||||
input_offsets: array of ragged input send sizes.
|
||||
send_sizes: array of ragged output data.
|
||||
output_offsets: array of ragged offsets in the target replica output.
|
||||
recv_sizes: array of ragged output receive sizes.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
operand: data array of shape (N, A, B, ...) representing concatenated
|
||||
(possibly padded) ragged data to be sent.
|
||||
output: data array of shape (M, A, B, ...) to update with received data.
|
||||
input_offsets: 1D integer array of shape (K,) representing the offsets of
|
||||
leading-axis slices into ``operand`` to be sent.
|
||||
send_sizes: 1D integer array array of shape (K,) representing the sizes of
|
||||
leading-axis slices into ``operand`` to be sent.
|
||||
output_offsets: 1D integer array of shape (K,) representing where the
|
||||
corresponding sent data is written on each corresponding receiver.
|
||||
recv_sizes: 1D integer array of shape (K,) representing sizes of
|
||||
leading-axis slices into ``output`` to update with received data.
|
||||
axis_name: name of the mapped axis over which to perform the communication.
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
|
||||
first two and last two replicas). Groups must cover all axis indices
|
||||
@ -538,7 +596,10 @@ def ragged_all_to_all(
|
||||
behavior is undefined.
|
||||
|
||||
Returns:
|
||||
array with shape equal to ``output``.
|
||||
Array of shape (M, A, B, ...) with the same value as the ``output`` except
|
||||
with received data written into slices starting at
|
||||
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
|
||||
``recv_sizes``.
|
||||
"""
|
||||
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
@ -1210,8 +1271,43 @@ def _ragged_all_to_all_effectful_abstract_eval(
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
def _ragged_all_to_all_jvp(primals, tangents, **params):
|
||||
operand, output, *sizes_and_offsets = primals
|
||||
operand_dot, output_dot, *_ = tangents
|
||||
result = ragged_all_to_all_p.bind(
|
||||
operand, output, *sizes_and_offsets, **params)
|
||||
if type(operand_dot) is type(output_dot) is ad.Zero:
|
||||
result_dot = ad.Zero.from_primal_value(result)
|
||||
else:
|
||||
operand_dot = ad.instantiate_zeros(operand_dot)
|
||||
output_dot = ad.instantiate_zeros(output_dot)
|
||||
result_dot = ragged_all_to_all_p.bind(
|
||||
operand_dot, output_dot, *sizes_and_offsets, **params)
|
||||
return result, result_dot
|
||||
|
||||
def _ragged_all_to_all_transpose(
|
||||
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
||||
*, axis_name, axis_index_groups):
|
||||
if type(t) is ad.Zero:
|
||||
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
||||
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
|
||||
else:
|
||||
zero = ad.zeros_like_aval(operand.aval)
|
||||
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
||||
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
|
||||
operand_t = ragged_all_to_all_p.bind(
|
||||
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
|
||||
axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
mask = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\
|
||||
.at[output_offsets_ + recv_sizes].add(-1))
|
||||
output_t = jax.numpy.where(mask, 0, t)
|
||||
return [operand_t, output_t] + [None] * 4
|
||||
|
||||
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
||||
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
|
||||
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
|
||||
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
|
||||
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
|
||||
batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name')
|
||||
|
||||
|
@ -303,15 +303,16 @@ def _igamma_series(ax, x, a, enabled, dtype, mode):
|
||||
|
||||
def igamma_impl(a, x, *, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = bitwise_or(lt(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
use_igammac = bitwise_and(gt(x, _const(x, 1)), gt(x, a))
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
|
||||
|
||||
use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(
|
||||
_reduce(bitwise_or,[x_is_zero, domain_error, underflow, is_nan]))
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity]))
|
||||
|
||||
output = select(
|
||||
use_igammac,
|
||||
@ -323,8 +324,7 @@ def igamma_impl(a, x, *, dtype):
|
||||
)
|
||||
output = select(x_is_zero, full_like(a, 0), output)
|
||||
output = select(x_is_infinity, full_like(a, 1), output)
|
||||
output = select(bitwise_or(domain_error, is_nan),
|
||||
full_like(a, float('nan')), output)
|
||||
output = select(domain_error, full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
@ -433,11 +433,15 @@ def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
def igammac_impl(a, x, *, dtype):
|
||||
out_of_range = bitwise_or(le(x, _const(x, 0)), le(a, _const(a, 0)))
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
enabled = bitwise_not(bitwise_or(out_of_range, underflow))
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero]))
|
||||
ax = exp(ax)
|
||||
|
||||
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
|
||||
@ -445,10 +449,10 @@ def igammac_impl(a, x, *, dtype):
|
||||
igammac_cf_call = _igammac_continued_fraction(ax, x, a,
|
||||
bitwise_and(enabled, bitwise_not(use_igamma)), dtype, IgammaMode.VALUE)
|
||||
|
||||
result = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
result = select(x_is_infinity, full_like(result, 0), result)
|
||||
return select(out_of_range, full_like(a, 1), result)
|
||||
output = select(use_igamma, _const(a, 1) - igamma_call, igammac_cf_call)
|
||||
output = select(bitwise_or(x_is_infinity, a_is_zero), full_like(output, 0), output)
|
||||
output = select(domain_error, full_like(a, float('nan')), output)
|
||||
return output
|
||||
|
||||
def igamma_grad_a_impl(a, x, *, dtype):
|
||||
is_nan = bitwise_or(_isnan(a), _isnan(x))
|
||||
|
@ -343,7 +343,7 @@ class BlockSpec:
|
||||
if self.block_shape is None:
|
||||
block_shape = array_aval.shape
|
||||
else:
|
||||
block_shape = self.block_shape
|
||||
block_shape = self.block_shape # type: ignore
|
||||
if len(array_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Block shape for {origin} (= {block_shape}) "
|
||||
|
@ -32,6 +32,7 @@ pytype_strict_library(
|
||||
],
|
||||
deps = [
|
||||
":block_spec",
|
||||
":custom_evaluate",
|
||||
":fusable",
|
||||
":fusion",
|
||||
":jaxpr_fusion",
|
||||
@ -44,6 +45,7 @@ pytype_strict_library(
|
||||
"block_spec.py",
|
||||
],
|
||||
deps = [
|
||||
":fuser_utils",
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
"//jax:api_util",
|
||||
@ -119,3 +121,27 @@ pytype_strict_library(
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "custom_evaluate",
|
||||
srcs = ["custom_evaluate.py"],
|
||||
deps = [
|
||||
":fuser_utils",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:source_info_util",
|
||||
"//jax:tree_util",
|
||||
"//jax:util",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "fuser_utils",
|
||||
srcs = ["fuser_utils.py"],
|
||||
deps = [
|
||||
"//jax:api_util",
|
||||
"//jax:core",
|
||||
"//jax:partial_eval",
|
||||
"//jax:tree_util",
|
||||
],
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val
|
||||
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
||||
|
@ -26,15 +26,14 @@ from typing import Any, Callable, Protocol, Sequence
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.fuser import fuser_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -226,18 +225,6 @@ def _unwrap_block_spec_scalar_prefetch(
|
||||
return out_block_spec
|
||||
|
||||
|
||||
def _make_jaxpr(f, *args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
flat_avals = [core.get_aval(x) for x in flat_args]
|
||||
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
return jaxpr, consts, in_tree, out_tree
|
||||
|
||||
|
||||
def pull_block_spec(
|
||||
f: Callable,
|
||||
out_block_specs: pallas_core.BlockSpec | tuple[pallas_core.BlockSpec, ...],
|
||||
@ -246,7 +233,9 @@ def pull_block_spec(
|
||||
grid: tuple[int | jax.Array, ...] | None = None,
|
||||
):
|
||||
def wrapped(*args, **kwargs):
|
||||
jaxpr, consts, in_tree, out_tree_ = _make_jaxpr(f, *args, **kwargs)
|
||||
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
|
||||
f, *args, **kwargs
|
||||
)
|
||||
# TODO(sharadmv): handle these consts better, they should correspond to
|
||||
# scalar prefetch.
|
||||
del consts, out_tree_
|
||||
@ -563,7 +552,9 @@ def make_kernel_function(
|
||||
def get_fusion_values(
|
||||
fusion: Callable, *args, **kwargs
|
||||
) -> tuple[Callable, tuple[jax.Array, ...], tuple[jax.Array, ...]]:
|
||||
jaxpr, values, in_tree, out_tree = _make_jaxpr(fusion, *args, **kwargs)
|
||||
jaxpr, values, in_tree, out_tree = fuser_utils.make_jaxpr(
|
||||
fusion, *args, **kwargs
|
||||
)
|
||||
assert len(values) == len(jaxpr.constvars), (jaxpr, values)
|
||||
out_usages = tuple({Usage.REGULAR} for _ in jaxpr.outvars)
|
||||
read_usage_env = compute_usage(jaxpr, out_usages)
|
||||
@ -1325,7 +1316,7 @@ def push_block_spec(
|
||||
flat_block_specs, in_tree_ = tree_util.tree_flatten(
|
||||
(in_spec_args, in_spec_kwargs)
|
||||
)
|
||||
jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs)
|
||||
jaxpr, _, in_tree, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
|
||||
if in_tree != in_tree_:
|
||||
raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}')
|
||||
out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs)
|
||||
|
82
jax/_src/pallas/fuser/custom_evaluate.py
Normal file
82
jax/_src/pallas/fuser/custom_evaluate.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers for evaluating functions under certain constraints."""
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.pallas.fuser import fuser_utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomEvaluateSettings:
|
||||
allow_transpose: bool = True
|
||||
|
||||
|
||||
def evaluate(f, *, allow_transpose: bool = True):
|
||||
def wrapped(*args, **kwargs):
|
||||
jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
|
||||
settings = CustomEvaluateSettings(allow_transpose=allow_transpose)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# Disallow most higher-order primitives for now.
|
||||
disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p}
|
||||
|
||||
|
||||
def _custom_evaluate_jaxpr(
|
||||
settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args
|
||||
):
|
||||
def read(v: core.Atom) -> Any:
|
||||
return v.val if isinstance(v, core.Literal) else env[v]
|
||||
|
||||
def write(v: core.Var, val: Any) -> None:
|
||||
env[v] = val
|
||||
|
||||
env: dict[core.Var, Any] = {}
|
||||
util.safe_map(write, jaxpr.constvars, consts)
|
||||
util.safe_map(write, jaxpr.invars, args)
|
||||
lu = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
|
||||
if eqn.primitive in disallowed_primitives:
|
||||
raise NotImplementedError(f'Primitive {eqn.primitive} not supported.')
|
||||
if not settings.allow_transpose and eqn.primitive is lax.transpose_p:
|
||||
raise ValueError('Transpose not allowed.')
|
||||
name_stack = (
|
||||
source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
)
|
||||
traceback = eqn.source_info.traceback
|
||||
with source_info_util.user_context(
|
||||
traceback, name_stack=name_stack
|
||||
), eqn.ctx.manager:
|
||||
ans = eqn.primitive.bind(
|
||||
*subfuns, *util.safe_map(read, eqn.invars), **bind_params
|
||||
)
|
||||
if eqn.primitive.multiple_results:
|
||||
util.safe_map(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
core.clean_up_dead_vars(eqn, env, lu)
|
||||
return util.safe_map(read, jaxpr.outvars)
|
33
jax/_src/pallas/fuser/fuser_utils.py
Normal file
33
jax/_src/pallas/fuser/fuser_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Basic utils for fuser internals."""
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
|
||||
|
||||
def make_jaxpr(f, *args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
flat_avals = [core.get_aval(x) for x in flat_args]
|
||||
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
return jaxpr, consts, in_tree, out_tree
|
@ -1853,7 +1853,13 @@ def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
|
||||
|
||||
|
||||
def _dot_general_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
|
||||
ctx: LoweringRuleContext,
|
||||
x,
|
||||
y,
|
||||
dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
**_,
|
||||
):
|
||||
(lhs_dims, rhs_dims), _ = dimension_numbers
|
||||
(aval_out,) = ctx.avals_out
|
||||
@ -1894,10 +1900,34 @@ def _dot_general_lowering_rule(
|
||||
x = vector.broadcast(bcast_shape, x)
|
||||
if ctx.avals_in[1].shape != bcast_shape:
|
||||
y = vector.broadcast(bcast_shape, y)
|
||||
red_dtype = (
|
||||
preferred_element_type if preferred_element_type else lhs_aval.dtype
|
||||
)
|
||||
red_type = aval_to_ir_type(
|
||||
ctx.lowering_context.dynamic_shape_replacement_fn,
|
||||
lhs_aval.update(shape=(lhs_aval.shape[0],)),
|
||||
lhs_aval.update(shape=(lhs_aval.shape[0],), dtype=red_dtype),
|
||||
)
|
||||
|
||||
if lhs_aval.dtype != red_dtype:
|
||||
lhs_type = aval_to_ir_type(
|
||||
ctx.lowering_context.dynamic_shape_replacement_fn,
|
||||
lhs_aval.update(shape=lhs_aval.shape, dtype=red_dtype),
|
||||
)
|
||||
if red_dtype == jnp.float32:
|
||||
x = arith.extf(lhs_type, x)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
|
||||
|
||||
if rhs_aval.dtype != red_dtype:
|
||||
rhs_type = aval_to_ir_type(
|
||||
ctx.lowering_context.dynamic_shape_replacement_fn,
|
||||
rhs_aval.update(shape=rhs_aval.shape, dtype=red_dtype),
|
||||
)
|
||||
if red_dtype == jnp.float32:
|
||||
y = arith.extf(rhs_type, y)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
|
||||
|
||||
acc = arith.ConstantOp(
|
||||
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
|
||||
)
|
||||
|
@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
raise NotImplementedError(f"Unsupported layout {x.layout}")
|
||||
|
||||
|
||||
def _reduce_lowering_rule_wg(
|
||||
kind: vector_dialect.CombiningKind,
|
||||
acc: object,
|
||||
ctx: LoweringRuleContext,
|
||||
x,
|
||||
*,
|
||||
axes,
|
||||
) -> ir.OpView:
|
||||
[x_aval] = ctx.avals_in
|
||||
[out_aval] = ctx.avals_out
|
||||
x = _ensure_ir_value(x, x_aval.dtype)
|
||||
out_type = mgpu_utils.dtype_to_ir_type(out_aval.dtype)
|
||||
if not out_aval.shape:
|
||||
# Special-case: reducing to a scalar.
|
||||
if x_aval.ndim != 1:
|
||||
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
|
||||
# 1D inputs.
|
||||
raise NotImplementedError("Only 1D inputs are supported")
|
||||
return vector_dialect.ReductionOp(out_type, kind, x)
|
||||
acc = vector_dialect.splat(
|
||||
ir.VectorType.get(out_aval.shape, out_type),
|
||||
_ensure_ir_value(acc, out_aval.dtype),
|
||||
)
|
||||
return vector_dialect.MultiDimReductionOp(kind, x, acc, axes)
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
|
||||
op = _reduce_lowering_rule_wg(
|
||||
vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes
|
||||
)
|
||||
op.attributes["offset"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes
|
||||
)
|
||||
return op.result
|
||||
|
||||
|
||||
@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
|
||||
[x_aval] = ctx.avals_in
|
||||
if jnp.issubdtype(x_aval.dtype, jnp.floating):
|
||||
kind = vector_dialect.CombiningKind.MAXIMUMF
|
||||
acc = float("-inf")
|
||||
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
|
||||
kind = vector_dialect.CombiningKind.MAXSI
|
||||
acc = np.iinfo(x_aval.dtype).max
|
||||
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
|
||||
kind = vector_dialect.CombiningKind.MAXUI
|
||||
acc = np.iinfo(x_aval.dtype).max
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
|
||||
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result
|
||||
|
||||
|
||||
@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane)
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
|
@ -198,7 +198,8 @@ def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) ->
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
|
@ -865,15 +865,15 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
def custom_assert(tst, result_jax, result_tf, *, args, tol,
|
||||
err_msg): # noqa: F811
|
||||
arg1, arg2 = args
|
||||
# lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
|
||||
# lax.igammac returns nan. when arg1 <= 0; tf.math.igammac returns 1
|
||||
special_cases = (arg1 <= 0.) | (arg2 <= 0)
|
||||
nr_special_cases = np.count_nonzero(special_cases)
|
||||
tst.assertAllClose(
|
||||
np.full((nr_special_cases,), 1., dtype=dtype),
|
||||
np.full((nr_special_cases,), np.nan, dtype=dtype),
|
||||
result_jax[special_cases],
|
||||
err_msg=err_msg)
|
||||
tst.assertAllClose(
|
||||
np.full((nr_special_cases,), np.nan, dtype=dtype),
|
||||
np.full((nr_special_cases,), 1, dtype=dtype),
|
||||
result_tf[special_cases],
|
||||
err_msg=err_msg)
|
||||
# non-special cases are equal
|
||||
@ -892,12 +892,12 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
custom_numeric(dtypes=[np.float64], tol=1e-9),
|
||||
custom_numeric(devices="gpu", tol=1e-3),
|
||||
custom_numeric(
|
||||
modes=("compiled",),
|
||||
custom_assert=custom_assert,
|
||||
devices=("cpu", "gpu"),
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
description=(
|
||||
"May return different results at undefined points "
|
||||
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 0 or "
|
||||
"JAX returns 1 and TF returns `NaN`")),
|
||||
"(both arguments less or equal 0). JAX returns `NaN` and TF returns 1")),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -260,7 +260,7 @@ def _construct_smem_reftree(
|
||||
dynamic_smem, c(dynamic_smem_offset, index), [],
|
||||
)
|
||||
if layout is None:
|
||||
layout = tcgen05._infer_tmem_layout(shape)
|
||||
layout = tcgen05._infer_tmem_layout(shape, collective)
|
||||
num_cols = layout.cols_in_shape(shape)
|
||||
delayed_warp_init.append(
|
||||
functools.partial(
|
||||
|
@ -259,14 +259,15 @@ def _vector_load_op_lowering_rule(
|
||||
is_signed=is_signed,
|
||||
vec_size=strided_layout.vec_size,
|
||||
)
|
||||
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
|
||||
elif layouts.from_layout_attr(out_layout_attr) == fa.TILED_LAYOUT_WGMMA:
|
||||
layout = ir.MemRefType(vector_load_op.base.type).layout
|
||||
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
|
||||
transformed_ref = transform_memref(vector_load_op.base, transforms)
|
||||
fragmented_array = fa.FragmentedArray.load_tiled(
|
||||
transformed_ref,
|
||||
swizzle=swizzle,
|
||||
is_signed=is_signed
|
||||
is_signed=is_signed,
|
||||
layout=fa.TILED_LAYOUT_WGMMA,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -319,6 +320,34 @@ def _vector_splat_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
||||
|
||||
|
||||
@_register_lowering(vector.ReductionOp)
|
||||
def _vector_reduction_op_lowering_rule(
|
||||
ctx: LoweringContext, op: vector.ReductionOp
|
||||
) -> Sequence[ir.Value]:
|
||||
del ctx # Unused.
|
||||
[layout] = inference_utils.in_layouts(op)
|
||||
() = inference_utils.out_layouts(op)
|
||||
element_type = ir.VectorType(op.vector.type).element_type
|
||||
is_signed = False if ir.IntegerType.isinstance(element_type) else None
|
||||
a = _fragmented_array_from_ir(op.vector, layout, is_signed)
|
||||
match str(op.kind):
|
||||
case "#vector.kind<add>":
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
scratch = _slice_smem(
|
||||
ir.MemRefType.get([4], element_type, memory_space=smem),
|
||||
arith.constant(None, op.attributes["offset"]),
|
||||
)
|
||||
result = a.reduce_sum(scratch)
|
||||
case (
|
||||
"#vector.kind<maxsi>" | "#vector.kind<maxui>" | "#vector.kind<maximumf>"
|
||||
):
|
||||
# TODO(slebedev): Implement this and remove the raise below.
|
||||
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
|
||||
return [_fragmented_array_to_ir(result, op.result.type)]
|
||||
|
||||
|
||||
def memref_layout_to_swizzle_and_transforms(
|
||||
layout: ir.Attribute,
|
||||
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
|
||||
@ -634,7 +663,10 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
*inference_utils.in_layouts(wgmma_op),
|
||||
*inference_utils.out_layouts(wgmma_op),
|
||||
)
|
||||
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
|
||||
is_supported_layout = (
|
||||
lambda l: layouts.from_tiled_layout_attr(l) == fa.TILED_LAYOUT_WGMMA
|
||||
)
|
||||
if not all(map(is_supported_layout, fa_layouts)):
|
||||
raise ValueError("Layout mismatch")
|
||||
wgmma_layout = fa_layouts[0]
|
||||
|
||||
@ -667,7 +699,12 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
|
||||
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
|
||||
|
||||
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
|
||||
return [
|
||||
_fragmented_array_to_ir(
|
||||
new_acc.value.to_layout(fa.TILED_LAYOUT_WGMMA),
|
||||
wgmma_op.accumulator.type,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@_register_lowering(mgpu.ArriveExpectTxOp)
|
||||
@ -704,16 +741,17 @@ def _mgpu_slice_smem_op_lowering_rule(
|
||||
ctx: LoweringContext, op: SliceSMEMOp
|
||||
) -> Sequence[ir.Value]:
|
||||
del ctx
|
||||
return [_slice_smem(op.result.type, op.offset)]
|
||||
|
||||
|
||||
def _slice_smem(result: ir.Type, offset: ir.Value):
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
|
||||
smem_base = gpu.dynamic_shared_memory(
|
||||
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
|
||||
)
|
||||
|
||||
offset = arith.index_cast(ir.IndexType.get(), op.offset)
|
||||
|
||||
return [memref.view(op.result.type, smem_base, offset, [])]
|
||||
offset = arith.index_cast(ir.IndexType.get(), offset)
|
||||
return memref.view(result, smem_base, offset, [])
|
||||
|
||||
|
||||
@_register_lowering(scf.ForOp)
|
||||
@ -857,7 +895,8 @@ def _should_lower(op: ir.OpView) -> bool:
|
||||
|
||||
|
||||
def lower_mgpu_dialect(
|
||||
module: ir.Module, launch_context: launch_context.LaunchContext | None
|
||||
module: ir.Module,
|
||||
launch_context: launch_context.LaunchContext | None,
|
||||
):
|
||||
# TODO(apaszke,bchetioui): Make sure the layouts match.
|
||||
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full
|
||||
|
@ -230,8 +230,8 @@ def main(unused_argv):
|
||||
tile_n *= 2
|
||||
if m < tile_m or n < tile_n:
|
||||
continue
|
||||
if kwargs["collective"] and tile_n >= 512:
|
||||
continue # TODO(apaszke): Support 512
|
||||
if tile_n > 512:
|
||||
continue
|
||||
if (m // tile_m) % kwargs["grid_tile_m"]:
|
||||
continue
|
||||
try:
|
||||
|
@ -1389,7 +1389,7 @@ class FragmentedArray:
|
||||
if isinstance(self.layout, WGSplatFragLayout):
|
||||
[reg] = self.registers.flat
|
||||
if ir.FloatType.isinstance(self.mlir_dtype):
|
||||
op = arith.mulf
|
||||
op = mulf
|
||||
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
||||
op = arith.muli
|
||||
else:
|
||||
|
@ -63,7 +63,7 @@ def _choose_representative_layout(
|
||||
|
||||
Given the input set of possible layouts, this function extracts a single
|
||||
representative layout. Currently, this function only works with strided,
|
||||
splat, and WGMMA fragmented layouts.
|
||||
splat, and tiled layouts.
|
||||
|
||||
Returns:
|
||||
A single layout that can be used to annotate the operation, or None if the
|
||||
@ -86,18 +86,18 @@ def _choose_representative_layout(
|
||||
)
|
||||
)
|
||||
|
||||
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
|
||||
tiled_layouts: list[fa.TiledLayout] = list(
|
||||
map(
|
||||
layouts_lib.from_layout_attr,
|
||||
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
|
||||
filter(layouts_lib.is_tiled_layout, layouts),
|
||||
)
|
||||
)
|
||||
|
||||
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
|
||||
if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len(
|
||||
layouts
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
|
||||
f"Expected only strided, splat, and tiled layouts, got {layouts}"
|
||||
)
|
||||
|
||||
if len(splat_layouts) > 1:
|
||||
@ -112,13 +112,19 @@ def _choose_representative_layout(
|
||||
"is not supported."
|
||||
)
|
||||
|
||||
if (wgmma_layouts and strided_layouts):
|
||||
if len(tiled_layouts) > 1:
|
||||
raise NotImplementedError(
|
||||
"Mixing strided and WGMMA layouts is not supported."
|
||||
"Finding a representative layout for several distinct tiled layouts "
|
||||
"is not supported."
|
||||
)
|
||||
|
||||
if wgmma_layouts:
|
||||
return layouts_lib.to_layout_attr(wgmma_layouts[0])
|
||||
if tiled_layouts and strided_layouts:
|
||||
raise NotImplementedError(
|
||||
"Mixing strided and tiled layouts is not supported."
|
||||
)
|
||||
|
||||
if tiled_layouts:
|
||||
return layouts_lib.to_layout_attr(tiled_layouts[0])
|
||||
|
||||
if strided_layouts:
|
||||
[strided_layout] = strided_layouts
|
||||
@ -330,10 +336,16 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
|
||||
|
||||
return [], [layout]
|
||||
|
||||
@partial(_add_layout_inference_rule, vector.ReductionOp)
|
||||
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
|
||||
if layout := inference_utils.value_layout(op.vector):
|
||||
return [layout], []
|
||||
return None
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
|
||||
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
|
||||
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
|
||||
layout = layouts_lib.to_layout_attr(fa.TILED_LAYOUT_WGMMA)
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
return [layout, layout], [layout]
|
||||
|
@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
_tiled_layout_attr_pattern = re.compile(
|
||||
r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
|
||||
r" warp_dim\s*=\s*(?P<warp_dim>[-\d]+),"
|
||||
r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
|
||||
r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
|
||||
)
|
||||
|
||||
|
||||
def to_tiled_layout_attr(
|
||||
layout: fa.TiledLayout,
|
||||
) -> ir.Attribute:
|
||||
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
|
||||
|
||||
tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]"
|
||||
tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]"
|
||||
return ir.Attribute.parse(
|
||||
f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim},"
|
||||
f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>"
|
||||
)
|
||||
|
||||
|
||||
_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[")
|
||||
|
||||
|
||||
def from_tiled_layout_attr(
|
||||
attr: ir.Attribute,
|
||||
) -> fa.TiledLayout:
|
||||
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute.
|
||||
|
||||
Raises:
|
||||
ValueError: If the attribute is not a #mosaic_gpu.TiledLayout
|
||||
attribute.
|
||||
"""
|
||||
match = _tiled_layout_attr_pattern.fullmatch(str(attr))
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}"
|
||||
)
|
||||
|
||||
tiling_str = match.group("tiling")
|
||||
tile_strings = []
|
||||
if len(tiling_str) > 2:
|
||||
tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1])
|
||||
tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings)
|
||||
return fa.TiledLayout(
|
||||
tiling=fa.Tiling(tiles),
|
||||
warp_dim=int(match.group("warp_dim")),
|
||||
lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")),
|
||||
vector_dim=int(match.group("vector_dim"))
|
||||
)
|
||||
|
||||
|
||||
def is_tiled_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_tiled_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
def to_layout_attr(
|
||||
layout: (
|
||||
fa.WGSplatFragLayout
|
||||
| fa.WGStridedFragLayout
|
||||
| fa.WGMMAFragLayout
|
||||
| fa.TiledLayout
|
||||
| fa.WGMMARowFragLayout
|
||||
),
|
||||
) -> ir.Attribute:
|
||||
@ -108,8 +164,8 @@ def to_layout_attr(
|
||||
return to_splat_fragmented_layout_attr(layout)
|
||||
case fa.WGStridedFragLayout():
|
||||
return to_strided_fragmented_layout_attr(layout)
|
||||
case fa.WGMMAFragLayout():
|
||||
return ir.Attribute.parse("#mosaic_gpu.WGMMAFragLayout")
|
||||
case fa.TiledLayout():
|
||||
return to_tiled_layout_attr(layout)
|
||||
case fa.WGMMARowFragLayout():
|
||||
return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout")
|
||||
case _:
|
||||
@ -118,15 +174,6 @@ def to_layout_attr(
|
||||
)
|
||||
|
||||
|
||||
_wgmma_fragmented_layout_attr_pattern = re.compile(
|
||||
r"^#mosaic_gpu.WGMMAFragLayout$"
|
||||
)
|
||||
|
||||
|
||||
def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
_wgmma_row_fragmented_layout_attr_pattern = re.compile(
|
||||
r"^#mosaic_gpu.WGMMARowFragLayout$"
|
||||
)
|
||||
@ -141,7 +188,7 @@ def from_layout_attr(
|
||||
) -> (
|
||||
fa.WGSplatFragLayout
|
||||
| fa.WGStridedFragLayout
|
||||
| fa.WGMMAFragLayout
|
||||
| fa.TiledLayout
|
||||
| fa.WGMMARowFragLayout
|
||||
):
|
||||
"""Constructs a layout from an MLIR attribute."""
|
||||
@ -149,8 +196,8 @@ def from_layout_attr(
|
||||
return from_splat_fragmented_layout_attr(attr)
|
||||
elif is_strided_fragmented_layout(attr):
|
||||
return from_strided_fragmented_layout_attr(attr)
|
||||
elif is_wgmma_fragmented_layout(attr):
|
||||
return fa.WGMMAFragLayout()
|
||||
elif is_tiled_layout(attr):
|
||||
return from_tiled_layout_attr(attr)
|
||||
elif is_wgmma_row_fragmented_layout(attr):
|
||||
return fa.WGMMARowFragLayout()
|
||||
else:
|
||||
|
@ -83,6 +83,7 @@ def mma(
|
||||
accumulate: ir.Value | bool = True,
|
||||
collective: bool = False,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
@ -112,6 +113,10 @@ def mma(
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
|
||||
)
|
||||
if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)):
|
||||
raise ValueError(
|
||||
f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}"
|
||||
)
|
||||
f32 = ir.F32Type.get()
|
||||
if element_type == f32 or element_type == ir.BF16Type.get():
|
||||
if d.dtype != f32:
|
||||
@ -136,11 +141,7 @@ def mma(
|
||||
raise ValueError(f"N must be a multiple of 8, got: {n}")
|
||||
elif n > 256 and n != 512:
|
||||
raise ValueError("Only N below 256 or N=512 are supported")
|
||||
if num_cta == 2 and n > 256:
|
||||
raise NotImplementedError(
|
||||
"N is too big for collective MMA. Only up to 256 is supported."
|
||||
)
|
||||
n_group_elems = min(n, 256)
|
||||
n_group_elems = min(n, 256 // num_cta)
|
||||
if m % m_group_elems:
|
||||
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
||||
if k % k_group_elems:
|
||||
@ -179,6 +180,7 @@ def mma(
|
||||
|
||||
# Step 4. Issue the instructions.
|
||||
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
||||
n_collective_group_elems = n_group_elems * num_cta
|
||||
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
|
||||
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
|
||||
@ -188,9 +190,9 @@ def mma(
|
||||
raise NotImplementedError("D needs to be sliced")
|
||||
acc = accumulate if ki == 0 else true
|
||||
_do_mma(
|
||||
d.slice(
|
||||
slice(None), utils.ds(ni * n_group_elems, n_group_elems)
|
||||
).address,
|
||||
arith.addi(
|
||||
d.address, arith.constant(i32, ni * n_collective_group_elems)
|
||||
),
|
||||
a_mk,
|
||||
b_nk,
|
||||
d_type=ir.F32Type.get(),
|
||||
@ -377,8 +379,15 @@ class TMEMLayout:
|
||||
+------------------+------------------+
|
||||
| [0:64, 64:128] | [64:128, 64:128] |
|
||||
+------------------+------------------+
|
||||
|
||||
The above is further complicated by column_tile_stride, which is used to
|
||||
swizzle the ordering of column tiles. That is, if column_tile_stride is 2,
|
||||
we will first lay out all tiles that have the column index 0, 2, 4, and so on
|
||||
until we run out of tiles. Only then we lay out the tiles with column index
|
||||
1, 3, etc.
|
||||
"""
|
||||
elements_in_tile: tuple[int, int]
|
||||
column_tile_stride: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
row_tiling = self.elements_in_tile[0]
|
||||
@ -405,7 +414,7 @@ class TMEMLayout:
|
||||
return num_tiles // tiles_in_row * cols_in_tile
|
||||
|
||||
|
||||
def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
|
||||
def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout:
|
||||
if shape[0] > TMEM_ROWS:
|
||||
raise ValueError(
|
||||
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
|
||||
@ -421,7 +430,15 @@ def _infer_tmem_layout(shape: tuple[int, int]) -> TMEMLayout:
|
||||
"Can only infer TMEM layout for shapes with row count that's a power of"
|
||||
f" 2, got: {shape[0]}"
|
||||
)
|
||||
return TMEMLayout(elements_in_tile=(shape[0], 1))
|
||||
if shape[1] % 8:
|
||||
raise ValueError(
|
||||
"Can only infer TMEM layout for shapes with column count that's a"
|
||||
f" multiple of 8, got: {shape[1]}"
|
||||
)
|
||||
if collective and shape[1] == 512:
|
||||
return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2)
|
||||
else:
|
||||
return TMEMLayout(elements_in_tile=(shape[0], 8))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -432,7 +449,14 @@ class TMEMRef:
|
||||
layout: TMEMLayout
|
||||
|
||||
@classmethod
|
||||
def from_alloc(cls, tmem_addr_ref: ir.Value, shape: tuple[int, int], dtype, layout: TMEMLayout | None = None):
|
||||
def from_alloc(
|
||||
cls,
|
||||
tmem_addr_ref: ir.Value,
|
||||
shape: tuple[int, int],
|
||||
dtype,
|
||||
collective: bool | None = None,
|
||||
layout: TMEMLayout | None = None,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
|
||||
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
|
||||
@ -449,7 +473,11 @@ class TMEMRef:
|
||||
if shape[0] < 32:
|
||||
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
|
||||
if layout is None:
|
||||
layout = _infer_tmem_layout(shape)
|
||||
if collective is None:
|
||||
raise ValueError(
|
||||
"collective argument must be provided when TMEM layout is inferred"
|
||||
)
|
||||
layout = _infer_tmem_layout(shape, collective)
|
||||
else:
|
||||
layout.check_shape(shape)
|
||||
# TODO: Do we have to do this??
|
||||
@ -461,12 +489,17 @@ class TMEMRef:
|
||||
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
|
||||
if any(is_squeezed):
|
||||
raise ValueError("TMEM can only be sliced, not indexed")
|
||||
if self.layout.elements_in_tile[0] != TMEM_ROWS:
|
||||
if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
|
||||
raise NotImplementedError(
|
||||
f"Slicing only implemented for refs with tiling of {TMEM_ROWS} rows"
|
||||
"Slicing only implemented for refs with standard layout, got:"
|
||||
f" {self.layout}"
|
||||
)
|
||||
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
|
||||
raise NotImplementedError("TMEM cannot be sliced along rows")
|
||||
if slice_shape[1] % 8:
|
||||
raise NotImplementedError(
|
||||
"TMEM column slice length must be a multiple of 8"
|
||||
)
|
||||
col_idx = base_idx[1]
|
||||
if not isinstance(col_idx, ir.Value):
|
||||
col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx)
|
||||
@ -484,48 +517,75 @@ class TMEMRef:
|
||||
raise ValueError("TMEM loads only support slicing")
|
||||
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
|
||||
raise NotImplementedError("Slicing of TMEM not impelmented yet")
|
||||
if self.layout.elements_in_tile[0] != TMEM_ROWS:
|
||||
raise NotImplementedError(
|
||||
f"Loads only implemented for refs with tiling of {TMEM_ROWS} rows"
|
||||
)
|
||||
if self.shape[1] % 8:
|
||||
raise NotImplementedError
|
||||
if self.dtype != ir.F32Type.get():
|
||||
raise NotImplementedError(self.dtype)
|
||||
layout = _m128_256bit_32bit_layout(self.shape)
|
||||
regs_shape = layout.registers_shape(self.shape)
|
||||
num = self.shape[1] // 8
|
||||
# TODO(apaszke): Make the tiling configurable through the args too.
|
||||
if num <= 32:
|
||||
num_tiling = num
|
||||
elif num == 64:
|
||||
num_tiling = 32
|
||||
else:
|
||||
raise NotImplementedError(num)
|
||||
registers = np.empty(regs_shape, dtype=object)
|
||||
# We load 16 lanes at a time, but need 32 in total.
|
||||
for row_group in range(2):
|
||||
addr_row = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16))
|
||||
regs = []
|
||||
cols_per_num_tile = 8 # This depends on the 16x256b below.
|
||||
for num_group in range(num // num_tiling):
|
||||
addr_row_col = arith.addi(
|
||||
addr_row,
|
||||
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
|
||||
if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)):
|
||||
# load_32xcols returns a 4xN array, but the FA tiling we use here tiles
|
||||
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
|
||||
registers = _load_32xcols(
|
||||
self.address, self.shape[1], self.dtype
|
||||
).T.reshape(regs_shape)
|
||||
elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2):
|
||||
if self.shape[1] % 128 != 0:
|
||||
raise ValueError(
|
||||
f"TMEM layout {self.layout} is not compatible with shape {self.shape}"
|
||||
)
|
||||
regs += tmem_load(addr_row_col, "16x256b", num_tiling)
|
||||
regs = [llvm.bitcast(self.dtype, r) for r in regs]
|
||||
vector_regs = []
|
||||
undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype))
|
||||
for r_low, r_high in zip(regs[::2], regs[1::2]):
|
||||
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
|
||||
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
|
||||
vector_regs.append(vreg)
|
||||
# Dimension 4 is the one where we split 32 rows into tiles of 8.
|
||||
regs_slice = (slice(None),) * 4 + (slice(row_group * 2, (row_group + 1) * 2),)
|
||||
registers[regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[regs_slice].shape)
|
||||
num_column_tiles = self.shape[1] // 128
|
||||
column_tile_stride = self.layout.column_tile_stride
|
||||
num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride)
|
||||
tiles = []
|
||||
for col_tile_base in range(num_strided_col_groups):
|
||||
for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride):
|
||||
tiles.append(
|
||||
_load_32xcols(
|
||||
arith.addi(self.address, arith.constant(i32, col_tile * 128)),
|
||||
cols=128,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Loads only implemented for refs with standard layout, got: {self.layout}"
|
||||
)
|
||||
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
|
||||
|
||||
def _load_32xcols(base_addr, cols, dtype):
|
||||
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
assert cols % 8 == 0
|
||||
cols_per_num_tile = 8
|
||||
load_shape = "16x256b"
|
||||
num = cols // 8
|
||||
if num <= 32:
|
||||
num_tiling = num
|
||||
elif num == 64:
|
||||
num_tiling = 32
|
||||
else:
|
||||
raise NotImplementedError(num)
|
||||
vector_regs = np.ndarray((4, num), dtype=object)
|
||||
# We load 16 lanes at a time, but need 32 in total.
|
||||
for row_group in range(2):
|
||||
addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16))
|
||||
regs = []
|
||||
for num_group in range(num // num_tiling):
|
||||
addr_row_col = arith.addi(
|
||||
addr_row,
|
||||
arith.constant(i32, num_tiling * num_group * cols_per_num_tile),
|
||||
)
|
||||
regs += tmem_load(addr_row_col, load_shape, num_tiling)
|
||||
regs = [llvm.bitcast(dtype, r) for r in regs]
|
||||
undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
|
||||
for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)):
|
||||
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
|
||||
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
|
||||
vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg
|
||||
return vector_regs
|
||||
|
||||
|
||||
def _m128_256bit_32bit_layout(shape: tuple[int, ...]):
|
||||
if len(shape) != 2:
|
||||
|
@ -1201,3 +1201,7 @@ def bitcast(x: ir.Value, new_type: ir.Type):
|
||||
assert x_ty.width == bitwidth(new_type.element_type) * math.prod(new_type.shape)
|
||||
return vector.bitcast(new_type, vector.splat(ir.VectorType.get((1,), x_ty), x))
|
||||
raise ValueError(f"Can't bitcast {x.type} to {new_type}")
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
@ -18,6 +18,7 @@ from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_val
|
||||
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
|
||||
from jax._src.pallas.fuser.fusable import fusable as fusable
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
||||
|
@ -43,14 +43,22 @@ class MultiPageAsyncCopyDescriptor:
|
||||
):
|
||||
self._vmem_buf = vmem_buf
|
||||
seq_id, kv_pages_start = offset
|
||||
self._async_copies = [
|
||||
pltpu.make_async_copy(
|
||||
pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]],
|
||||
vmem_buf.at[i],
|
||||
sem,
|
||||
)
|
||||
for i in range(vmem_buf.shape[0])
|
||||
]
|
||||
pages_per_seq = page_indices_ref.shape[1]
|
||||
self._async_copies = []
|
||||
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
|
||||
# a bunch of if-ops. Check the performance when we have benchmarking setup.
|
||||
for i in range(vmem_buf.shape[0]):
|
||||
page_idx = kv_pages_start + i
|
||||
page_idx = jax.lax.select(
|
||||
page_idx < pages_per_seq, page_idx, pages_per_seq - 1
|
||||
)
|
||||
self._async_copies.append(
|
||||
pltpu.make_async_copy(
|
||||
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
|
||||
vmem_buf.at[i],
|
||||
sem,
|
||||
)
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Starts the async copies."""
|
||||
|
@ -49,7 +49,7 @@ py_library_providing_imports_info(
|
||||
config_setting(
|
||||
name = "disable_jaxlib_for_cpu_build",
|
||||
flag_values = {
|
||||
"//jax:build_jaxlib": "False",
|
||||
"//jax:build_jaxlib": "false",
|
||||
"@local_config_cuda//:enable_cuda": "False",
|
||||
},
|
||||
)
|
||||
@ -57,7 +57,23 @@ config_setting(
|
||||
config_setting(
|
||||
name = "disable_jaxlib_for_cuda12_build",
|
||||
flag_values = {
|
||||
"//jax:build_jaxlib": "False",
|
||||
"//jax:build_jaxlib": "false",
|
||||
"@local_config_cuda//:enable_cuda": "True",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_py_import_for_cpu_build",
|
||||
flag_values = {
|
||||
"//jax:build_jaxlib": "wheel",
|
||||
"@local_config_cuda//:enable_cuda": "False",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_py_import_for_cuda12_build",
|
||||
flag_values = {
|
||||
"//jax:build_jaxlib": "wheel",
|
||||
"@local_config_cuda//:enable_cuda": "True",
|
||||
},
|
||||
)
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/pair.h"
|
||||
#include "jaxlib/absl_status_casters.h"
|
||||
@ -29,7 +31,7 @@ namespace nb = nanobind;
|
||||
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
|
||||
int batch_size, int max_seq_length, float dropout,
|
||||
bool bidirectional, bool cudnn_allow_tf32,
|
||||
int workspace_size, int reserve_space_size) {
|
||||
size_t workspace_size, size_t reserve_space_size) {
|
||||
return PackDescriptor(RnnDescriptor{
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
||||
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/gpu/rnn_kernels.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -71,7 +72,7 @@ template <>
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
static absl::StatusOr<std::pair<int, int>>
|
||||
static absl::StatusOr<std::pair<size_t, size_t>>
|
||||
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
@ -174,7 +175,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
return std::make_pair(workSpaceSize, reserveSpaceSize);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32) {
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
|
||||
#define JAXLIB_GPU_RNN_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -34,12 +36,12 @@ struct RnnDescriptor {
|
||||
float dropout;
|
||||
int bidirectional;
|
||||
int cudnn_allow_tf32;
|
||||
int workspace_size;
|
||||
int reserve_space_size;
|
||||
size_t workspace_size;
|
||||
size_t reserve_space_size;
|
||||
};
|
||||
|
||||
// Return (workspace size, reserve space size).
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
absl::StatusOr<std::pair<size_t, size_t>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32);
|
||||
|
@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
|
||||
param.value)));
|
||||
}
|
||||
}
|
||||
// Triton's kernel ABI expects an additional scratchpad global memory.
|
||||
// For now it is only used for on-device creation of TMA descriptors, which
|
||||
// we do not use yet, so we are just replacing this argument with a null
|
||||
// pointer.
|
||||
// TODO: b/381242007 - Allocate a proper buffer if we want to use
|
||||
// device-side TMA APIs.
|
||||
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
|
||||
params.push_back(&scratch_ptr);
|
||||
|
||||
params.push_back(buffers++); // Scratch buffer.
|
||||
return kernel_.Launch(stream, grid_, params.data());
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,15 @@ def if_building_jaxlib(
|
||||
"@pypi_jax_cuda12_plugin//:pkg",
|
||||
"@pypi_jax_cuda12_pjrt//:pkg",
|
||||
],
|
||||
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]):
|
||||
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"],
|
||||
if_py_import = [
|
||||
"//jaxlib/tools:jaxlib_py_import",
|
||||
"//jaxlib/tools:jax_cuda_plugin_py_import",
|
||||
"//jaxlib/tools:jax_cuda_pjrt_py_import",
|
||||
],
|
||||
if_py_import_for_cpu = [
|
||||
"//jaxlib/tools:jaxlib_py_import",
|
||||
]):
|
||||
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
|
||||
|
||||
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
|
||||
@ -234,12 +242,16 @@ def if_building_jaxlib(
|
||||
if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of
|
||||
gpu-enabled builds
|
||||
if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds
|
||||
if_py_import: the py_import targets to depend on in case of gpu-enabled builds
|
||||
if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds
|
||||
"""
|
||||
|
||||
return select({
|
||||
"//jax:enable_jaxlib_build": if_building,
|
||||
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu,
|
||||
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building,
|
||||
"//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu,
|
||||
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import,
|
||||
})
|
||||
|
||||
# buildifier: disable=function-docstring
|
||||
|
@ -128,7 +128,6 @@ def MosaicGPU_WGStridedFragLayout : AttrDef<MosaicGPU_Dialect, "WGStridedFragLay
|
||||
let assemblyFormat = "`<` $shape`,` $vector_size `>`";
|
||||
}
|
||||
|
||||
|
||||
def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout", []> {
|
||||
let summary = "Annotates an array that is the result of a splat.";
|
||||
let description = [{
|
||||
@ -143,20 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout"
|
||||
let assemblyFormat = "`<` $shape `>`";
|
||||
}
|
||||
|
||||
def MosaicGPU_WGMMAFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMAFragLayout", []> {
|
||||
let summary = "2D array that can be tiled by supported WGMMA shapes.";
|
||||
let description = [{
|
||||
This layout annotates arrays that are fragmented across all threads in a
|
||||
warpgroup that is executing a WGMMA operation. The shape of the array is
|
||||
(m, n) where:
|
||||
- m % 64 == 0
|
||||
- n % 8 == 0
|
||||
}];
|
||||
|
||||
let mnemonic = "WGMMAFragLayout";
|
||||
let assemblyFormat = "";
|
||||
}
|
||||
|
||||
def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayout", []> {
|
||||
let summary = "1D array that is a row that can be tiled by supported WGMMA shapes.";
|
||||
let description = [{
|
||||
@ -169,6 +154,24 @@ def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayou
|
||||
let assemblyFormat = "";
|
||||
}
|
||||
|
||||
def MosaicGPU_TiledLayout : AttrDef<MosaicGPU_Dialect, "TiledLayout", []> {
|
||||
let summary = "A layout derived from a tiling expression.";
|
||||
let description = [{
|
||||
See mosaic/gpu/fragmented_array.py -> TiledLayout for more details.
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
"::mlir::ArrayAttr":$tiling,
|
||||
"int":$warp_dim,
|
||||
"::mlir::ArrayAttr":$lane_dims,
|
||||
"int":$vector_dim
|
||||
);
|
||||
let mnemonic = "TiledLayout";
|
||||
let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` "
|
||||
"`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`";
|
||||
}
|
||||
|
||||
|
||||
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
|
||||
// but it was not possible to reuse that definition. Including that file
|
||||
// pulls in ops definitions that we don't want and they fail to compile.
|
||||
|
@ -18,6 +18,10 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load(
|
||||
"@xla//third_party/py:py_import.bzl",
|
||||
"py_import",
|
||||
)
|
||||
load(
|
||||
"@xla//third_party/py:py_manylinux_compliance_test.bzl",
|
||||
"verify_manylinux_compliance_test",
|
||||
@ -228,6 +232,18 @@ string_flag(
|
||||
build_setting_default = "dist",
|
||||
)
|
||||
|
||||
NVIDIA_WHEELS_DEPS = [
|
||||
"@pypi_nvidia_cublas_cu12//:whl",
|
||||
"@pypi_nvidia_cuda_cupti_cu12//:whl",
|
||||
"@pypi_nvidia_cuda_runtime_cu12//:whl",
|
||||
"@pypi_nvidia_cudnn_cu12//:whl",
|
||||
"@pypi_nvidia_cufft_cu12//:whl",
|
||||
"@pypi_nvidia_cusolver_cu12//:whl",
|
||||
"@pypi_nvidia_cusparse_cu12//:whl",
|
||||
"@pypi_nvidia_nccl_cu12//:whl",
|
||||
"@pypi_nvidia_nvjitlink_cu12//:whl",
|
||||
]
|
||||
|
||||
jax_wheel(
|
||||
name = "jaxlib_wheel",
|
||||
no_abi = False,
|
||||
@ -235,6 +251,11 @@ jax_wheel(
|
||||
wheel_name = "jaxlib",
|
||||
)
|
||||
|
||||
py_import(
|
||||
name = "jaxlib_py_import",
|
||||
wheel = ":jaxlib_wheel",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jaxlib_wheel_editable",
|
||||
editable = True,
|
||||
@ -252,6 +273,12 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_plugin",
|
||||
)
|
||||
|
||||
py_import(
|
||||
name = "jax_cuda_plugin_py_import",
|
||||
wheel = ":jax_cuda_plugin_wheel",
|
||||
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_plugin_wheel_editable",
|
||||
editable = True,
|
||||
@ -290,6 +317,12 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_pjrt",
|
||||
)
|
||||
|
||||
py_import(
|
||||
name = "jax_cuda_pjrt_py_import",
|
||||
wheel = ":jax_cuda_pjrt_wheel",
|
||||
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel_editable",
|
||||
editable = True,
|
||||
|
@ -213,8 +213,36 @@ class RnnTest(jtu.JaxTestCase):
|
||||
|
||||
k = jax.random.split(jax.random.PRNGKey(1), 4)
|
||||
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
|
||||
stablehlo)
|
||||
if jtu.jaxlib_version() <= (0, 5, 2):
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
|
||||
stablehlo)
|
||||
else:
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"',
|
||||
stablehlo)
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_no_workspace_overflow(self):
|
||||
if jtu.jaxlib_version() <= (0, 5, 2):
|
||||
self.skipTest("Older versions fail because of integer overflow.")
|
||||
|
||||
# Problem sizes known to cause overflows on older versions.
|
||||
batch_size, max_seq_length, input_size = 256, 500, 512
|
||||
num_layers, hidden_size = 1, 256
|
||||
num_params = rnn.get_num_params_in_lstm(
|
||||
input_size, hidden_size, num_layers, True)
|
||||
x = jax.ShapeDtypeStruct(
|
||||
(batch_size, max_seq_length, input_size), jnp.float32)
|
||||
h_0 = jax.ShapeDtypeStruct(
|
||||
(2 * num_layers, batch_size, hidden_size), jnp.float32)
|
||||
c_0 = jax.ShapeDtypeStruct(
|
||||
(2 * num_layers, batch_size, hidden_size), jnp.float32)
|
||||
weights = jax.ShapeDtypeStruct((num_params,), jnp.float32)
|
||||
seq_lengths = jax.ShapeDtypeStruct((batch_size,), jnp.int32)
|
||||
fun = jax.jit(partial(
|
||||
rnn.lstm, input_size=input_size, hidden_size=hidden_size,
|
||||
num_layers=num_layers, dropout=0.0, bidirectional=True))
|
||||
fun.lower(x, h_0, c_0, weights, seq_lengths) # Doesn't crash.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -2445,7 +2445,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
|
||||
xs = jnp.ones((5, 3))
|
||||
xs = jnp.ones((20, 3))
|
||||
c = jnp.ones(4)
|
||||
|
||||
scan = lambda c, xs: lax.scan(f, c, xs)
|
||||
@ -2502,6 +2502,28 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
x, n = jnp.arange(3), jnp.arange(4)
|
||||
jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
|
||||
|
||||
def test_disable_jit_while_loop_with_mutation(self):
|
||||
# https://github.com/jax-ml/jax/issues/27019
|
||||
|
||||
def body_fun(carry):
|
||||
x, y = carry
|
||||
x += 1 # in-place if x is mutable
|
||||
return x, y + x
|
||||
|
||||
def cond_fun(carry):
|
||||
x, _ = carry
|
||||
return x < 10
|
||||
|
||||
def f():
|
||||
val = np.array(1.0) # mutable value
|
||||
return jax.lax.while_loop(cond_fun, body_fun, (val, val))[1]
|
||||
|
||||
with jax.disable_jit(False):
|
||||
result_jit = f()
|
||||
with jax.disable_jit(True):
|
||||
result_nojit = f()
|
||||
self.assertEqual(result_jit, result_nojit)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{shape}_{axis=}",
|
||||
"shape": shape, "axis": axis}
|
||||
|
@ -278,6 +278,35 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
|
||||
with jax.checking_leaks():
|
||||
lsp_special.expi(jnp.ones(()))
|
||||
|
||||
def testExpiDisableJit(self):
|
||||
# Regression test for https://github.com/jax-ml/jax/issues/27019
|
||||
x = jnp.array([-0.5])
|
||||
with jax.disable_jit(True):
|
||||
result_nojit = lsp_special.expi(x)
|
||||
with jax.disable_jit(False):
|
||||
result_jit = lsp_special.expi(x)
|
||||
self.assertAllClose(result_jit, result_nojit)
|
||||
|
||||
def testGammaIncBoundaryValues(self):
|
||||
dtype = jax.numpy.zeros(0).dtype # default float dtype.
|
||||
nan = float('nan')
|
||||
inf = float('inf')
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)]
|
||||
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
|
||||
self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol)
|
||||
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
|
||||
|
||||
def testGammaIncCBoundaryValues(self):
|
||||
dtype = jax.numpy.zeros(0).dtype # default float dtype.
|
||||
nan = float('nan')
|
||||
inf = float('inf')
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)]
|
||||
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
|
||||
self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol)
|
||||
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -39,7 +39,10 @@ jax_multiplatform_test(
|
||||
],
|
||||
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
||||
shard_count = 16,
|
||||
tags = ["multiaccelerator"],
|
||||
tags = [
|
||||
"multiaccelerator",
|
||||
"noasan", # Times out.
|
||||
],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
|
@ -210,7 +210,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
for layout in [
|
||||
mgpu.WGSplatFragLayout(shape),
|
||||
mgpu.WGStridedFragLayout(shape, vec_size=4),
|
||||
mgpu.WGMMAFragLayout(),
|
||||
mgpu.TILED_LAYOUT_WGMMA,
|
||||
]
|
||||
)
|
||||
def test_infer_layout_from_yield_op_in_layouts_for_for_op(
|
||||
@ -278,7 +278,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout())
|
||||
wgmma_layout = layouts.to_layout_attr(mgpu.TILED_LAYOUT_WGMMA)
|
||||
self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout])
|
||||
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
|
||||
self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout])
|
||||
@ -312,7 +312,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
mgpu.WGStridedFragLayout((32, 4), vec_size=1),
|
||||
mgpu.WGMMAFragLayout(),
|
||||
mgpu.TILED_LAYOUT_WGMMA,
|
||||
)
|
||||
def test_infer_layout_picks_non_splat_layout_over_splat_layout(
|
||||
self, layout
|
||||
|
@ -1026,7 +1026,7 @@ class TCGen05Test(TestCase):
|
||||
in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32
|
||||
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
|
||||
m=(256,), # TODO(apaszke): 64, 192, 256
|
||||
n=(128, 256), # TODO(apaszke): 512, 192, other non-power-of-2
|
||||
n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
|
||||
k_steps=(1, 2),
|
||||
swizzle=(32, 64, 128,),
|
||||
)
|
||||
|
@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x_ref):
|
||||
self.assertEqual(core.get_ty(x_ref).sharding.spec,
|
||||
core.get_ty(x_ref[...]).sharding.spec)
|
||||
self.assertEqual(core.typeof(x_ref).sharding.spec,
|
||||
core.typeof(x_ref[...]).sharding.spec)
|
||||
y = x_ref[...] + 1
|
||||
return y
|
||||
|
||||
|
@ -184,6 +184,23 @@ class PallasCallTest(PallasTest):
|
||||
y = jnp.flip(x).reshape(1, 256)
|
||||
np.testing.assert_array_equal(kernel(x, y), x + y[0])
|
||||
|
||||
@parameterized.product(
|
||||
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
)
|
||||
def test_reduce_sum(self, shape, thread_semantics):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape)
|
||||
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), jnp.sum(x))
|
||||
|
||||
def test_reshape(self):
|
||||
shape1, shape2 = (128,), (2, 16, 4)
|
||||
|
||||
@ -200,10 +217,14 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
|
||||
|
||||
def test_add_xy_indexed(self):
|
||||
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
|
||||
def test_add_xy_indexed(self, thread_semantics):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
def kernel(x_ref, y_ref, o_ref):
|
||||
idx = _sum_same_dtype(y_ref[...])
|
||||
@ -1078,10 +1099,14 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
self.assertIn("acc % 2", output())
|
||||
|
||||
def test_cond_returning_array(self):
|
||||
@parameterized.parameters([*plgpu.ThreadSemantics])
|
||||
def test_cond_returning_array(self, thread_semantics):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
thread_semantics=thread_semantics
|
||||
),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
acc = _sum_same_dtype(x_ref[...])
|
||||
|
@ -470,6 +470,27 @@ class OpsTest(PallasBaseTest):
|
||||
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
|
||||
np.testing.assert_array_equal(out, expected)
|
||||
|
||||
def test_reduce_with_const(self):
|
||||
m = 1
|
||||
d = 1024
|
||||
x = jnp.ones((m, d), jnp.bfloat16)
|
||||
|
||||
def dot(x, y):
|
||||
return jax.lax.dot_general(
|
||||
x,
|
||||
y,
|
||||
(((1,), (1,)), ((), ())),
|
||||
preferred_element_type=jnp.float32,
|
||||
)
|
||||
|
||||
def kernel(x, out):
|
||||
out[:] = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
|
||||
|
||||
run = pl.pallas_call(kernel, jax.ShapeDtypeStruct((m, 1), jnp.float32))
|
||||
output = run(x)
|
||||
expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
|
||||
np.testing.assert_array_equal(output, expected)
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
@ -64,10 +64,6 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
|
||||
max_num_seq = max(len(seq_lens), max_num_seq)
|
||||
max_kv_len = max(kv_lens)
|
||||
pages_per_seq = ceil_div(max_kv_len, page_size)
|
||||
pages_per_seq = (
|
||||
ceil_div(pages_per_seq, num_kv_pages_per_block)
|
||||
* num_kv_pages_per_block
|
||||
)
|
||||
num_q_heads, num_kv_heads = num_heads
|
||||
|
||||
cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
|
||||
@ -130,8 +126,8 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
|
||||
num_seqs=num_seqs,
|
||||
)
|
||||
tols = {
|
||||
"float32": 1e-1,
|
||||
"bfloat16": 2e-1,
|
||||
"float32": 0.15,
|
||||
"bfloat16": 0.2,
|
||||
}
|
||||
tol = tols[jnp.dtype(dtype).name]
|
||||
self.assertAllClose(output, expected, atol=tol, rtol=tol)
|
||||
|
@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
# Eager mode
|
||||
|
@ -125,6 +125,80 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
c, jnp.array([[1, 3, 0, 0], [2, 2, 4, 0]], dtype=jnp.int32)
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=2)
|
||||
),
|
||||
)
|
||||
def test_ragged_all_to_all_grad(self, axis_name, mesh_axes):
|
||||
device_type = jax.devices()[0].platform
|
||||
if device_type == 'tpu' and jtu.get_tpu_version() < 4:
|
||||
raise unittest.SkipTest(
|
||||
'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU'
|
||||
f' v{jtu.get_tpu_version()}'
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
operand = jax.device_put(
|
||||
jnp.array([[1, 2, 2], [3, 4, 0]], dtype=jnp.float32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output = jax.device_put(
|
||||
jnp.zeros((2, 4), dtype=jnp.float32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
input_offsets = jax.device_put(
|
||||
jnp.array([[0, 1], [0, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
send_sizes = jax.device_put(
|
||||
jnp.array([[1, 2], [1, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
output_offsets = jax.device_put(
|
||||
jnp.array([[0, 0], [1, 2]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
recv_sizes = jax.device_put(
|
||||
jnp.array([[1, 1], [2, 1]], dtype=jnp.int32),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
)
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=(
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
P(axis_name, None),
|
||||
),
|
||||
out_specs=P(axis_name),
|
||||
check_rep=False,
|
||||
)
|
||||
def fwd(
|
||||
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
):
|
||||
operand = operand.reshape(operand.shape[1:])
|
||||
output = output.reshape(output.shape[1:])
|
||||
input_offsets = input_offsets.reshape(input_offsets.shape[1:])
|
||||
send_sizes = send_sizes.reshape(send_sizes.shape[1:])
|
||||
output_offsets = output_offsets.reshape(output_offsets.shape[1:])
|
||||
recv_sizes = recv_sizes.reshape(recv_sizes.shape[1:])
|
||||
return lax.ragged_all_to_all(
|
||||
operand,
|
||||
output,
|
||||
input_offsets,
|
||||
send_sizes,
|
||||
output_offsets,
|
||||
recv_sizes,
|
||||
axis_name=axis_name,
|
||||
)
|
||||
|
||||
args = input_offsets, send_sizes, output_offsets, recv_sizes
|
||||
jtu.check_grads(lambda op, out: fwd(op, out, *args), (operand, output), order=1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=4)
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "fae64d49aa41e774922ca46e94cd754c800b6240"
|
||||
XLA_SHA256 = "846ce8037cc0cba5135bff0bfd6fd02810e72b42ce0928002c595c97bf7b3603"
|
||||
XLA_COMMIT = "c270a6ce45df7f7bb3024f2e4df56b688d76ebd6"
|
||||
XLA_SHA256 = "b2f7d0293fc62bb670d0b58c5847108652eac4d9e6c7e420bed2029e74af6f2d"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user