mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge branch 'main' into scipy-expon
This commit is contained in:
commit
7fc605f783
.github/workflows
bazel_cpu_rbe.ymlbazel_cuda_rbe.ymlci-build.yamlcloud-tpu-ci-presubmit.ymljax-array-api.ymltsan.yamlupstream-nightly.ymlwheel_win_x64.ymlwindows_ci.yml
CHANGELOG.mdWORKSPACEbuild
docs
examples/ffi
jax
_src
abstract_arrays.pyad_checkpoint.pyapi.pyapi_util.pyarray.pycheckify.pyconfig.pycore.py
core.pycudnn
custom_batching.pycustom_dce.pycustom_derivatives.pycustom_partitioning.pycustom_transpose.pyexport
internal_test_util/export_back_compat_test_data
interpreters
lax
lib
linear_util.pymesh.pynumpy
pallas
pjit.pyrandom.pystate
test_util.pyxla_bridge.pyexperimental
interpreters
tools
version.pyjax_plugins/cuda
jaxlib
19
.github/workflows/bazel_cpu_rbe.yml
vendored
19
.github/workflows/bazel_cpu_rbe.yml
vendored
@ -14,10 +14,15 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
# Don't cancel in-progress jobs for main/release branches.
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
@ -26,14 +31,22 @@ jobs:
|
||||
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
|
||||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
|
||||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
strategy:
|
||||
matrix:
|
||||
python: ["3.10", "3.13"]
|
||||
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
|
||||
enable-x_64: [1, 0]
|
||||
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
|
||||
exclude:
|
||||
# Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have
|
||||
# coverage for one of each, we don't need to run both.
|
||||
- python: "3.10"
|
||||
enable-x_64: 1
|
||||
- python: "3.13"
|
||||
enable-x_64: 0
|
||||
name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
|
||||
# End Presubmit Naming Check github-cpu-presubmits
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
19
.github/workflows/bazel_cuda_rbe.yml
vendored
19
.github/workflows/bazel_cuda_rbe.yml
vendored
@ -14,10 +14,15 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
# Don't cancel in-progress jobs for main/release branches.
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
@ -25,14 +30,22 @@ jobs:
|
||||
runs-on: ${{ matrix.runner }}
|
||||
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
|
||||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
strategy:
|
||||
matrix:
|
||||
python: ["3.10", "3.13"]
|
||||
runner: ["linux-x86-n2-16"]
|
||||
enable-x_64: [1, 0]
|
||||
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
|
||||
exclude:
|
||||
# Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have
|
||||
# coverage for one of each, we don't need to run both.
|
||||
- python: "3.10"
|
||||
enable-x_64: 1
|
||||
- python: "3.13"
|
||||
enable-x_64: 0
|
||||
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
|
||||
# End Presubmit Naming Check github-cuda-presubmits
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
82
.github/workflows/ci-build.yaml
vendored
82
.github/workflows/ci-build.yaml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: 3.11
|
||||
- run: python -m pip install pre-commit
|
||||
@ -70,22 +70,13 @@ jobs:
|
||||
apt update
|
||||
apt install -y libssl-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install .[minimum-jaxlib] -r build/test-requirements.txt
|
||||
pip install uv
|
||||
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
@ -117,22 +108,13 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r docs/requirements.txt
|
||||
pip install uv
|
||||
uv pip install --system -r docs/requirements.txt
|
||||
- name: Test documentation
|
||||
env:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
@ -140,7 +122,7 @@ jobs:
|
||||
JAX_ARRAY: 1
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
|
||||
|
||||
|
||||
@ -160,22 +142,13 @@ jobs:
|
||||
apt update
|
||||
apt install -y libssl-dev libsqlite3-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r docs/requirements.txt
|
||||
pip install uv
|
||||
uv pip install --system -r docs/requirements.txt
|
||||
- name: Render documentation
|
||||
run: |
|
||||
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
|
||||
@ -195,22 +168,13 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
|
||||
pip install uv
|
||||
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
@ -236,23 +200,15 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
|
||||
- name: Install JAX
|
||||
run: pip install .[cuda12]
|
||||
run: |
|
||||
pip install uv
|
||||
uv pip install --system .[cuda12]
|
||||
- name: Build and install example project
|
||||
run: python -m pip install -v ./examples/ffi[test]
|
||||
run: uv pip install --system ./examples/ffi[test]
|
||||
env:
|
||||
# We test building using GCC instead of clang. All other JAX builds use
|
||||
# clang, but it is useful to make sure that FFI users can compile using
|
||||
|
7
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
7
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
@ -17,6 +17,10 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
# This should also be set to read-only in the project settings, but it's nice to
|
||||
# document and enforce the permissions here.
|
||||
@ -25,7 +29,8 @@ permissions:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
# Don't cancel in-progress jobs for main/release branches.
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
cloud-tpu-test:
|
||||
|
7
.github/workflows/jax-array-api.yml
vendored
7
.github/workflows/jax-array-api.yml
vendored
@ -32,13 +32,14 @@ jobs:
|
||||
submodules: 'true'
|
||||
path: 'array-api-tests'
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install .[ci]
|
||||
python -m pip install pytest-xdist -r array-api-tests/requirements.txt
|
||||
pip install uv
|
||||
uv pip install --system .[ci]
|
||||
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
|
||||
- name: Run the test suite
|
||||
env:
|
||||
ARRAY_API_TESTS_MODULE: jax.numpy
|
||||
|
9
.github/workflows/tsan.yaml
vendored
9
.github/workflows/tsan.yaml
vendored
@ -6,7 +6,7 @@ concurrency:
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 12 * * *" # Daily at 12:00 UTC
|
||||
- cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST
|
||||
workflow_dispatch: # allows triggering the workflow run manually
|
||||
pull_request: # Automatically trigger on pull requests affecting this file
|
||||
branches:
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
# Create archive to be used with bazel as hermetic python:
|
||||
cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan
|
||||
|
||||
- name: Save CPython with TSAN
|
||||
- name: Save TSAN CPython
|
||||
id: cache-cpython-tsan-save
|
||||
if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
@ -102,9 +102,11 @@ jobs:
|
||||
# If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz
|
||||
if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then
|
||||
echo "Extract cpython from python-tsan.tgz"
|
||||
pushd .
|
||||
ls ${GITHUB_WORKSPACE}/python-tsan.tgz
|
||||
cd ${GITHUB_WORKSPACE} && tar -xvzf python-tsan.tgz
|
||||
cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz
|
||||
ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/
|
||||
popd
|
||||
fi
|
||||
|
||||
export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH
|
||||
@ -172,7 +174,6 @@ jobs:
|
||||
--clang_path=/usr/bin/clang-18
|
||||
|
||||
# Update the patch to use TSAN instrumented numpy
|
||||
|
||||
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
|
||||
cat .github/workflows/requirements_lock_3_13_ft.patch
|
||||
|
||||
|
2
.github/workflows/upstream-nightly.yml
vendored
2
.github/workflows/upstream-nightly.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install JAX test requirements
|
||||
|
2
.github/workflows/wheel_win_x64.yml
vendored
2
.github/workflows/wheel_win_x64.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.pyver }}
|
||||
cache: 'pip'
|
||||
|
2
.github/workflows/windows_ci.yml
vendored
2
.github/workflows/windows_ci.yml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
with:
|
||||
path: jax
|
||||
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: ${{ matrix.pyver }}
|
||||
cache: 'pip'
|
||||
|
@ -22,6 +22,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
|
||||
details.
|
||||
|
||||
* Changes
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
|
||||
env vars. Before they could only be specified via jax.config or flags.
|
||||
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
|
||||
This package may safely be removed if it is present on your machine; JAX now
|
||||
uses `libtpu` instead.
|
||||
|
||||
## jax 0.5.0 (Jan 17, 2025)
|
||||
|
||||
As of this release, JAX now uses
|
||||
|
15
WORKSPACE
15
WORKSPACE
@ -62,6 +62,21 @@ xla_workspace0()
|
||||
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
|
||||
flatbuffers()
|
||||
|
||||
load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
|
||||
jax_python_wheel_repository(
|
||||
name = "jax_wheel",
|
||||
version_key = "_version",
|
||||
version_source = "//jax:version.py",
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/py:python_wheel.bzl",
|
||||
"python_wheel_version_suffix_repository",
|
||||
)
|
||||
python_wheel_version_suffix_repository(
|
||||
name = "jax_wheel_version_suffix",
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
|
||||
"cuda_json_init_repository",
|
||||
|
@ -7,7 +7,8 @@ flatbuffers
|
||||
hypothesis
|
||||
mpmath>=1.3
|
||||
pillow>=10.4.0
|
||||
portpicker
|
||||
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
|
||||
portpicker; python_version<"3.13"
|
||||
pytest-xdist
|
||||
wheel
|
||||
rich
|
||||
|
@ -146,7 +146,7 @@
|
||||
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
|
||||
"to `bind`, and in particular we follow a handy internal convention: when we\n",
|
||||
"call `bind`, we pass values representing array data as positional arguments,\n",
|
||||
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
|
||||
"and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This\n",
|
||||
"calling convention simplifies some core logic (since e.g. instances of the\n",
|
||||
"`Tracer` class to be defined below can only occur in positional arguments to\n",
|
||||
"`bind`). The wrappers can also provide docstrings!\n",
|
||||
|
@ -133,7 +133,7 @@ The functions that user code calls, like `add` and `sin`, are just wrappers
|
||||
around calls to `bind`. These wrappers let us control how arguments are passed
|
||||
to `bind`, and in particular we follow a handy internal convention: when we
|
||||
call `bind`, we pass values representing array data as positional arguments,
|
||||
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
|
||||
and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
|
||||
calling convention simplifies some core logic (since e.g. instances of the
|
||||
`Tracer` class to be defined below can only occur in positional arguments to
|
||||
`bind`). The wrappers can also provide docstrings!
|
||||
|
@ -123,7 +123,7 @@ def bind1(prim, *args, **params):
|
||||
# around calls to `bind`. These wrappers let us control how arguments are passed
|
||||
# to `bind`, and in particular we follow a handy internal convention: when we
|
||||
# call `bind`, we pass values representing array data as positional arguments,
|
||||
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
|
||||
# and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
|
||||
# calling convention simplifies some core logic (since e.g. instances of the
|
||||
# `Tracer` class to be defined below can only occur in positional arguments to
|
||||
# `bind`). The wrappers can also provide docstrings!
|
||||
|
@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N
|
||||
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
|
||||
in subsequent runs as a result of a compilation cache miss.
|
||||
|
||||
### Pre-compiling multi-node programs on single node
|
||||
|
||||
JAX can populate the compilation cache with compiled programs for multiple nodes
|
||||
on a single node. Preparing the cache on a single node helps to decrease the costly
|
||||
compilation time on a cluster. To compile and run multi-node programs on a single
|
||||
node, users can create fake remote devices using
|
||||
the `jax_mock_gpu_topology` configuration option.
|
||||
|
||||
For instance, the snippet below instructs JAX to mock a cluster with four
|
||||
nodes, each node running eight processes with each process attached to one GPU.
|
||||
|
||||
```python
|
||||
jax.config.update("jax_mock_gpu_topology", "4x8x1")
|
||||
```
|
||||
|
||||
After populating the cache with this config, users can run the program
|
||||
without recompilation on four nodes, eight processes per node,
|
||||
one GPU per process.
|
||||
|
||||
Important notes:
|
||||
|
||||
* The process running the mocked program must have the same amount of GPUs
|
||||
and the same GPU model as the nodes that would use the cache. For instance,
|
||||
a mocked topology `8x4x2` must run in a process with two GPUs.
|
||||
|
||||
* When running programs with mocked topology, the results of communications
|
||||
with other nodes are undefined, so the outputs of JAX programs running
|
||||
in mocked environments will likely be incorrect.
|
||||
|
||||
|
||||
## Logging cache activity
|
||||
|
||||
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.
|
||||
|
@ -13,12 +13,12 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
set(
|
||||
JAX_FFI_EXAMPLE_PROJECTS
|
||||
JAX_FFI_EXAMPLE_CPU_PROJECTS
|
||||
"rms_norm"
|
||||
"cpu_examples"
|
||||
)
|
||||
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
|
||||
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
|
||||
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
|
||||
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
@ -26,9 +26,16 @@ endforeach()
|
||||
|
||||
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
|
||||
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_STANDARD 17)
|
||||
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
|
||||
nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
|
||||
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
|
||||
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
|
||||
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
endif()
|
||||
|
62
examples/ffi/src/jax_ffi_example/gpu_examples.cc
Normal file
62
examples/ffi/src/jax_ffi_example/gpu_examples.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2025 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
struct State {
|
||||
static xla::ffi::TypeId id;
|
||||
explicit State(int32_t value) : value(value) {}
|
||||
int32_t value;
|
||||
};
|
||||
ffi::TypeId State::id = {};
|
||||
|
||||
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
|
||||
return std::make_unique<State>(42);
|
||||
}
|
||||
|
||||
static ffi::Error StateExecute(cudaStream_t stream, State* state,
|
||||
ffi::ResultBufferR0<ffi::S32> out) {
|
||||
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
|
||||
ffi::Ffi::BindInstantiate());
|
||||
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
||||
.Ctx<ffi::State<State>>()
|
||||
.Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
NB_MODULE(_gpu_examples, m) {
|
||||
m.def("type_id",
|
||||
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
|
||||
m.def("handler", []() {
|
||||
nb::dict d;
|
||||
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
|
||||
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
|
||||
return d;
|
||||
});
|
||||
}
|
24
examples/ffi/src/jax_ffi_example/gpu_examples.py
Normal file
24
examples/ffi/src/jax_ffi_example/gpu_examples.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from jax_ffi_example import _gpu_examples
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
|
||||
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")
|
||||
|
||||
|
||||
def read_state():
|
||||
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()
|
41
examples/ffi/tests/gpu_examples_test.py
Normal file
41
examples/ffi/tests/gpu_examples_test.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class GpuExamplesTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cuda"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
# Import here to avoid trying to load the library when it's not built.
|
||||
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top
|
||||
|
||||
self.read_state = gpu_examples.read_state
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(self.read_state(), 42)
|
||||
self.assertEqual(jax.jit(self.read_state)(), 42)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -45,8 +45,9 @@ array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-
|
||||
|
||||
|
||||
def masked_array_error(*args, **kwargs):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
raise ValueError(
|
||||
"numpy masked arrays are not supported as direct inputs to JAX functions."
|
||||
" Use arr.filled() to convert the value to a standard numpy array.")
|
||||
|
||||
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
|
||||
|
||||
@ -54,7 +55,8 @@ core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
|
||||
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
dtype = x.dtype
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
|
||||
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype),
|
||||
sharding=core.get_cur_mesh_sharding(core.P(*[None] * x.ndim)))
|
||||
|
||||
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
|
||||
@ -62,7 +64,9 @@ core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
|
||||
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
|
||||
dtype = np.dtype(x)
|
||||
dtypes.check_valid_dtype(dtype)
|
||||
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
|
||||
shape = np.shape(x)
|
||||
return ShapedArray(shape, dtypes.canonicalize_dtype(dtype),
|
||||
sharding=core.get_cur_mesh_sharding(core.P(*[None] * len(shape))))
|
||||
|
||||
for t in numpy_scalar_types:
|
||||
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
|
||||
@ -74,7 +78,8 @@ def _make_abstract_python_scalar(typ, val):
|
||||
# Note: all python scalar types are weak except bool, because bool only
|
||||
# comes in a single width.
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
|
||||
weak_type=typ is not bool)
|
||||
weak_type=typ is not bool,
|
||||
sharding=core.get_cur_mesh_sharding())
|
||||
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
|
||||
|
@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
debug = api_util.tracing_debug_info(
|
||||
debug = api_util.debug_info(
|
||||
"checkpoint / remat", fun,
|
||||
args, kwargs, static_argnums=static_argnums)
|
||||
fun_, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
@ -418,11 +418,11 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
|
||||
def _trace_to_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: lu.TracingDebugInfo
|
||||
debug: core.DebugInfo
|
||||
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
|
||||
try:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
except core.ConcretizationTypeError as e:
|
||||
msg, = e.args
|
||||
if 'for checkpoint' in msg:
|
||||
@ -447,7 +447,7 @@ def saved_residuals(f: Callable,
|
||||
args, kwargs = tree_unflatten(in_tree, args)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
|
||||
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
|
||||
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
|
||||
return_shape=True)(*in_leaves)
|
||||
assert isinstance(out, tuple)
|
||||
@ -699,7 +699,8 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
|
||||
assert next(ins_iter, None) is None
|
||||
with source_info_util.extend_name_stack('rematted_computation'):
|
||||
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
|
||||
in_pvals, False)
|
||||
|
||||
# Transpose the linear jaxpr (which only has linear inputs).
|
||||
out_cts_iter = iter(out_cts_flat)
|
||||
|
@ -57,11 +57,11 @@ from jax._src import pjit
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
|
||||
result_paths, flat_out_axes)
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||
flat_out_axes)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
|
||||
f"{max_argnum + 1} positional arguments to be passed by the caller, "
|
||||
f"but got only {len(args)} positional arguments.")
|
||||
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
|
||||
dbg = debug_info('value_and_grad', fun, args, kwargs)
|
||||
|
||||
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
@ -1021,7 +1021,7 @@ def _mapped_axis_spec(args_flat, in_axes_flat):
|
||||
try:
|
||||
# Duck type arrays like BCOO arrays can be passed to vmap.
|
||||
return shaped_abstractify(arg).sharding.spec[i]
|
||||
except TypeError:
|
||||
except (IndexError, TypeError):
|
||||
return None
|
||||
|
||||
temp_spec = None
|
||||
@ -1426,11 +1426,12 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
dbg = debug_info(
|
||||
"pmap", fun, args, kwargs,
|
||||
static_argnums=static_broadcasted_tuple)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f = lu.wrap_init(fun, debug_info=dbg)
|
||||
del dbg
|
||||
if static_broadcasted_tuple:
|
||||
if max(static_broadcasted_tuple) >= len(args):
|
||||
raise ValueError(
|
||||
@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
|
||||
raise ValueError(msg) from None
|
||||
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
|
||||
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg.add_result_paths(res_paths)
|
||||
f = lu.add_debug_info(f, dbg)
|
||||
f, out_axes_thunk = flat_out_axes(f, out_axes)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
@ -2235,7 +2233,7 @@ def _check_sharding(aval, s):
|
||||
f" invalid value: {s}")
|
||||
if isinstance(s, Sharding):
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
aval = core.token_shaped_array
|
||||
aval = core.get_token_aval()
|
||||
if not isinstance(s, PmapSharding):
|
||||
pjit.pjit_check_aval_sharding(
|
||||
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
|
||||
|
@ -31,7 +31,6 @@ from jax._src.tree_util import (
|
||||
prefix_errors)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.linear_util import TracingDebugInfo
|
||||
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
|
||||
Unhashable, safe_zip)
|
||||
from jax._src import traceback_util
|
||||
@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
|
||||
return fun
|
||||
|
||||
|
||||
def tracing_debug_info(
|
||||
def debug_info(
|
||||
traced_for: str,
|
||||
fun: Callable,
|
||||
args: Sequence[Any],
|
||||
@ -591,17 +590,17 @@ def tracing_debug_info(
|
||||
static_argnums: tuple[int, ...] = (),
|
||||
static_argnames: tuple[str, ...] = (),
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing.
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing?
|
||||
sourceinfo: str | None = None,
|
||||
signature: inspect.Signature | None = None,
|
||||
) -> TracingDebugInfo:
|
||||
) -> core.DebugInfo:
|
||||
if sourceinfo is None:
|
||||
sourceinfo = fun_sourceinfo(fun)
|
||||
if signature is None:
|
||||
signature = fun_signature(fun)
|
||||
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
|
||||
static_argnames)
|
||||
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
||||
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
|
||||
|
||||
|
||||
def fun_signature(fun: Callable) -> inspect.Signature | None:
|
||||
@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
|
||||
|
||||
# TODO(mattjj): make this function internal to this module
|
||||
def fun_sourceinfo(fun: Callable) -> str:
|
||||
# See TracingDebugInfo.fun_src_info
|
||||
# See DebugInfo.fun_src_info
|
||||
res = getattr(fun, "__fun_sourceinfo__", None)
|
||||
if res is not None: return res
|
||||
while isinstance(fun, partial):
|
||||
@ -675,30 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
|
||||
arg_names = args_arg_names + kwargs_arg_names
|
||||
return arg_names
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(_fun, _store, *args, **kwargs):
|
||||
"linear_util transform to get output pytree paths of pre-flattened function."
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
|
||||
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
|
||||
trace_debug: TracingDebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
) -> core.Jaxpr:
|
||||
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
|
||||
if trace_debug is None:
|
||||
return jaxpr
|
||||
# TODO(necula): re-enable this safety check
|
||||
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
|
||||
if result_paths is None:
|
||||
result_paths = trace_debug.result_paths_thunk() # type: ignore
|
||||
debug_info = core.JaxprDebugInfo(
|
||||
trace_debug.traced_for, trace_debug.func_src_info,
|
||||
trace_debug.arg_names, tuple(result_paths)) # type: ignore
|
||||
return jaxpr.replace(debug_info=debug_info)
|
||||
|
||||
def hoist_obj_attrs(f, flat_args):
|
||||
idxs, objs, flat_args_ = [], [], []
|
||||
for i, x in enumerate(flat_args):
|
||||
@ -723,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
|
||||
_class_with_attrs: set[type] = set()
|
||||
|
||||
# TODO(mattjj): make this function faster
|
||||
def _check_no_aliased_ref_args(dbg, avals, args):
|
||||
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
|
||||
assert config.mutable_array_checks.value
|
||||
refs: dict[int, int] = {}
|
||||
for i, (a, x) in enumerate(zip(avals, args)):
|
||||
@ -737,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
|
||||
if dbg else
|
||||
f"at both flat index {dup_idx} and flat index {i}") from None
|
||||
|
||||
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
|
||||
assert config.mutable_array_checks.value
|
||||
refs: set[int] = {id(core.get_referent(c)) for c in consts
|
||||
if isinstance(core.get_aval(c), AbstractRef)}
|
||||
@ -748,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
|
||||
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
|
||||
f"array reference of type {a.str_short()} was both closed over and "
|
||||
f"passed as the argument "
|
||||
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
|
||||
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")
|
||||
|
@ -39,6 +39,7 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
@ -55,7 +56,10 @@ PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
|
||||
|
||||
def _get_device(a: ArrayImpl) -> Device:
|
||||
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
|
||||
assert len(devices) == 1
|
||||
if len(devices) != 1:
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays the input arrays must "
|
||||
f"have one shard each. An argument array had {len(devices)} shard(s).")
|
||||
return devices[0]
|
||||
|
||||
|
||||
@ -195,54 +199,102 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
self.aval = aval
|
||||
self._sharding = sharding
|
||||
self._arrays = [a._arrays[0] for a in arrays]
|
||||
self._committed = committed
|
||||
self._npy_value = None
|
||||
arrays = [a._arrays[0] for a in arrays]
|
||||
|
||||
# Don't rearrange if skip_checks is enabled because this assumes that the
|
||||
# input buffers are already arranged properly. This usually happens when
|
||||
# Array's are created as output of a JAX transformation
|
||||
# (like pjit, etc).
|
||||
if not _skip_checks or config.enable_checks.value:
|
||||
self._check_and_rearrange()
|
||||
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
|
||||
self._arrays = arrays # type: ignore
|
||||
|
||||
def _check_and_rearrange(self):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
if xla_extension_version >= 310:
|
||||
def _check_and_rearrange(self, arrays, sharding, aval):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(self._arrays)}")
|
||||
addressable_dev = sharding.addressable_devices
|
||||
if len(arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
if len(array_device_ids) != len(arrays):
|
||||
buffer_device_ids = [_get_device(db).id for db in arrays]
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays, the input arrays"
|
||||
" must be from distinct devices, but got device IDs"
|
||||
f" {buffer_device_ids}")
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
self._arrays,
|
||||
sharding=self.sharding,
|
||||
aval=self.aval,
|
||||
expected_shape=self.sharding.shard_shape(self.shape),
|
||||
)
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
arrays,
|
||||
sharding=sharding,
|
||||
aval=aval,
|
||||
expected_shape=sharding.shard_shape(aval.shape),
|
||||
)
|
||||
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = sharding._addressable_device_assignment
|
||||
return [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
else:
|
||||
def _check_and_rearrange(self): # type: ignore
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(self._arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
self._arrays,
|
||||
sharding=self.sharding,
|
||||
aval=self.aval,
|
||||
expected_shape=self.sharding.shard_shape(self.shape),
|
||||
)
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
@ -1220,7 +1272,7 @@ pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
||||
|
||||
def _token_global_result_handler(global_aval, out_sharding, committed):
|
||||
array_handler = _array_global_result_handler(
|
||||
core.token_shaped_array, out_sharding, committed)
|
||||
core.get_token_aval(), out_sharding, committed)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
out_buf = array_handler(*args, **kwargs)
|
||||
|
@ -35,6 +35,7 @@ from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import pjit
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
@ -966,7 +967,8 @@ def shard_map_error_check(
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
in_avals[i] = sharder(mesh, new_in_names[i], v)
|
||||
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
with (core.extend_axis_env_nd(mesh.shape.items()),
|
||||
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh))):
|
||||
# jaxpr to checked_jaxpr
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
|
||||
@ -1202,11 +1204,11 @@ def checkify(f: Callable[..., Out],
|
||||
in_tree = jtu.tree_structure(((), {}))
|
||||
closed_f = lambda: f(*args, **kwargs)
|
||||
# stage:
|
||||
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
|
||||
debug = api_util.debug_info("checkify", f, args, kwargs)
|
||||
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
|
||||
debug_info=debug),
|
||||
in_tree)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
|
||||
|
@ -1716,3 +1716,21 @@ memory_fitting_effort = float_state(
|
||||
default=0.0,
|
||||
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
|
||||
)
|
||||
|
||||
cpu_collectives_implementation = optional_enum_state(
|
||||
name='jax_cpu_collectives_implementation',
|
||||
enum_values=["gloo", "mpi", "megascale"],
|
||||
default=None,
|
||||
help=(
|
||||
"Cross-process collective implementation used on CPU. Must be one of "
|
||||
'("gloo", "mpi")'),
|
||||
)
|
||||
|
||||
num_cpu_devices = int_state(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help=(
|
||||
"Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized."),
|
||||
)
|
||||
|
152
jax/_src/core.py
152
jax/_src/core.py
@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet
|
||||
no_effects: Effects = effects.no_effects
|
||||
|
||||
|
||||
# TODO(necula): make this an extension of TracingDebugInfo
|
||||
class JaxprDebugInfo(NamedTuple):
|
||||
# An extension of lu.TracingDebugInfo; see comments there
|
||||
traced_for: str
|
||||
func_src_info: str
|
||||
arg_names: tuple[str | None, ...]
|
||||
# This is formed after tracing, when we have concrete `result_paths`
|
||||
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
if len(self.arg_names) == expected:
|
||||
return self.arg_names
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the result_paths with a safety check."""
|
||||
if len(self.result_paths) == expected:
|
||||
return self.result_paths
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
DebugInfo = lu.DebugInfo
|
||||
|
||||
class Jaxpr:
|
||||
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
|
||||
@ -117,7 +93,7 @@ class Jaxpr:
|
||||
_outvars: list[Atom]
|
||||
_eqns: list[JaxprEqn]
|
||||
_effects: Effects
|
||||
_debug_info: JaxprDebugInfo | None
|
||||
_debug_info: DebugInfo | None
|
||||
|
||||
@property
|
||||
def constvars(self) -> list[Var]:
|
||||
@ -140,13 +116,13 @@ class Jaxpr:
|
||||
return self._effects
|
||||
|
||||
@property
|
||||
def debug_info(self) -> JaxprDebugInfo | None:
|
||||
def debug_info(self) -> DebugInfo | None:
|
||||
return self._debug_info
|
||||
|
||||
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
|
||||
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
|
||||
effects: Effects = no_effects,
|
||||
debug_info: JaxprDebugInfo | None = None):
|
||||
debug_info: DebugInfo | None = None):
|
||||
"""
|
||||
Args:
|
||||
constvars: list of variables introduced for constants. Array constants are
|
||||
@ -157,14 +133,14 @@ class Jaxpr:
|
||||
eqns: list of equations.
|
||||
effects: set of effects. The effects on a jaxpr are a superset of the
|
||||
union of the effects for each equation.
|
||||
debug_info: optional JaxprDebugInfo.
|
||||
debug_info: optional DebugInfo.
|
||||
"""
|
||||
self._constvars = list(constvars)
|
||||
self._invars = list(invars)
|
||||
self._outvars = list(outvars)
|
||||
self._eqns = list(eqns)
|
||||
self._effects = effects
|
||||
self._debug_info = debug_info
|
||||
self._debug_info = debug_info and debug_info.resolve_result_paths()
|
||||
# TODO(necula): re-enable these safety checks
|
||||
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
|
||||
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
|
||||
@ -505,6 +481,8 @@ class Primitive:
|
||||
map_primitive: bool = False
|
||||
# set for ref primitives
|
||||
ref_primitive: bool = False
|
||||
# set for primitives that can skip canonicalization of values
|
||||
skip_canonicalization: bool = False
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
@ -513,6 +491,12 @@ class Primitive:
|
||||
return f'{self.name}'
|
||||
|
||||
def bind(self, *args, **params):
|
||||
if not config.sharding_in_types.value:
|
||||
return self._true_bind(*args, **params)
|
||||
args = args if self.skip_canonicalization else map(canonicalize_value, args)
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def _true_bind(self, *args, **params):
|
||||
for arg in args:
|
||||
if (isinstance(arg, Tracer)
|
||||
and not arg._trace.is_valid()
|
||||
@ -610,8 +594,8 @@ def check_avals_context_mesh(avals, prim_name):
|
||||
if config.sharding_in_types.value:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
for a in avals:
|
||||
if a.sharding.mesh.empty or cur_mesh.empty:
|
||||
continue
|
||||
# avals can have meshes with different axis_names so allow that in
|
||||
# full auto mode.
|
||||
if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto:
|
||||
continue
|
||||
if a.sharding.mesh != cur_mesh:
|
||||
@ -621,21 +605,6 @@ def check_avals_context_mesh(avals, prim_name):
|
||||
" error occurs at source: "
|
||||
f" {source_info_util.summarize(source_info_util.current())}")
|
||||
|
||||
# TODO(yashkatariya, dougalm): Remove this and replace with canonicalize_value
|
||||
# function which casts scalar, numpy arrays, etc to jax arrays so that values
|
||||
# passed to primitives are always have avals, etc i.e. they are canonical and
|
||||
# also does mesh casting, etc
|
||||
def cast_from_auto_to_manual(avals):
|
||||
if not config.sharding_in_types.value:
|
||||
return avals
|
||||
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
return [a.update(sharding=NamedSharding(cur_mesh, P(*[None] * a.ndim)))
|
||||
if (not a.sharding.mesh.empty and cur_mesh._are_all_axes_manual and
|
||||
a.sharding.mesh._are_all_axes_auto)
|
||||
else a for a in avals]
|
||||
|
||||
# -------------------- tracing --------------------
|
||||
|
||||
TracerType = TypeVar('TracerType', bound='Tracer')
|
||||
@ -1775,6 +1744,38 @@ def _make_lengths_same(sharding, ndim):
|
||||
assert False, "unreachable"
|
||||
|
||||
|
||||
# TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values
|
||||
# passed to primitives are always have avals, etc i.e. they are canonical.
|
||||
def canonicalize_value(val):
|
||||
if not config.sharding_in_types.value:
|
||||
return val
|
||||
|
||||
from jax._src.pjit import NamedSharding, mesh_cast # type: ignore
|
||||
|
||||
try:
|
||||
aval = get_aval(val)
|
||||
except TypeError:
|
||||
return val
|
||||
if not isinstance(aval, ShapedArray):
|
||||
return val
|
||||
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh == aval.sharding.mesh: # type: ignore
|
||||
return val
|
||||
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: # type: ignore
|
||||
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
|
||||
if aval.sharding.mesh.empty and not cur_mesh.empty: # type: ignore
|
||||
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
|
||||
return val
|
||||
|
||||
|
||||
def get_cur_mesh_sharding(spec=None):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
spec = P() if spec is None else spec
|
||||
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def modify_spec_for_auto_manual(spec, mesh) -> P:
|
||||
@ -1791,13 +1792,16 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
|
||||
return P(*new_spec)
|
||||
|
||||
def _maybe_modify_sharding(sharding, ndim):
|
||||
if len(sharding.spec) == 0 or all(s is None for s in sharding.spec):
|
||||
if len(sharding.spec) != ndim:
|
||||
return _make_lengths_same(sharding, ndim)
|
||||
return sharding
|
||||
|
||||
if sharding.mesh._are_all_axes_explicit:
|
||||
out = sharding
|
||||
elif all(s is None for s in sharding.spec):
|
||||
out = sharding
|
||||
else:
|
||||
out = sharding.with_spec(modify_spec_for_auto_manual(
|
||||
sharding.spec, sharding.mesh))
|
||||
return sharding
|
||||
|
||||
out = sharding.with_spec(modify_spec_for_auto_manual(
|
||||
sharding.spec, sharding.mesh))
|
||||
if (len(out.spec) != ndim and
|
||||
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
|
||||
out = _make_lengths_same(out, ndim)
|
||||
@ -1807,18 +1811,14 @@ def _maybe_modify_sharding(sharding, ndim):
|
||||
def get_sharding(sharding, ndim):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
if sharding is not None:
|
||||
out_s = _maybe_modify_sharding(sharding, ndim)
|
||||
if len(out_s.spec) != ndim:
|
||||
raise ValueError(
|
||||
"Length of sharding.spec must be equal to aval's ndim. Got"
|
||||
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
|
||||
else:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh.empty:
|
||||
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
out_s = NamedSharding(cur_mesh, P(*[None] * ndim))
|
||||
if sharding is None:
|
||||
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
|
||||
|
||||
out_s = _maybe_modify_sharding(sharding, ndim)
|
||||
if len(out_s.spec) != ndim:
|
||||
raise ValueError(
|
||||
"Length of sharding.spec must be equal to aval's ndim. Got"
|
||||
f" sharding.spec {out_s.spec}, aval.ndim {ndim} and sharding {out_s}")
|
||||
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError("Mesh of an aval must be an AbstractMesh. "
|
||||
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
|
||||
@ -2112,7 +2112,8 @@ class AbstractToken(AbstractValue):
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
|
||||
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
|
||||
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
|
||||
def get_token_aval():
|
||||
return ShapedArray((0,), np.dtype(np.bool_), sharding=get_cur_mesh_sharding())
|
||||
|
||||
# Concrete token object
|
||||
class Token:
|
||||
@ -2377,7 +2378,8 @@ def dim_constant(ct: int):
|
||||
return np.int64(ct)
|
||||
|
||||
def dim_value_aval() -> AbstractValue:
|
||||
return ShapedArray((), dim_value_dtype(), weak_type=True)
|
||||
return ShapedArray((), dim_value_dtype(), weak_type=True,
|
||||
sharding=get_cur_mesh_sharding())
|
||||
|
||||
# ------------------- Call -------------------
|
||||
|
||||
@ -2385,6 +2387,9 @@ class CallPrimitive(Primitive):
|
||||
multiple_results = True
|
||||
call_primitive = True
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, fun_and_args, params):
|
||||
fun = fun_and_args[0]
|
||||
args = fun_and_args[1:]
|
||||
@ -2393,7 +2398,8 @@ class CallPrimitive(Primitive):
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
|
||||
jaxpr, ())
|
||||
if config.dynamic_shapes.value:
|
||||
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
|
||||
return [subfun], new_params
|
||||
@ -2425,8 +2431,11 @@ class MapPrimitive(Primitive):
|
||||
multiple_results = True
|
||||
map_primitive = True
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, fun_and_args, params):
|
||||
fun = fun_and_args[0]
|
||||
fun: lu.WrappedFun = fun_and_args[0]
|
||||
args = fun_and_args[1:]
|
||||
assert len(params['in_axes']) == len(args)
|
||||
return trace.process_map(self, fun, args, params)
|
||||
@ -2436,8 +2445,9 @@ class MapPrimitive(Primitive):
|
||||
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
|
||||
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info), jaxpr, ())
|
||||
axes = new_params.pop('out_axes')
|
||||
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
|
@ -119,9 +119,11 @@ def element_type_to_backend_config_type_mapping(dtype):
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
def get_max_seg_per_batch(q_offsets):
|
||||
return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1
|
||||
|
||||
def create_dot_product_attention_backend_config_base(
|
||||
batch, num_heads, seq_q, seq_kv, dtype,fmha_scale, mask_type, layout, is_bwd
|
||||
batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd
|
||||
):
|
||||
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
|
||||
# P: BMM1 output in shape of BNTS
|
||||
@ -226,6 +228,7 @@ def create_dot_product_attention_backend_config(
|
||||
mask_type,
|
||||
layout,
|
||||
sliding_window_length,
|
||||
max_seg_per_batch,
|
||||
is_bwd
|
||||
):
|
||||
backend_config = create_dot_product_attention_backend_config_base(
|
||||
@ -237,6 +240,7 @@ def create_dot_product_attention_backend_config(
|
||||
backend_config['cudnn_fmha_backend_config']["dropout_rate"] = dropout_rate
|
||||
backend_config['cudnn_fmha_backend_config']["seed"] = seed
|
||||
backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length
|
||||
backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch
|
||||
return json.dumps(backend_config)
|
||||
|
||||
def create_dot_product_attention_fp8_backend_config(
|
||||
@ -268,7 +272,8 @@ get_fp8_custom_call_name = functools.partial(
|
||||
get_custom_call_name, has_bias=False, has_dropout=False, is_fp8=True
|
||||
)
|
||||
|
||||
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
|
||||
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
|
||||
q_offsets, kv_offsets, layout):
|
||||
def check_eq(a, b, c, msg):
|
||||
if not (a == b == c):
|
||||
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
|
||||
@ -300,36 +305,36 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
|
||||
if kS != vS:
|
||||
raise ValueError(f"KV must have same seq length, got {kS} vs {vS}")
|
||||
|
||||
# check bias/q_seqlen/kv_seqlen
|
||||
# check bias
|
||||
if bias is not None:
|
||||
_, _, bT, bS = bias.shape
|
||||
if bT != qT or bS != vS:
|
||||
raise ValueError(
|
||||
f"Bias must have same seq length as QKV, got {bT} and {bS}")
|
||||
if q_seqlen is not None:
|
||||
q_seq_dtype = q_seqlen.dtype
|
||||
q_seq_rank = len(q_seqlen.shape)
|
||||
if q_seq_dtype != jnp.int32:
|
||||
raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}")
|
||||
if q_seq_rank != 1:
|
||||
raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}")
|
||||
q_seq_b = q_seqlen.shape[0]
|
||||
if q_seq_b != qB:
|
||||
raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}")
|
||||
if kv_seqlen is not None:
|
||||
kv_seq_dtype = kv_seqlen.dtype
|
||||
kv_seq_rank = len(kv_seqlen.shape)
|
||||
if kv_seq_dtype != jnp.int32:
|
||||
raise ValueError(
|
||||
f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}")
|
||||
if kv_seq_rank != 1:
|
||||
raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}")
|
||||
kv_seq_b = kv_seqlen.shape[0]
|
||||
if kv_seq_b != qB:
|
||||
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
|
||||
|
||||
# check q_seqlen/kv_seqlen/q_offsets/kv_offsets
|
||||
expected_rank = 2 if q_offsets is not None else 1
|
||||
def check_seqlen_offsets(tensor, name):
|
||||
if tensor is not None:
|
||||
dtype = tensor.dtype
|
||||
rank = len(tensor.shape)
|
||||
if dtype != jnp.int32:
|
||||
raise ValueError(f"{name} must have int32 datatype, got {dtype}")
|
||||
if rank != expected_rank:
|
||||
raise ValueError(f"{name} must have a rank of {expected_rank}, got {rank}")
|
||||
b = tensor.shape[0]
|
||||
if b != qB:
|
||||
raise ValueError(f"{name} must have same batch as Q, got {b}")
|
||||
|
||||
check_seqlen_offsets(q_seqlen, "q_seqlen")
|
||||
check_seqlen_offsets(kv_seqlen, "kv_seqlen")
|
||||
check_seqlen_offsets(q_offsets, "q_offsets")
|
||||
check_seqlen_offsets(kv_offsets, "kv_offsets")
|
||||
|
||||
|
||||
def check_is_flash_attention(
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_fp8=False):
|
||||
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
|
||||
is_fp8=False):
|
||||
# Extract sequence length (T) and head dim (H) based on layout
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
_, _, T, H = query.shape
|
||||
@ -363,6 +368,9 @@ def check_is_flash_attention(
|
||||
f"Unsupported sequence length Q {T}, KV {S}."
|
||||
)
|
||||
|
||||
if is_packed and cudnn_version < 90600:
|
||||
raise NotImplementedError("Packed layout requires cudnn version >= 9.6.")
|
||||
|
||||
def check_cudnn_version():
|
||||
# check if cuDNN is installed
|
||||
if cuda_versions is None:
|
||||
@ -378,78 +386,142 @@ def check_compute_capability(capability):
|
||||
return current >= target
|
||||
|
||||
def _dot_product_attention_fwd(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, cudnn_version):
|
||||
# check if flash attention is supported for this attention pattern
|
||||
check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, bias is not None, False)
|
||||
query, key, layout, cudnn_version, bias is not None, False,
|
||||
get_max_seg_per_batch(q_offsets) > 1)
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=False)
|
||||
output = outputs[0]
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, cudnn_version):
|
||||
# check if flash attention is supported for this attention pattern
|
||||
check_is_flash_attention(
|
||||
query, key, layout, cudnn_version, bias is not None, True)
|
||||
query, key, layout, cudnn_version, bias is not None, True,
|
||||
get_max_seg_per_batch(q_offsets) > 1)
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=True)
|
||||
res = (query, key, value, bias, q_seqlen, kv_seqlen,
|
||||
outputs[1], outputs[0])
|
||||
res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
|
||||
kv_offsets, outputs[1], outputs[0])
|
||||
return outputs[0], res
|
||||
|
||||
def _dot_product_attention_bwd_rule(
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training, res, grad_output):
|
||||
(query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output) = res
|
||||
(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output) = res
|
||||
grads = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length
|
||||
)
|
||||
grads = (*grads,) + (None,) * (6 - len(grads))
|
||||
grads = (*grads,) + (None,) * (8 - len(grads))
|
||||
return grads
|
||||
|
||||
def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key):
|
||||
# fix seqlen and offsets to what cuDNN expects in sequence packing.
|
||||
# cuDNN expects seqlen to have shape [S] where S is the total number of segments
|
||||
# while the SDPA API accetps seqlen with shape [B, M] where B is the batch and M
|
||||
# is the maximum number of segments of one batch. B x M is larger than S and seqlen
|
||||
# is filled with -1 for padded regions. Therefore, we need to shift all non negative
|
||||
# values to left side to form a correct seqlen. Similar layout is required for
|
||||
# offsets tensors.
|
||||
# cuDNN expects offsets to have offset for each segment starting from first segment
|
||||
# while SDPA API accetps offsets to have offset for each segment starting from
|
||||
# current batch, therefore we need to calculate accumulative offset of each segment
|
||||
# starting from first segment.
|
||||
def _shift_to_left(x, fill_value):
|
||||
# shift any non-negative value to left
|
||||
# [[1, 3, -1, -1], [2, 3, 4, -1]]
|
||||
# -> [[1, 3, 2, 3], [4, -1, -1, -1]]
|
||||
x_shape = x.shape
|
||||
x = x.flatten()
|
||||
size = x.size
|
||||
indices = jnp.nonzero(x >= 0, size=size, fill_value=size)[0]
|
||||
y = jnp.take(x, indices, fill_value=fill_value)
|
||||
return jnp.reshape(y, x_shape)
|
||||
|
||||
def _cu_offset(offsets, max_seq):
|
||||
# calculate accumulative offset by batch
|
||||
# [[1, 3, 5, 7], [4, 5, -1, -1]], max_seq = 8
|
||||
# -> [[1, 3, 5, 7], [12, 13, -1, -1]]
|
||||
batch = offsets.shape[0]
|
||||
offsets = jnp.where(
|
||||
offsets >= 0,
|
||||
offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis],
|
||||
offsets,
|
||||
)
|
||||
return offsets
|
||||
|
||||
if get_max_seg_per_batch(q_offsets) > 1:
|
||||
B, T, N, H = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
|
||||
q_seqlen = _shift_to_left(q_seqlen, -1)
|
||||
kv_seqlen = _shift_to_left(kv_seqlen, -1)
|
||||
|
||||
q_offsets = _cu_offset(q_offsets, T)
|
||||
kv_offsets = _cu_offset(kv_offsets, S)
|
||||
q_offsets = _shift_to_left(q_offsets, -1)
|
||||
kv_offsets = _shift_to_left(kv_offsets, -1)
|
||||
|
||||
# mark any invalid entries as maximum offset
|
||||
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
|
||||
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
|
||||
|
||||
# multiply by stride_per_token to get correct offsets
|
||||
# do it here because real stride changes after sharding
|
||||
q_offsets = q_offsets * N * H
|
||||
kv_offsets = kv_offsets * N * H
|
||||
|
||||
return q_seqlen, kv_seqlen, q_offsets, kv_offsets
|
||||
|
||||
def _dot_product_attention_fwd_impl(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
|
||||
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
|
||||
outputs = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=is_training)
|
||||
return outputs
|
||||
|
||||
def _dot_product_attention_bwd_impl(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
|
||||
grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length):
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
|
||||
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
|
||||
grads = _dot_product_attention_bwd_p.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
*, scale, seed, dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
if layout == AttentionLayout.BNTH.value:
|
||||
@ -459,7 +531,9 @@ def _dot_product_attention_fwd_abstract(
|
||||
B, T, N, _ = query.shape
|
||||
_, S, _, _ = key.shape
|
||||
output_shape = query.shape
|
||||
softmax_stat_shape = (B, N, T)
|
||||
|
||||
max_seg_per_batch = get_max_seg_per_batch(q_offsets)
|
||||
softmax_stat_shape = (B * max_seg_per_batch, N, T)
|
||||
|
||||
if is_training:
|
||||
return (
|
||||
@ -472,9 +546,9 @@ def _dot_product_attention_fwd_abstract(
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
|
||||
grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type,
|
||||
layout, sliding_window_length):
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, *, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
@ -511,9 +585,9 @@ def _dot_product_attention_bwd_abstract(
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout,
|
||||
sliding_window_length, is_training):
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
|
||||
kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type,
|
||||
layout, sliding_window_length, is_training):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -530,24 +604,30 @@ def _dot_product_attention_fwd_cuda_lowering(
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
|
||||
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
|
||||
output_shape = (B, N, T, H)
|
||||
softmax_stat_shape = (B, N, T)
|
||||
softmax_stat_shape = (B * max_seg_per_batch, N, T)
|
||||
workspace_shape = (0,)
|
||||
workspace_type = ir.IntegerType.get_unsigned(8)
|
||||
|
||||
has_bias, _ = variadic_args
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
mask_type, layout, sliding_window_length, is_bwd=False,
|
||||
)
|
||||
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*}
|
||||
mask_type, layout, sliding_window_length, max_seg_per_batch,
|
||||
is_bwd=False)
|
||||
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}}
|
||||
# {output, activation*, workspace}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, _ = variadic_args
|
||||
operands = [query, key, value]
|
||||
if has_bias:
|
||||
operands.append(bias)
|
||||
if has_padding(mask_type):
|
||||
if has_padding(mask_type) or max_seg_per_batch > 1:
|
||||
operands.append(q_seqlen)
|
||||
operands.append(kv_seqlen)
|
||||
if max_seg_per_batch > 1:
|
||||
operands.append(q_offsets)
|
||||
operands.append(kv_offsets)
|
||||
|
||||
custom_call_name = get_custom_call_name(has_bias, has_dropout, False)
|
||||
|
||||
if is_training:
|
||||
@ -581,9 +661,9 @@ def _dot_product_attention_fwd_cuda_lowering(
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm)]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length):
|
||||
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale, seed, dropout_rate,
|
||||
variadic_args, mask_type, layout, sliding_window_length):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
@ -607,23 +687,29 @@ def _dot_product_attention_bwd_cuda_lowering(
|
||||
grad_query_shape = (B, q_N, T, H)
|
||||
grad_key_shape = (B, k_N, S, H)
|
||||
grad_value_shape = (B, k_N, S, H)
|
||||
|
||||
has_bias, has_dbias = variadic_args
|
||||
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
|
||||
backend_config = create_dot_product_attention_backend_config(
|
||||
B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate,
|
||||
mask_type, layout, sliding_window_length, is_bwd=True,
|
||||
)
|
||||
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*}
|
||||
mask_type, layout, sliding_window_length, max_seg_per_batch,
|
||||
is_bwd=True)
|
||||
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*,
|
||||
# q_offsets*, kv_offsets*}
|
||||
# {dQ, dK, dV, dbias*, workspace}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_dbias = variadic_args
|
||||
# create operands
|
||||
operands = [query, key, value, activation, grad_output]
|
||||
if has_bias:
|
||||
# flash attention requires bias in the bwd for remat
|
||||
operands.append(bias)
|
||||
operands.append(fwd_output)
|
||||
if has_padding(mask_type):
|
||||
if has_padding(mask_type) or max_seg_per_batch > 1:
|
||||
operands.append(q_seqlen)
|
||||
operands.append(kv_seqlen)
|
||||
if max_seg_per_batch > 1:
|
||||
operands.append(q_offsets)
|
||||
operands.append(kv_offsets)
|
||||
# get custom call name
|
||||
custom_call_name = get_custom_call_name(has_bias, has_dropout, True)
|
||||
|
||||
@ -674,7 +760,8 @@ def _dot_product_attention_fwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length, is_training):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, q_seqlen, kv_seqlen = batched_args
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, \
|
||||
q_offsets, kv_offsets = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
if is_training:
|
||||
out_bdims = query_bdim, query_bdim
|
||||
@ -701,9 +788,9 @@ def _dot_product_attention_fwd_batcher(
|
||||
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
|
||||
|
||||
outputs = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length, is_training=is_training)
|
||||
|
||||
# reshape to original shape
|
||||
@ -720,8 +807,8 @@ def _dot_product_attention_bwd_batcher(
|
||||
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
|
||||
mask_type, layout, sliding_window_length):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, q_seqlen, \
|
||||
kv_seqlen, activation, fwd_output, grad_output = batched_args
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \
|
||||
activation, fwd_output, grad_output = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim, query_bdim
|
||||
|
||||
@ -757,8 +844,8 @@ def _dot_product_attention_bwd_batcher(
|
||||
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
|
||||
|
||||
grads = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, activation,
|
||||
fwd_output, grad_output, scale=scale, seed=seed,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
activation, fwd_output, grad_output, scale=scale, seed=seed,
|
||||
dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length,
|
||||
@ -834,7 +921,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo
|
||||
return [out_sharding]
|
||||
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(
|
||||
_dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
|
||||
_dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
|
||||
@ -883,7 +970,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args):
|
||||
return out_shardings
|
||||
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(
|
||||
_dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15)
|
||||
_dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17)
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(
|
||||
@ -1003,13 +1090,15 @@ dispatch.prim_requires_devices_during_lowering.add(
|
||||
_dot_product_attention_bwd_p_wrapper
|
||||
)
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
q_seqlen: Array,
|
||||
kv_seqlen: Array,
|
||||
q_offsets: Array,
|
||||
kv_offsets: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
@ -1019,9 +1108,10 @@ def _dot_product_attention(query: Array,
|
||||
sliding_window_length: int | None,
|
||||
cudnn_version: int):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
|
||||
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length,
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
|
||||
sliding_window_length=sliding_window_length,
|
||||
cudnn_version=cudnn_version)
|
||||
return output
|
||||
|
||||
@ -1612,6 +1702,8 @@ def dot_product_attention(
|
||||
mask: Array | None = None,
|
||||
q_seqlen: Array | None = None,
|
||||
kv_seqlen: Array | None = None,
|
||||
q_offsets: Array | None = None,
|
||||
kv_offsets: Array | None = None,
|
||||
fp8_params: FP8Params | None = None,
|
||||
*,
|
||||
scale: float = 1.0,
|
||||
@ -1647,8 +1739,26 @@ def dot_product_attention(
|
||||
value: Values to be used in attention with a shape of BSNH or BNSH.
|
||||
bias: Bias to be added to logits with a shape of BNTS.
|
||||
mask: Mask used to filter out logits with a shape of BNTS.
|
||||
q_seqlen: Non padded sequence length of Queries with a shape of B.
|
||||
kv_seqlen: Non padded sequence length of Keys and Values with a shape of B.
|
||||
q_seqlen: Non padded sequence length of query with a shape of B.
|
||||
If q_offsets is set, q_seqlen should have shape [B,M] where M is the
|
||||
maximum number of segments per batch. For batch that has less segments
|
||||
than maximum segments, fill the padded entries with -1.
|
||||
kv_seqlen: Non padded sequence length of key and value with a shape of B.
|
||||
If kv_offsets is set, kv_seqlen should have shape [B,M] where M is the
|
||||
maximum number of segments per batch. For batch that has less segments
|
||||
than maximum segments, fill the padded entries with -1.
|
||||
q_offsets: offset of each segment packed in query with a shape of [B,M+1]
|
||||
where M is the maximum number of segments per batch. For batch that has
|
||||
less segments than maximum segments, fill the padded entries with -1.
|
||||
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
|
||||
size 1, q_offsets = [[0,1,2,-1], [0,1,-1,-1]]. q_seqlen should be set
|
||||
to indicate the size of each segment.
|
||||
kv_offsets: offset of each segment packed in key with a shape of [B,M+1]
|
||||
where M is the maximum number of segments per batch. For batch that has
|
||||
less segments than maximum segments, fill the padded entries with -1.
|
||||
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
|
||||
size 1, kv_offsets = [[0,1,2,-1], [0,1,-1,-1]]. kv_seqlen should be set
|
||||
to indicate the size of each segment.
|
||||
scale: Scale for the query.
|
||||
dropout_rate: Dropout rate.
|
||||
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
|
||||
@ -1679,7 +1789,7 @@ def dot_product_attention(
|
||||
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
|
||||
)
|
||||
check_fp8_params(fp8_params)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
|
||||
output, amax_s, amax_o = _dot_product_attention_fp8(
|
||||
query, key, value, fp8_params,
|
||||
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version
|
||||
@ -1691,6 +1801,8 @@ def dot_product_attention(
|
||||
if sliding_window_length is not None and sliding_window_length <= 0:
|
||||
raise ValueError(
|
||||
f"Require sliding_window_length > 0, got {sliding_window_length}")
|
||||
if q_offsets is not None and (q_seqlen is None or kv_seqlen is None):
|
||||
raise ValueError("Require q_seqlen and kv_seqlen to use packed layout")
|
||||
|
||||
if bias is not None:
|
||||
# reshape bias to have 4D shape
|
||||
@ -1712,7 +1824,7 @@ def dot_product_attention(
|
||||
bias = bias + mask
|
||||
|
||||
# check if input shape and data type is compatiable
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
|
||||
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
|
||||
has_bias = bias is not None
|
||||
has_dbias = has_bias and \
|
||||
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
|
||||
@ -1724,8 +1836,12 @@ def dot_product_attention(
|
||||
q_seqlen = jnp.zeros(0, dtype=query.dtype)
|
||||
if kv_seqlen is None:
|
||||
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
|
||||
if q_offsets is None:
|
||||
q_offsets = jnp.zeros(0, dtype=query.dtype)
|
||||
if kv_offsets is None:
|
||||
kv_offsets = jnp.zeros(0, dtype=query.dtype)
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
|
||||
dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length,
|
||||
cudnn_version)
|
||||
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
|
||||
scale, seed, dropout_rate, variadic_args, mask_type, layout.value,
|
||||
sliding_window_length, cudnn_version)
|
||||
return output
|
||||
|
@ -147,13 +147,13 @@ class custom_vmap:
|
||||
raise AttributeError(
|
||||
f"No batching rule defined for custom_vmap function {fun_name} "
|
||||
"using def_vmap.")
|
||||
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
|
||||
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(self.fun, debug_info=debug),
|
||||
in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
assert self.vmap_rule is not None
|
||||
|
@ -127,12 +127,12 @@ class custom_dce:
|
||||
"def_dce."
|
||||
)
|
||||
rule_name = util.fun_name(self.dce_rule)
|
||||
debug = api_util.tracing_debug_info("custom_dce", self.fun,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug = api_util.debug_info("custom_dce", self.fun,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
|
||||
args, {},
|
||||
static_argnums=self.static_argnums)
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
@ -147,11 +147,11 @@ class custom_dce:
|
||||
)
|
||||
static_args = [args[i] for i in self.static_argnums]
|
||||
dce_rule = api_util.prepend_static_args(
|
||||
lu.wrap_init(self.dce_rule), static_args
|
||||
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
|
||||
)
|
||||
else:
|
||||
fun = lu.wrap_init(self.fun, debug_info=debug)
|
||||
dce_rule = lu.wrap_init(self.dce_rule)
|
||||
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
|
||||
dyn_args = args
|
||||
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
@ -176,7 +176,7 @@ class custom_dce:
|
||||
)
|
||||
assert self.dce_rule is not None
|
||||
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_rule, in_avals, debug_rule
|
||||
flat_rule, in_avals
|
||||
)
|
||||
|
||||
# This second round of DCE is used to work out which inputs are actually
|
||||
@ -191,7 +191,7 @@ class custom_dce:
|
||||
|
||||
return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
|
||||
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
out_avals = closed_call.out_avals
|
||||
out_flat = custom_dce_p.bind(
|
||||
@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
|
||||
# that most users of this API would compose this with a custom_jvp or
|
||||
# custom_vjp, which makes this less urgent.
|
||||
out = core.call_p.bind(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
|
||||
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
|
||||
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
|
||||
)
|
||||
|
||||
out_primals, out_tangents = util.split_list(out, [len(out_nz)])
|
||||
|
@ -348,6 +348,9 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args
|
||||
class CustomJVPCallPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, args, params):
|
||||
fun, jvp, tracers = args[0], args[1], args[2:]
|
||||
return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params)
|
||||
@ -866,6 +869,9 @@ def _temporary_shape_exception(a, a_) -> bool:
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
initial_style: core.Primitive
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, args, params):
|
||||
fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:]
|
||||
return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params)
|
||||
|
@ -308,7 +308,7 @@ class custom_partitioning:
|
||||
that describes the sharding rule, or a Callable that produces either of
|
||||
these. We borrow the idea from the einops.rearrange string , to use a space
|
||||
separator between factors and allow multiple letters factor names. See
|
||||
[jax-shardy-guide](https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb)
|
||||
`jax-shardy-guide <https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb>`_
|
||||
for more details and examples on how to use this.
|
||||
|
||||
When config.use_shardy_partitioner.value is True, `sharding_rule` is used;
|
||||
@ -468,9 +468,9 @@ class custom_partitioning:
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
|
||||
args, kwargs,
|
||||
static_argnums=self.static_argnums)
|
||||
debug = api_util.debug_info("custom_partitioning", self.fun,
|
||||
args, kwargs,
|
||||
static_argnums=self.static_argnums)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
|
||||
@ -485,13 +485,13 @@ class custom_partitioning:
|
||||
_check_for_tracers(static_args)
|
||||
else:
|
||||
static_args = []
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
|
@ -155,6 +155,9 @@ class CustomTransposePrimitive(core.Primitive):
|
||||
map_primitive = False
|
||||
multiple_results = True
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, call_args, params):
|
||||
call, tracers = call_args[0], call_args[1:]
|
||||
return trace.process_custom_transpose(self, call, tracers, **params)
|
||||
|
@ -1042,7 +1042,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
*_CPU_FFI_KERNELS,
|
||||
*_GPU_FFI_KERNELS,
|
||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||
"cu_threefry2x32", "cu_threefry2x32_ffi",
|
||||
"cu_threefry2x32_ffi",
|
||||
# Triton IR does not guarantee stability.
|
||||
# "__gpu$xla.gpu.triton",
|
||||
# cholesky on CPU
|
||||
|
File diff suppressed because one or more lines are too long
@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
|
||||
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
||||
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
|
||||
residual_avals = map(get_aval, consts)
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
@ -143,11 +143,12 @@ def linearize_jaxpr(
|
||||
return _linearize_jaxpr(jaxpr, tuple(nonzeros))
|
||||
|
||||
@weakref_lru_cache
|
||||
@source_info_util.reset_name_stack()
|
||||
def _linearize_jaxpr(
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
nonzeros: tuple[bool, ...]
|
||||
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
|
||||
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
|
||||
dbg = jaxpr.jaxpr.debug_info
|
||||
primal_trace = pe.DynamicJaxprTrace(dbg)
|
||||
tangent_trace = pe.DynamicJaxprTrace(dbg)
|
||||
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
|
||||
@ -166,16 +167,17 @@ def _linearize_jaxpr(
|
||||
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
|
||||
del lin_trace, ans, tracers, new_arg
|
||||
|
||||
debug_info = jaxpr.jaxpr.debug_info
|
||||
nzs_out = [type(t) is not Zero for t in out_tangents]
|
||||
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
|
||||
for (nz, t) in zip(nzs_out, out_tangents) if nz)
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
|
||||
tangent_trace.invalidate()
|
||||
if attrs_tracked:
|
||||
raise NotImplementedError("TODO: attrs")
|
||||
residuals_and_primals = (*tangent_consts, *out_primals)
|
||||
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
|
||||
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
|
||||
primal_trace.invalidate()
|
||||
num_residuals = len(tangent_consts)
|
||||
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
|
||||
@ -192,7 +194,8 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
tracers = [t.full_lower() for t in tracers]
|
||||
with core.set_current_trace(linearize_trace, check_leaks=True):
|
||||
with (core.set_current_trace(linearize_trace, check_leaks=True),
|
||||
source_info_util.transform_name_stack('jvp')):
|
||||
if has_aux:
|
||||
ans, aux = traceable.call_wrapped(*tracers)
|
||||
aux_primals = [x.primal
|
||||
@ -207,7 +210,7 @@ def direct_linearize(traceable: lu.WrappedFun,
|
||||
out_nzs = [type(t) is not Zero for t in out_tangents]
|
||||
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
|
||||
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
|
||||
tangent_trace.invalidate()
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
|
||||
pe.PartialVal.known(zeros_like_aval(t.aval))
|
||||
@ -587,6 +590,10 @@ class LinearizeTrace(Trace):
|
||||
self.tag = core.TraceTag() if tag is None else tag
|
||||
self.parent_trace = parent_trace
|
||||
self.tangent_trace = tangent_trace
|
||||
self._name_stack_prefix_len = len(source_info_util.current_name_stack())
|
||||
|
||||
def _name_stack_suffix(self):
|
||||
return source_info_util.current_name_stack()[self._name_stack_prefix_len:]
|
||||
|
||||
def to_primal_tangent_pair(self, val):
|
||||
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
|
||||
@ -605,7 +612,8 @@ class LinearizeTrace(Trace):
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
primal_out, tangent_nzs_out, residuals, linearized = lin(
|
||||
tangent_nzs, *primals_in, **params)
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
with (core.set_current_trace(self.tangent_trace),
|
||||
source_info_util.set_name_stack(self._name_stack_suffix())):
|
||||
tangent_out = linearized(residuals, *tangents_in)
|
||||
if primitive.multiple_results:
|
||||
return [maybe_linearize_tracer(self, x, nz, t)
|
||||
@ -1019,12 +1027,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
||||
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
debug_info = jaxpr.jaxpr.debug_info
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
|
||||
nonzeros)
|
||||
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
|
||||
f_jvp, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
|
@ -451,8 +451,6 @@ class AxisData:
|
||||
|
||||
|
||||
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
|
||||
if orig_sharding.mesh.empty:
|
||||
return None
|
||||
val = axis_data.explicit_mesh_axis
|
||||
new_spec = P(*tuple_insert(orig_sharding.spec, axis, val))
|
||||
return NamedSharding(orig_sharding.mesh, new_spec)
|
||||
@ -760,7 +758,8 @@ def _batch_jaxpr2(
|
||||
axis_data,
|
||||
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
|
||||
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
|
||||
debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
||||
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
||||
in_axes2, avals_in = unzip2([
|
||||
|
@ -54,11 +54,14 @@ from jax._src.sharding_impls import (AUTO, NamedSharding,
|
||||
SdyArraySharding, SdyArrayShardingList)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
|
||||
from jax._src.lib.mlir import register_jax_dialects
|
||||
from jax._src.state.types import AbstractRef
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
@ -469,11 +472,20 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
|
||||
if loc is None:
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
if xla_extension_version >= 309:
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
frame.end_line,
|
||||
frame.end_column,
|
||||
)
|
||||
else:
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
ctx.traceback_caches.location_cache[code_lasti] = loc
|
||||
frame_locs.append(loc)
|
||||
@ -1121,16 +1133,20 @@ def lower_jaxpr_to_module(
|
||||
"In multi-platform lowering either all or no lowering platforms "
|
||||
f"should support donation. Lowering for {platforms} of which "
|
||||
f"only {platforms_with_donation} support donation")
|
||||
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
|
||||
input_output_aliases, in_avals, out_avals, donated_args,
|
||||
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts,
|
||||
result_shardings if num_partitions > 1 else None)
|
||||
if (num_partitions > 1 and
|
||||
(result_shardings is None or
|
||||
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
|
||||
any(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
|
||||
for s in result_shardings))):
|
||||
xla_donated_args = donated_args
|
||||
donated_args = [False] * len(donated_args)
|
||||
if xla_donated_args is None:
|
||||
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
|
||||
input_output_aliases, in_avals, out_avals, donated_args,
|
||||
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts)
|
||||
if xla_donated_args is None:
|
||||
xla_donated_args = [False] * len(donated_args)
|
||||
for input_id in range(len(donated_args)):
|
||||
if donated_args[input_id]:
|
||||
xla_donated_args[input_id] = True
|
||||
donated_args[input_id] = False
|
||||
if any(donated_args):
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
@ -1225,14 +1241,15 @@ def lower_jaxpr_to_module(
|
||||
|
||||
def _set_up_aliases(input_output_aliases, avals_in, avals_out,
|
||||
donated_args, arg_memory_kinds, result_memory_kinds,
|
||||
in_layouts, out_layouts):
|
||||
in_layouts, out_layouts, result_shardings):
|
||||
if input_output_aliases is None:
|
||||
input_output_aliases = [None] * len(avals_in)
|
||||
else:
|
||||
input_output_aliases = list(input_output_aliases)
|
||||
# To match-up in-avals to out-avals we only care about the number of
|
||||
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
|
||||
strip_metadata = lambda a: a.strip_weak_type()
|
||||
strip_metadata = lambda a: (a if a is core.abstract_token else
|
||||
core.ShapedArray(a.shape, a.dtype))
|
||||
avals_in = map(strip_metadata, avals_in)
|
||||
avals_out = map(strip_metadata, avals_out)
|
||||
|
||||
@ -1283,7 +1300,10 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,
|
||||
" for the input and output layout to be chosen by XLA and not the"
|
||||
" layout of the input which might not be optimal.")
|
||||
if (in_layouts is None or out_layouts is None or
|
||||
in_layouts[input_id] == out_layouts[i]):
|
||||
in_layouts[input_id] == out_layouts[i]) and (
|
||||
result_shardings is None or not (
|
||||
(s := result_shardings[i]) is None or
|
||||
isinstance(s, AUTO) or contains_unconstrained(s))):
|
||||
input_output_aliases[input_id] = i
|
||||
else:
|
||||
# Fallback to xla donation if layouts don't match.
|
||||
@ -1393,7 +1413,6 @@ def lower_jaxpr_to_fun(
|
||||
MLIR func op
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_fun", name)
|
||||
|
||||
# The first dimension variable may be the platform index
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
||||
|
@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
|
||||
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent, JaxprEqnContext)
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src import tree_util
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
tree_flatten, tree_structure)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
@ -502,7 +501,7 @@ call_param_updaters[core.closed_call_p] = _closed_call_param_updater
|
||||
|
||||
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
||||
_, avals_out, _, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params), avals, debug_info)
|
||||
lu.wrap_init(fun, params, debug_info=debug_info), avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
return avals_out
|
||||
|
||||
@ -590,7 +589,7 @@ def trace_to_subjaxpr_nounits(
|
||||
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits2(
|
||||
f,
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
instantiate: bool | Sequence[bool]):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
|
||||
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
|
||||
cell = []
|
||||
def fun(*known_vals_in):
|
||||
@ -951,7 +950,9 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
|
||||
return [*known_vals_out, *residuals]
|
||||
|
||||
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
|
||||
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
|
||||
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, debug_info=f.debug_info),
|
||||
known_avals)
|
||||
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
|
||||
|
||||
# check jaxpr_known and jaxpr_unknown in isolation
|
||||
@ -1125,7 +1126,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
|
||||
known_outvars, known_eqns)
|
||||
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
|
||||
known_eqns, known_effects)
|
||||
known_eqns, known_effects, jaxpr.debug_info)
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
|
||||
|
||||
_, ins_staged = partition_list(in_inst, jaxpr.invars)
|
||||
@ -1334,10 +1335,10 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
|
||||
|
||||
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
|
||||
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
|
||||
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
jaxpr.debug_info.arg_names,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
return new_jaxpr
|
||||
@ -1422,10 +1423,10 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
eqns = new_eqns[::-1]
|
||||
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
|
||||
|
||||
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
|
||||
dbg = jaxpr.debug_info and core.DebugInfo(
|
||||
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
|
||||
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
|
||||
jaxpr.debug_info.filter_arg_names(used_inputs),
|
||||
jaxpr.debug_info.filter_result_paths(used_outputs))
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
|
||||
@ -1623,9 +1624,9 @@ class JaxprStackFrame:
|
||||
attrs_tracked: list[tuple[Any, str]]
|
||||
attrs_inits: list
|
||||
attrs_vars: list[Var]
|
||||
debug_info: lu.TracingDebugInfo | None
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
def __init__(self, debug_info: lu.TracingDebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
self.gensym = core.gensym()
|
||||
self.tracer_to_var = {}
|
||||
self.constid_to_tracer = {}
|
||||
@ -1642,7 +1643,9 @@ class JaxprStackFrame:
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
|
||||
def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
|
||||
def to_jaxpr(self, trace: DynamicJaxprTrace,
|
||||
out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None,
|
||||
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
@ -1655,7 +1658,8 @@ class JaxprStackFrame:
|
||||
outvars = state_outvars + explicit_outvars
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
|
||||
debug_info)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
|
||||
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
|
||||
@ -1809,7 +1813,7 @@ def _inline_literals(
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
|
||||
def __init__(self, debug_info: lu.TracingDebugInfo | None):
|
||||
def __init__(self, debug_info: core.DebugInfo | None):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
|
||||
def invalidate(self):
|
||||
@ -1948,7 +1952,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
|
||||
f, reduced_in_avals, f.debug_info)
|
||||
f, reduced_in_avals)
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
if ordered_effects:
|
||||
raise ValueError("Ordered effects not supported for "
|
||||
@ -2072,8 +2076,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
|
||||
return self.frame.to_jaxpr(self, out_tracers)
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer],
|
||||
debug_info: core.DebugInfo | None):
|
||||
return self.frame.to_jaxpr(self, out_tracers, debug_info)
|
||||
|
||||
|
||||
custom_staging_rules: dict[Primitive, Callable] = {}
|
||||
@ -2114,14 +2119,12 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
def trace_to_jaxpr_dynamic(
|
||||
fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: lu.TracingDebugInfo | None = None,
|
||||
*,
|
||||
keep_inputs: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
trace = DynamicJaxprTrace(debug_info)
|
||||
trace = DynamicJaxprTrace(fun.debug_info)
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
|
||||
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
@ -2129,15 +2132,15 @@ def trace_to_jaxpr_dynamic(
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
_check_no_returned_refs(debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
|
||||
_check_no_returned_refs(fun.debug_info, out_tracers)
|
||||
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
|
||||
del trace, fun, in_tracers, out_tracers, ans
|
||||
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
def _check_no_returned_refs(
|
||||
dbg: lu.TracingDebugInfo | None,
|
||||
dbg: core.DebugInfo | None,
|
||||
out_tracers: Sequence[DynamicJaxprTracer]
|
||||
) -> None:
|
||||
if not config.mutable_array_checks.value: return
|
||||
@ -2148,10 +2151,8 @@ def _check_no_returned_refs(
|
||||
raise ValueError(
|
||||
f"function returned a mutable array reference of type {a.str_short()}, "
|
||||
"but mutable array references cannot be returned.")
|
||||
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
|
||||
if (dbg.result_paths_thunk and
|
||||
(ls := dbg.result_paths_thunk()) and
|
||||
ls[i]) else '')
|
||||
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
|
||||
loc = f' at output tree path {result_paths[i]}'
|
||||
frame = t._trace.frame
|
||||
v = frame.tracer_to_var.get(id(t))
|
||||
eqn = next((e for e in frame.eqns if v in e.outvars), None)
|
||||
@ -2160,7 +2161,7 @@ def _check_no_returned_refs(
|
||||
origin_info = ('\n\nThe returned mutable array was created on line '
|
||||
f'{source_info_util.summarize(eqn.source_info)}.')
|
||||
elif v in frame.invars:
|
||||
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
|
||||
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
|
||||
origin_info = ('\n\nThe returned mutable array was passed in as the '
|
||||
f'argument {arg_name}.')
|
||||
else:
|
||||
@ -2172,10 +2173,10 @@ def _check_no_returned_refs(
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic2(
|
||||
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
|
||||
fun: lu.WrappedFun,
|
||||
) -> tuple[Jaxpr, OutputType, list[Any]]:
|
||||
|
||||
trace = DynamicJaxprTrace(debug_info)
|
||||
trace = DynamicJaxprTrace(fun.debug_info)
|
||||
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
|
||||
in_avals, keep_inputs = unzip2(fun.in_type)
|
||||
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
|
||||
|
@ -33,7 +33,6 @@ import numpy as np
|
||||
import jax
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -652,7 +651,6 @@ class ParallelCallableInfo:
|
||||
in_axes: Iterable[int | None]
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]]
|
||||
avals: Sequence[core.AbstractValue]
|
||||
debug_info: api_util.TracingDebugInfo | None
|
||||
|
||||
@cached_property
|
||||
def local_devices(self):
|
||||
@ -723,8 +721,7 @@ def stage_parallel_callable(
|
||||
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
|
||||
fun, sharded_avals, pci.debug_info)
|
||||
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
|
||||
fun, sharded_avals)
|
||||
|
||||
assert len(out_sharded_avals) == len(pci.out_axes), (
|
||||
len(out_sharded_avals), len(pci.out_axes))
|
||||
@ -758,7 +755,7 @@ def get_pmap_jaxpr(
|
||||
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals, fun.debug_info)
|
||||
in_axes, out_axes_thunk, avals)
|
||||
with core.extend_axis_env_nd([(axis_name, axis_size)]):
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
@ -964,7 +961,7 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: list[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
jaxpr_debug_info: core.JaxprDebugInfo
|
||||
jaxpr_debug_info: core.DebugInfo
|
||||
|
||||
def build_execute_fun(self):
|
||||
input_indices = []
|
||||
@ -992,7 +989,7 @@ class UnloadedPmapExecutable:
|
||||
|
||||
return PmapExecutable(
|
||||
self.compiled, self.build_execute_fun, fingerprint,
|
||||
self.local_input_avals, self.jaxpr_debug_info, self)
|
||||
self.local_input_avals, self)
|
||||
|
||||
@staticmethod
|
||||
def from_hlo(hlo: ir.Module,
|
||||
@ -1004,7 +1001,7 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects: list[core.Effect],
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
jaxpr_debug_info: core.DebugInfo,
|
||||
platforms: Sequence[str],
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
compiler_options=None):
|
||||
@ -1119,24 +1116,23 @@ class UnloadedPmapExecutable:
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
||||
"_unloaded_executable"]
|
||||
"fingerprint", "in_avals", "_unloaded_executable"]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
||||
in_avals, jaxpr_debug_info, unloaded_executable):
|
||||
in_avals,
|
||||
unloaded_executable: UnloadedPmapExecutable):
|
||||
self.xla_executable = xla_executable
|
||||
self._unsafe_call = None
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
self.fingerprint = fingerprint
|
||||
self.in_avals = in_avals
|
||||
self._jaxpr_debug_info = jaxpr_debug_info
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
def unsafe_call(self) -> Callable[..., Any]:
|
||||
if self._unsafe_call is None:
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
return self._unsafe_call # type: ignore
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -1147,7 +1143,8 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
def call(self, *args):
|
||||
# TODO(frostig): do we need to check sharding and sharded avals?
|
||||
arg_avals = map(core.abstractify, args)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals,
|
||||
self._unloaded_executable.jaxpr_debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
|
||||
@ -2127,7 +2124,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
debug_info: core.JaxprDebugInfo | None
|
||||
debug_info: core.DebugInfo | None
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
@ -2588,7 +2585,7 @@ def try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
|
||||
recover_in_s, recover_in_aval = None, None
|
||||
for in_s, in_aval in safe_zip(in_shardings, in_avals):
|
||||
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
|
||||
if isinstance(in_s, NamedSharding):
|
||||
recover_in_s, recover_in_aval = in_s, in_aval
|
||||
break
|
||||
if recover_in_s is None:
|
||||
@ -3199,14 +3196,14 @@ def cc_shard_arg(x, sharding, layout):
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
||||
jaxpr_debug_info: core.DebugInfo | None = None):
|
||||
if len(ref_avals) != len(arg_avals):
|
||||
raise TypeError(
|
||||
f"Computation compiled for {len(ref_avals)} inputs "
|
||||
f"but called with {len(arg_avals)}")
|
||||
|
||||
if jaxpr_debug_info is not None:
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
|
||||
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
|
||||
else:
|
||||
num_args = len(ref_avals)
|
||||
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
||||
@ -3258,7 +3255,7 @@ def check_array_xla_sharding_layout_match(
|
||||
args_after_dce,
|
||||
in_xla_shardings: Sequence[JSharding],
|
||||
in_xla_layouts: Sequence[DeviceLocalLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None,
|
||||
jaxpr_debug_info: core.DebugInfo | None,
|
||||
kept_var_idx: set[int]) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
|
||||
|
@ -53,19 +53,19 @@ def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
def _initial_style_open_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info),
|
||||
in_tree)
|
||||
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_fun, in_avals, debug_info)
|
||||
wrapped_fun, in_avals)
|
||||
return jaxpr, consts, out_tree(), attrs_tracked
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, debug_info)
|
||||
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable,
|
||||
def _initial_style_jaxpr_attrs(fun: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
debug_info: core.DebugInfo):
|
||||
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals, debug_info)
|
||||
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable,
|
||||
def _initial_style_jaxprs_with_common_consts(
|
||||
funs: Sequence[Callable],
|
||||
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
||||
debug_infos: Sequence[api_util.TracingDebugInfo]):
|
||||
debug_infos: Sequence[core.DebugInfo]):
|
||||
# When staging the branches of a conditional into jaxprs, constants are
|
||||
# extracted from each branch and converted to jaxpr arguments. To use the
|
||||
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
||||
|
@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
if (config.disable_jit.value and core.is_concrete(index)):
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
dbgs = [api_util.tracing_debug_info("switch", branch, operands, {})
|
||||
dbgs = [api_util.debug_info("switch", branch, operands, {})
|
||||
for branch in branches]
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {})
|
||||
dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
|
||||
if config.mutable_array_checks.value:
|
||||
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
|
||||
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {})
|
||||
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals,
|
||||
[dbg_true_fun, dbg_false_fun])
|
||||
@ -561,7 +561,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
effects_known = _join_cond_effects(branches_known)
|
||||
eqn_known = pe.new_jaxpr_eqn(
|
||||
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
|
||||
effects_known, eqn.source_info)
|
||||
effects_known, eqn.source_info, eqn.ctx)
|
||||
|
||||
# Build the staged eqn.
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
@ -569,7 +569,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
effects_staged = _join_cond_effects(branches_staged)
|
||||
eqn_staged = pe.new_jaxpr_eqn(
|
||||
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
|
||||
cond_p, params_staged, effects_staged, eqn.source_info)
|
||||
cond_p, params_staged, effects_staged, eqn.source_info, eqn.ctx)
|
||||
|
||||
new_vars = [*new_inst, *res_binders]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
@ -684,7 +684,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_effects, eqn.source_info)
|
||||
eqn.primitive, new_params, new_effects, eqn.source_info, eqn.ctx)
|
||||
|
||||
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
|
||||
for jaxpr in new_params['branches'])
|
||||
|
@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
def _create_jaxpr(init):
|
||||
init_flat = tree_leaves(init)
|
||||
_, in_tree = tree_flatten((init, xs))
|
||||
dbg = api_util.tracing_debug_info("scan", f, (init, xs), {})
|
||||
dbg = api_util.debug_info("scan", f, (init, xs), {})
|
||||
carry_avals = tuple(map(core.get_aval, init_flat))
|
||||
jaxpr, _, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, dbg)
|
||||
@ -585,7 +585,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||||
call_jaxpr.effects, eqn.source_info)
|
||||
call_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
|
||||
jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals),
|
||||
jaxpr_staged_resin_,
|
||||
@ -609,7 +609,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
_, outvars = partition_list(out_inst, eqn.outvars)
|
||||
eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars,
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||||
call_jaxpr.effects, eqn.source_info)
|
||||
call_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||
new_vars = [*new_inst, *resvars]
|
||||
return eqn_known, eqn_staged, in_unknowns, out_inst, new_vars
|
||||
|
||||
|
@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
return carry, stacked_y
|
||||
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})
|
||||
dbg_body = api_util.debug_info("scan", f, (init, xs), {})
|
||||
|
||||
if config.mutable_array_checks.value:
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(core.get_aval, init_vals))
|
||||
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {})
|
||||
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
||||
cond_fun, in_tree, init_avals, cond_dbg)
|
||||
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {})
|
||||
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
||||
body_fun, in_tree, init_avals, body_dbg)
|
||||
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
|
||||
@ -1368,7 +1368,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
if (not isinstance(pred_aval, ShapedArray)
|
||||
or pred_aval.strip_weak_type() != ShapedArray((), np.bool_)):
|
||||
or ShapedArray(pred_aval.shape, pred_aval.dtype) != ShapedArray((), np.bool_)):
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
|
||||
@ -1855,18 +1855,26 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
f'Effects not supported in `while`: {disallowed_effects}')
|
||||
return body_jaxpr.out_avals, joined_effects
|
||||
|
||||
def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
|
||||
def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
|
||||
cond_nconsts, body_nconsts):
|
||||
# TODO(sharadmv): enable supporting state effects in the cond
|
||||
if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects):
|
||||
raise NotImplementedError
|
||||
cond_consts_discharge, body_consts_discharge, carry_discharge = split_list(
|
||||
should_discharge, [cond_nconsts, body_nconsts])
|
||||
|
||||
if any(cond_consts_discharge):
|
||||
raise NotImplementedError
|
||||
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
|
||||
cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals,
|
||||
[cond_nconsts,
|
||||
body_nconsts])
|
||||
# There shouldn't be any `Ref`s in the `cond` (because of our check above).
|
||||
assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals)
|
||||
is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals]
|
||||
is_ref = [
|
||||
isinstance(aval, state.AbstractRef) and should
|
||||
for aval, should in zip(body_consts_avals, body_consts_discharge)
|
||||
]
|
||||
remaining_body_consts, refs = partition_list(is_ref, body_consts)
|
||||
remaining_body_const_avals, ref_avals = partition_list(is_ref,
|
||||
body_consts_avals)
|
||||
@ -1886,7 +1894,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
|
||||
# Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that
|
||||
# they are part of the carry.
|
||||
discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
body_jaxpr, ())
|
||||
body_jaxpr, (), should_discharge=[*body_consts_discharge, *carry_discharge])
|
||||
if discharged_consts: raise NotImplementedError
|
||||
|
||||
def new_body(*consts_refs_carry):
|
||||
@ -1943,7 +1951,7 @@ batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
mlir.register_lowering(while_p, _while_lowering)
|
||||
core.custom_typechecks[while_p] = _while_typecheck
|
||||
state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)
|
||||
state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule)
|
||||
|
||||
|
||||
def _pred_bcast_select_hlo(ctx,
|
||||
|
@ -93,16 +93,16 @@ def custom_root(f: Callable,
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(core.get_aval, guess_flat))
|
||||
f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {})
|
||||
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals, f_debug)
|
||||
|
||||
in_tree, = treedef_children(in_args_tree)
|
||||
_check_tree("f", "initial_guess", out_tree, in_tree, False)
|
||||
|
||||
solve_debug = api_util.tracing_debug_info("custom_root solve", solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
solve_debug = api_util.debug_info("custom_root solve", solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
||||
partial(solve, f), in_args_tree, guess_avals, solve_debug)
|
||||
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
|
||||
@ -111,10 +111,10 @@ def custom_root(f: Callable,
|
||||
unchecked_zeros, f_jvp = api.linearize(f, x)
|
||||
return tangent_solve(f_jvp, b)
|
||||
|
||||
tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve",
|
||||
tangent_solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
tangent_solve_debug = api_util.debug_info("custom_root tangent_solve",
|
||||
tangent_solve,
|
||||
(f, initial_guess), {},
|
||||
static_argnums=(0,))
|
||||
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
||||
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
|
||||
tangent_solve_debug)
|
||||
@ -265,17 +265,17 @@ def custom_linear_solve(
|
||||
|
||||
return f_aux if has_aux else f
|
||||
|
||||
matvec_debug = api_util.tracing_debug_info("custom_linear_solve",
|
||||
matvec, (b,), {})
|
||||
matvec_debug = api_util.debug_info("custom_linear_solve",
|
||||
matvec, (b,), {})
|
||||
# no auxiliary data assumed for matvec
|
||||
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
|
||||
matvec_debug)
|
||||
_check_tree("matvec", "b", out_tree, tree, False)
|
||||
|
||||
solve_debug = api_util.tracing_debug_info("custom_linear_solve solve",
|
||||
solve, (matvec, b), {},
|
||||
static_argnums=(0,))
|
||||
solve_debug = api_util.debug_info("custom_linear_solve solve",
|
||||
solve, (matvec, b), {},
|
||||
static_argnums=(0,))
|
||||
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
|
||||
solve_debug)
|
||||
@ -285,7 +285,7 @@ def custom_linear_solve(
|
||||
vecmat_jaxpr = tr_solve_jaxpr = None
|
||||
vecmat_consts = tr_solve_consts = []
|
||||
else:
|
||||
transpose_solve_debug = api_util.tracing_debug_info(
|
||||
transpose_solve_debug = api_util.debug_info(
|
||||
"custom_linear_solve transpose_solve", transpose_solve,
|
||||
(matvec, b), {}, static_argnums=(0,))
|
||||
if symmetric:
|
||||
@ -325,7 +325,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
||||
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
||||
if num_aux > 0:
|
||||
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
||||
return args_to_raise
|
||||
return args_to_raise, jaxprs.solve.effects
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
@ -482,7 +482,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
|
||||
linear_solve_p = core.Primitive('custom_linear_solve')
|
||||
linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.register_initial_style_primitive(linear_solve_p)
|
||||
mlir.register_lowering(
|
||||
|
@ -367,12 +367,46 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
"""
|
||||
return nextafter_p.bind(x1, x2)
|
||||
|
||||
@export
|
||||
def floor(x: ArrayLike) -> Array:
|
||||
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
||||
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
|
||||
|
||||
This function lowers directly to the `stablehlo.floor`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point type.
|
||||
|
||||
Returns:
|
||||
Array of same shape and dtype as ``x``, containing values rounded
|
||||
to the next integer toward negative infinity.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
||||
- :func:`jax.lax.round`: round to the nearest integer
|
||||
|
||||
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
|
||||
"""
|
||||
return floor_p.bind(x)
|
||||
|
||||
@export
|
||||
def ceil(x: ArrayLike) -> Array:
|
||||
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
|
||||
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
|
||||
|
||||
This function lowers directly to the `stablehlo.ceil`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point type.
|
||||
|
||||
Returns:
|
||||
Array of same shape and dtype as ``x``, containing values rounded
|
||||
to the next integer toward positive infinity.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
||||
- :func:`jax.lax.round`: round to the nearest integer
|
||||
|
||||
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
|
||||
"""
|
||||
return ceil_p.bind(x)
|
||||
|
||||
class RoundingMethod(enum.IntEnum):
|
||||
@ -388,20 +422,38 @@ class RoundingMethod(enum.IntEnum):
|
||||
as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
|
||||
"""
|
||||
|
||||
@export
|
||||
def round(x: ArrayLike,
|
||||
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
|
||||
) -> Array:
|
||||
r"""Elementwise round.
|
||||
|
||||
Rounds values to the nearest integer.
|
||||
Rounds values to the nearest integer. This function lowers directly to the
|
||||
`stablehlo.round`_ operation.
|
||||
|
||||
Args:
|
||||
x: an array or scalar value to round.
|
||||
x: an array or scalar value to round. Must have floating-point type.
|
||||
rounding_method: the method to use when rounding halfway values
|
||||
(e.g., `0.5`). See :class:`jax.lax.RoundingMethod` for possible values.
|
||||
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
|
||||
|
||||
Returns:
|
||||
An array containing the elementwise rounding of x.
|
||||
An array of the same shape and dtype as ``x``, containing the elementwise
|
||||
rounding of ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
||||
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
||||
|
||||
Examples:
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import lax
|
||||
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
|
||||
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
|
||||
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
|
||||
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
|
||||
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
|
||||
|
||||
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
|
||||
"""
|
||||
rounding_method = RoundingMethod(rounding_method)
|
||||
return round_p.bind(x, rounding_method=rounding_method)
|
||||
@ -409,29 +461,126 @@ def round(x: ArrayLike,
|
||||
def is_finite(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`\mathrm{isfinite}`.
|
||||
|
||||
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
|
||||
:math:`\mathit{NaN}`.
|
||||
This function lowers directly to the `stablehlo.is_finite`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point type.
|
||||
|
||||
Returns:
|
||||
Array of boolean dtype with the same shape as ``x``, containing ``False`` where
|
||||
``x`` is :math:`\pm\infty` or :math:`\mathit{NaN}`, and ``True`` otherwise.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.isinf`: return True where array is infinite.
|
||||
- :func:`jax.numpy.isnan`: return True where array is NaN.
|
||||
|
||||
.. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite
|
||||
"""
|
||||
return is_finite_p.bind(x)
|
||||
|
||||
def exp(x: ArrayLike) -> Array:
|
||||
r"""Elementwise exponential: :math:`e^x`."""
|
||||
r"""Elementwise exponential: :math:`e^x`.
|
||||
|
||||
This function lowers directly to the `stablehlo.exponential`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
exponential.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
|
||||
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
||||
|
||||
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
||||
"""
|
||||
return exp_p.bind(x)
|
||||
|
||||
def exp2(x: ArrayLike) -> Array:
|
||||
r"""Elementwise base-2 exponential: :math:`2^x`."""
|
||||
r"""Elementwise base-2 exponential: :math:`2^x`.
|
||||
|
||||
This function is implemented in terms of the `stablehlo.exponential`_
|
||||
and `stablehlo.multiply`_ operations.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
base-2 exponential.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
||||
|
||||
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
|
||||
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
|
||||
"""
|
||||
return exp2_p.bind(x)
|
||||
|
||||
def expm1(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`e^{x} - 1`."""
|
||||
r"""Elementwise :math:`e^{x} - 1`.
|
||||
|
||||
This function lowers directly to the `stablehlo.exponential_minus_one`_
|
||||
operation. Compared to the naive expression ``lax.exp(x) - 1``, it is
|
||||
more accurate for ``x`` near zero.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
exponential minus 1.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
- :func:`jax.lax.log1p`: elementwise :math:`\mathrm{log}(1 + x)`.
|
||||
|
||||
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
|
||||
"""
|
||||
return expm1_p.bind(x)
|
||||
|
||||
def log(x: ArrayLike) -> Array:
|
||||
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
|
||||
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
|
||||
|
||||
This function lowers directly to the `stablehlo.log`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
natural logarithm.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
|
||||
|
||||
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
|
||||
"""
|
||||
return log_p.bind(x)
|
||||
|
||||
def log1p(x: ArrayLike) -> Array:
|
||||
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
|
||||
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
|
||||
|
||||
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
|
||||
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
|
||||
for ``x`` near zero.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating-point or complex type.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
natural logarithm of ``x + 1``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
|
||||
- :func:`jax.lax.log`: elementwise natural logarithm :math:`\mathrm{log}(x)`.
|
||||
|
||||
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
|
||||
"""
|
||||
return log1p_p.bind(x)
|
||||
|
||||
def tanh(x: ArrayLike) -> Array:
|
||||
@ -745,9 +894,10 @@ def _trace_composite_to_jaxpr(fun: Callable,
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
name: str,
|
||||
debug_info: api_util.TracingDebugInfo):
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
|
||||
debug_info: core.DebugInfo):
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info), in_tree)
|
||||
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError(
|
||||
"Found a JAX Tracer as a constant in the decomposition for the "
|
||||
@ -822,8 +972,8 @@ def composite(
|
||||
"""
|
||||
@functools.wraps(decomposition)
|
||||
def _decorator(*args, **kwargs):
|
||||
debug_info = api_util.tracing_debug_info("composite", decomposition,
|
||||
args, kwargs)
|
||||
debug_info = api_util.debug_info("composite", decomposition,
|
||||
args, kwargs)
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals = tuple(core.get_aval(x) for x in flat_args)
|
||||
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
|
||||
@ -3274,7 +3424,7 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
||||
if isinstance(sharding, NamedSharding):
|
||||
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
|
||||
else:
|
||||
return None
|
||||
return core.get_cur_mesh_sharding()
|
||||
return sharding
|
||||
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
||||
@ -6540,6 +6690,8 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
|
||||
if (not dyn_shape and
|
||||
not any(isinstance(d, core.DArray) and
|
||||
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
||||
if sharding is None:
|
||||
sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape)))
|
||||
return ShapedArray(shape, dtype, sharding=sharding)
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
|
||||
|
@ -733,7 +733,6 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
||||
f"named axes, but got: {axes}")
|
||||
if config.sharding_in_types.value:
|
||||
args = core.cast_from_auto_to_manual(args)
|
||||
core.check_avals_context_mesh(args, 'all_reduce')
|
||||
out_avals = [
|
||||
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
|
||||
|
@ -1884,12 +1884,12 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
|
||||
mode, fill_value):
|
||||
# TODO(yashkatariya): Write a proper gather sharding rule.
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
|
||||
return None
|
||||
if (cur_mesh._are_all_axes_explicit and # type: ignore
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
|
||||
return core.get_cur_mesh_sharding()
|
||||
if (cur_mesh._are_all_axes_explicit and
|
||||
all(s is None for s in operand.sharding.spec) and
|
||||
all(s is None for s in indices.sharding.spec)):
|
||||
return None
|
||||
return core.get_cur_mesh_sharding()
|
||||
raise GatherShardingError(
|
||||
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
|
||||
" the gather indexing.")
|
||||
|
@ -24,6 +24,7 @@ from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
@ -49,9 +50,14 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
from jax._src.pjit import _get_abstract_mesh_from_avals, NamedSharding
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
|
||||
return None if num_out is None else [None] * num_out
|
||||
aval_mesh = _get_abstract_mesh_from_avals(avals)
|
||||
# TODO(yashkatariya): `aval_mesh.empty` should be `aval_mesh.unset`
|
||||
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
|
||||
s = NamedSharding(aval_mesh, P())
|
||||
return s if num_out is None else [s] * num_out
|
||||
if rule is None:
|
||||
raise ValueError(
|
||||
f'sharding rule for {prim.name} is not implemented. Please file a'
|
||||
@ -68,7 +74,6 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
weak_type = weak_type_rule(*avals, **kwargs)
|
||||
least_specialized = type(max(avals, key=_get_array_abstraction_level))
|
||||
if least_specialized is core.ShapedArray:
|
||||
avals = core.cast_from_auto_to_manual(avals)
|
||||
core.check_avals_context_mesh(avals, prim.name)
|
||||
out_aval = core.ShapedArray(
|
||||
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
|
||||
@ -94,7 +99,6 @@ def standard_multi_result_abstract_eval(
|
||||
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
|
||||
weak_types = weak_type_rule(*avals, **kwargs)
|
||||
if least_specialized is core.ShapedArray:
|
||||
avals = core.cast_from_auto_to_manual(avals)
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
core.check_avals_context_mesh(avals, prim.name)
|
||||
|
@ -12,4 +12,70 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
from typing import Protocol
|
||||
|
||||
from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error
|
||||
|
||||
|
||||
class CompilationResult(Protocol):
|
||||
asm: str
|
||||
smem_bytes: int
|
||||
cluster_dim_x: int
|
||||
cluster_dim_y: int
|
||||
cluster_dim_z: int
|
||||
|
||||
|
||||
class CompilationHandler(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
module: bytes,
|
||||
arch_name: str,
|
||||
num_warps: int,
|
||||
num_ctas: int,
|
||||
num_stages: int,
|
||||
) -> CompilationResult:
|
||||
...
|
||||
|
||||
|
||||
_compilation_handlers: dict[str, CompilationHandler] = {}
|
||||
_compilation_handlers_lock = threading.Lock()
|
||||
|
||||
|
||||
def register_compilation_handler(
|
||||
platform: str, handler: CompilationHandler
|
||||
) -> None:
|
||||
platform = platform.upper()
|
||||
with _compilation_handlers_lock:
|
||||
if existing_handler := _compilation_handlers.get(platform):
|
||||
raise RuntimeError(
|
||||
f'Platform {platform} already has a Triton compilation handler:'
|
||||
f' {existing_handler}'
|
||||
)
|
||||
_compilation_handlers[platform] = handler
|
||||
|
||||
|
||||
def has_compilation_handler(platform: str) -> bool:
|
||||
platform = platform.upper()
|
||||
with _compilation_handlers_lock:
|
||||
return platform in _compilation_handlers
|
||||
|
||||
|
||||
def compile(
|
||||
platform: str,
|
||||
module: bytes,
|
||||
arch_name: str,
|
||||
*,
|
||||
num_warps: int,
|
||||
num_ctas: int,
|
||||
num_stages: int,
|
||||
) -> CompilationResult:
|
||||
platform = platform.upper()
|
||||
with _compilation_handlers_lock:
|
||||
handler = _compilation_handlers.get(platform)
|
||||
if handler is None:
|
||||
raise RuntimeError(
|
||||
f'Platform {platform} does not have a Triton compilation handler'
|
||||
)
|
||||
return handler(module, arch_name, num_warps, num_ctas, num_stages)
|
||||
|
@ -63,7 +63,7 @@ data must be immutable, because it will be stored in function memoization tables
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, NamedTuple
|
||||
import weakref
|
||||
@ -71,6 +71,7 @@ import weakref
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.tree_util import keystr, generate_key_paths
|
||||
from jax._src.util import curry, cache_clearing_funs, HashableFunction
|
||||
|
||||
|
||||
@ -156,7 +157,7 @@ class WrappedFun:
|
||||
f_transformed: Callable,
|
||||
transforms,
|
||||
stores: tuple[Store | EqualStore | None, ...], params, in_type,
|
||||
debug_info: TracingDebugInfo | None):
|
||||
debug_info: DebugInfo | None):
|
||||
self.f = f
|
||||
self.f_transformed = f_transformed
|
||||
self.transforms = transforms
|
||||
@ -253,12 +254,10 @@ def fun_name(f):
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
class TracingDebugInfo(NamedTuple):
|
||||
"""Tracing-time debugging info about a func and its arguments.
|
||||
|
||||
Formed just before staging to a jaxpr and read in trace-time error messages.
|
||||
"""
|
||||
class DebugInfo(NamedTuple):
|
||||
"""Debugging info about a func, its arguments, and results."""
|
||||
traced_for: str # e.g. 'jit', 'scan', etc
|
||||
|
||||
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
|
||||
# no source location information. The first word is always the function name,
|
||||
# which may be '<unknown>'.
|
||||
@ -270,23 +269,18 @@ class TracingDebugInfo(NamedTuple):
|
||||
# e.g., tangent args in jax.jvp.
|
||||
arg_names: tuple[str | None, ...]
|
||||
|
||||
# The result paths are not available while we are tracing the function,
|
||||
# instead we keep a thunk. Once we are done tracing, we use
|
||||
# `self.resolve_result_paths()` to execute the thunk and replace the
|
||||
# actual result paths.
|
||||
# e.g. ('[0]', '[1]', ...)
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None
|
||||
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
|
||||
|
||||
@classmethod
|
||||
def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
|
||||
jaxpr_dbg = jaxpr.jaxpr._debug_info
|
||||
if jaxpr_dbg is None: return None
|
||||
return TracingDebugInfo(jaxpr_dbg.traced_for,
|
||||
jaxpr_dbg.func_src_info,
|
||||
jaxpr_dbg.arg_names,
|
||||
lambda: jaxpr_dbg.result_paths)
|
||||
|
||||
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
|
||||
) -> TracingDebugInfo:
|
||||
assert self.result_paths_thunk is None
|
||||
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
|
||||
closure=()))
|
||||
def resolve_result_paths(self) -> DebugInfo:
|
||||
"""Return a debug info with resolved result paths."""
|
||||
if callable(self.result_paths):
|
||||
return self._replace(result_paths=tuple(self.result_paths()))
|
||||
return self
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
@ -296,15 +290,47 @@ class TracingDebugInfo(NamedTuple):
|
||||
# TODO(necula): this should not happen
|
||||
return (None,) * expected
|
||||
|
||||
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
|
||||
"""Keep only the arg_names for which `keep` is True."""
|
||||
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)
|
||||
|
||||
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
|
||||
"""Get the result paths with a safety check."""
|
||||
assert not callable(self.result_paths), self
|
||||
if self.result_paths is not None and len(self.result_paths) == expected:
|
||||
return self.result_paths
|
||||
else:
|
||||
# TODO(necula): this should not happen
|
||||
return ("",) * expected
|
||||
|
||||
def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
|
||||
"""Keep only the result_paths for which `keep` is True."""
|
||||
assert not callable(self.result_paths), self
|
||||
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
|
||||
|
||||
|
||||
def wrap_init(f: Callable, params=None, *,
|
||||
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
|
||||
debug_info: DebugInfo | None = None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params_dict = {} if params is None else params
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
|
||||
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
|
||||
if debug_info:
|
||||
if debug_info.result_paths is None:
|
||||
fun, result_paths_thunk = _get_result_paths_thunk(fun)
|
||||
debug_info = debug_info._replace(
|
||||
result_paths=HashableFunction(result_paths_thunk, closure=()))
|
||||
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
|
||||
fun.params, fun.in_type, debug_info)
|
||||
return fun
|
||||
|
||||
|
||||
@transformation_with_aux2
|
||||
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
@ -341,16 +367,9 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
|
||||
def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
|
||||
) -> WrappedFun:
|
||||
"""Produce a new WrappedFun with debug_info attached."""
|
||||
assert f.debug_info is None
|
||||
if debug_info is None:
|
||||
return f
|
||||
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
|
||||
|
||||
|
||||
def cache(call: Callable, *, explain: Callable | None = None):
|
||||
def cache(call: Callable, *,
|
||||
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
Args:
|
||||
@ -358,6 +377,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
underlying transforms and params on the WrappedFun are used as part of the
|
||||
memoization cache key.
|
||||
|
||||
explain: a function that is invoked upon cache misses to log an explanation
|
||||
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
|
||||
|
||||
Returns:
|
||||
A memoized version of ``call``.
|
||||
"""
|
||||
@ -373,7 +395,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
else:
|
||||
ans = call(fun, *args)
|
||||
if explain and config.explain_cache_misses.value:
|
||||
explain(fun.f, cache is new_cache, cache, key)
|
||||
explain(fun, cache is new_cache, cache, key)
|
||||
cache[key] = (ans, fun.stores)
|
||||
|
||||
return ans
|
||||
|
@ -530,8 +530,8 @@ class AbstractMesh:
|
||||
|
||||
@staticmethod
|
||||
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
||||
jax_config.abstract_mesh_context_manager.set_local(mesh)
|
||||
return
|
||||
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
||||
return prev
|
||||
|
||||
|
||||
# Create this indirection because pytype fails to recognize a property if a
|
||||
|
@ -9744,11 +9744,12 @@ def einsum(
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
if spec is not None:
|
||||
einsum = jax.named_call(einsum, name=spec)
|
||||
return einsum(operands, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
jit_einsum = jax.named_call(jit_einsum, name=spec)
|
||||
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
|
||||
return jit_einsum(operand_arrays, contractions, precision,
|
||||
preferred_element_type, _dot_general, out_sharding)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
@ -9843,7 +9844,7 @@ def _removechars(s, chars):
|
||||
|
||||
|
||||
def _einsum(
|
||||
operands: Sequence,
|
||||
operands: list[jax.Array],
|
||||
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
|
||||
precision,
|
||||
preferred_element_type,
|
||||
@ -9859,7 +9860,6 @@ def _einsum(
|
||||
"`out_sharding` argument of `einsum` only supports NamedSharding"
|
||||
" instances. Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
operands = list(map(asarray, operands))
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
|
||||
else:
|
||||
@ -11649,7 +11649,8 @@ def take_along_axis(
|
||||
j = 0
|
||||
for i in range(rank):
|
||||
if i == axis_int:
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
if mode != 'promise_in_bounds':
|
||||
indices = _normalize_index(indices, axis_size)
|
||||
gather_indices.append(lax.reshape(indices, gather_index_shape))
|
||||
slice_sizes.append(1)
|
||||
start_index_map.append(i)
|
||||
|
@ -222,10 +222,6 @@ class AbstractMemoryRef(state.AbstractRef):
|
||||
def __repr__(self) -> str:
|
||||
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
return self.inner_aval.sharding
|
||||
|
||||
def update_weak_type(self, weak_type):
|
||||
return AbstractMemoryRef(
|
||||
self.inner_aval.update_weak_type(weak_type), self.memory_space)
|
||||
@ -413,9 +409,9 @@ class BlockSpec:
|
||||
|
||||
fake_index_map_args, fake_index_map_kwargs = \
|
||||
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
|
||||
debug = api_util.tracing_debug_info("pallas_call index_map",
|
||||
index_map_func, fake_index_map_args,
|
||||
fake_index_map_kwargs)
|
||||
debug = api_util.debug_info("pallas_call index_map",
|
||||
index_map_func, fake_index_map_args,
|
||||
fake_index_map_kwargs)
|
||||
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
|
||||
index_map_src_info = NameAndSrcInfo.from_pallas_call(
|
||||
@ -423,7 +419,7 @@ class BlockSpec:
|
||||
)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_index_map_fun, index_map_avals, debug_info=debug
|
||||
flat_index_map_fun, index_map_avals
|
||||
)
|
||||
mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
|
||||
if len(out_avals) != len(block_shape):
|
||||
@ -890,7 +886,8 @@ def get_grid_mapping(
|
||||
)
|
||||
# The inputs for the index maps
|
||||
index_map_avals = (
|
||||
(index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid))
|
||||
index_map_grid_aval.update(sharding=jax_core.get_cur_mesh_sharding()),
|
||||
) * len(grid_spec.grid)
|
||||
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
|
||||
|
||||
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
|
||||
|
@ -49,3 +49,13 @@ class ArrayLike(Protocol):
|
||||
|
||||
def empty_like(x: ArrayLike, *, memory_space: Any = None):
|
||||
return empty(x.shape, x.dtype, memory_space=memory_space)
|
||||
|
||||
|
||||
def when(condition):
|
||||
def _wrapped(f):
|
||||
if isinstance(condition, bool):
|
||||
if condition:
|
||||
f()
|
||||
else:
|
||||
jax.lax.cond(condition, f, lambda: None)
|
||||
return _wrapped
|
||||
|
@ -30,8 +30,9 @@ LOCATION_PATTERN = re.compile(
|
||||
r'(?P<location>loc\((?P<eqn_str>\".*?\")(?P<frames>.*)\))'
|
||||
)
|
||||
FRAME_PATTERN = re.compile(
|
||||
r'(?P<fun_name>\".*?\")\((?P<filename>\".*?\"):'
|
||||
r'(?P<lineno>[0-9]+):(?P<colno>[0-9]+)\)'
|
||||
r'(?P<fun_name>\".*?\")\((?P<filename>\"[^"]*?\"):'
|
||||
r'(?P<lineno>[0-9]+)?:(?P<colno>[0-9]+)'
|
||||
r'( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)'
|
||||
)
|
||||
MLIR_ERR_PREFIX = (
|
||||
'Pallas encountered an internal verification error.'
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax._src.pallas import helpers as pl_helpers
|
||||
from jax._src.pallas import primitives as pl_primitives
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
from jax._src.pallas.mosaic import primitives as plm_primitives
|
||||
@ -55,3 +56,40 @@ def sync_copy(src_ref, dst_ref):
|
||||
src_ref,
|
||||
dst_ref,
|
||||
)
|
||||
|
||||
|
||||
def run_on_first_core(core_axis_name: str):
|
||||
"""Runs a function on the first core in a given axis."""
|
||||
num_cores = jax.lax.psum(1, core_axis_name)
|
||||
if num_cores == 1:
|
||||
return lambda f: f()
|
||||
|
||||
def wrapped(f):
|
||||
core_id = jax.lax.axis_index(core_axis_name)
|
||||
|
||||
@pl_helpers.when(core_id == 0)
|
||||
@functools.wraps(f)
|
||||
def _():
|
||||
return f()
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def core_barrier(sem, *, core_axis_name: str):
|
||||
"""Synchronizes all cores in a given axis."""
|
||||
num_cores = jax.lax.psum(1, core_axis_name)
|
||||
core_id = jax.lax.axis_index(core_axis_name)
|
||||
|
||||
@pl_helpers.when(num_cores > 1)
|
||||
def _():
|
||||
with jax.named_scope("sync_cores"):
|
||||
|
||||
def signal_core(i):
|
||||
# Don't signal ourself
|
||||
@pl_helpers.when(core_id != i)
|
||||
def _():
|
||||
plm_primitives.semaphore_signal(sem, 1, core_index=i)
|
||||
|
||||
for i in range(num_cores):
|
||||
signal_core(i)
|
||||
plm_primitives.semaphore_wait(sem, num_cores - 1)
|
||||
|
@ -1524,7 +1524,9 @@ def _masked_swap_lowering_rule(
|
||||
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
|
||||
for b in ref_block_shape
|
||||
]
|
||||
mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None)
|
||||
mem_aval = aval_out.update(
|
||||
shape=tuple(mem_slice_shape), sharding=jax_core.get_cur_mesh_sharding()
|
||||
)
|
||||
mem_aval_shape = ctx.lowering_context.dynamic_shape_replacement_fn(
|
||||
mem_aval.shape
|
||||
)
|
||||
@ -2127,7 +2129,11 @@ def _gather_lowering_rule(
|
||||
slice_sizes == (1, 1)
|
||||
and not unique_indices
|
||||
and not indices_are_sorted
|
||||
and mode == lax.GatherScatterMode.FILL_OR_DROP
|
||||
and mode
|
||||
in (
|
||||
lax.GatherScatterMode.FILL_OR_DROP,
|
||||
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
|
||||
)
|
||||
):
|
||||
if dimension_numbers == lax.GatherDimensionNumbers(
|
||||
offset_dims=(),
|
||||
@ -3011,6 +3017,11 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
|
||||
|
||||
|
||||
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
||||
return x
|
||||
lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule
|
||||
|
||||
|
||||
def _custom_jvp_call_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
|
@ -921,13 +921,18 @@ ARBITRARY = GridDimensionSemantics()
|
||||
|
||||
def _partition_grid(
|
||||
grid: tuple[int | jax.Array, ...],
|
||||
core_axis: int | None,
|
||||
core_axis: int | str | None,
|
||||
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
|
||||
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
|
||||
if core_axis is None:
|
||||
# We aren't partitioning the grid
|
||||
return grid, (0,) * len(grid)
|
||||
num_cores = pl.num_programs(core_axis)
|
||||
if isinstance(core_axis, int):
|
||||
num_cores = pl.num_programs(core_axis)
|
||||
core_id = pl.program_id(core_axis)
|
||||
else:
|
||||
num_cores = jax.lax.psum(1, core_axis)
|
||||
core_id = jax.lax.axis_index(core_axis)
|
||||
# Check that num_cores is statically known
|
||||
if not isinstance(num_cores, int):
|
||||
raise NotImplementedError(
|
||||
@ -966,7 +971,7 @@ def _partition_grid(
|
||||
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
|
||||
)
|
||||
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
|
||||
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
|
||||
partitioned_dim_offset = core_id * partitioned_dim_size
|
||||
new_grid = jax_util.tuple_update(
|
||||
grid, first_divisible_dimension, partitioned_dim_size
|
||||
)
|
||||
@ -990,8 +995,7 @@ def _partition_grid(
|
||||
# We have some remainder iterations that we need to assign somewhere. We
|
||||
# know that rem < num_cores, so we can assign one extra iteration to each
|
||||
# core except for the last (num_cores - rem).
|
||||
core_index = pl.program_id(core_axis)
|
||||
num_iters = jnp.where(core_index < rem, base_num_iters + 1,
|
||||
num_iters = jnp.where(core_id < rem, base_num_iters + 1,
|
||||
base_num_iters)
|
||||
new_grid = jax_util.tuple_update(grid, partition_dimension, num_iters)
|
||||
# Ordinarily, we would compute the offset as:
|
||||
@ -999,9 +1003,9 @@ def _partition_grid(
|
||||
# However, since we have some cores that don't have an extra iteration, we
|
||||
# need to adjust the offset by `rem`.
|
||||
grid_offset = jnp.where(
|
||||
core_index < rem,
|
||||
core_index * num_iters,
|
||||
core_index * base_num_iters + rem,
|
||||
core_id < rem,
|
||||
core_id * num_iters,
|
||||
core_id * base_num_iters + rem,
|
||||
)
|
||||
offsets = jax_util.tuple_update(
|
||||
(0,) * len(grid), partition_dimension, grid_offset
|
||||
@ -1015,8 +1019,9 @@ def emit_pipeline(
|
||||
grid: tuple[int | jax.Array, ...],
|
||||
in_specs=None,
|
||||
out_specs=None,
|
||||
should_accumulate_out=False,
|
||||
should_accumulate_out: bool = False,
|
||||
core_axis: int | None = None,
|
||||
core_axis_name: str | None = None,
|
||||
dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None,
|
||||
trace_scopes: bool = True,
|
||||
):
|
||||
@ -1039,6 +1044,8 @@ def emit_pipeline(
|
||||
as accumulators.
|
||||
core_axis: optional int, indicates whether or not to partition the grid
|
||||
along the core axis.
|
||||
core_axis_name: optional str, indicates whether or not to partition the grid
|
||||
along the core axis.
|
||||
dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL
|
||||
or ARBITRARY).
|
||||
trace_scopes: optional bool, indicates whether to annotate each region in
|
||||
@ -1049,7 +1056,10 @@ def emit_pipeline(
|
||||
raise ValueError(
|
||||
f"Grid must consist of Python integers and JAX Arrays: {grid_types}"
|
||||
)
|
||||
grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics)
|
||||
if not (core_axis is None or core_axis_name is None):
|
||||
raise ValueError("core_axis and core_axis_name cannot both be provided.")
|
||||
core_axis_ = core_axis_name if core_axis is None else core_axis
|
||||
grid, grid_offsets = _partition_grid(grid, core_axis_, dimension_semantics)
|
||||
|
||||
num_steps = _grid_size(grid)
|
||||
if not isinstance(in_specs, (list, tuple)):
|
||||
|
@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
|
||||
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
|
||||
|
||||
def dma_start_discharge_rule(in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
(
|
||||
src_ref,
|
||||
src_transforms,
|
||||
@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
_,
|
||||
) = tree_util.tree_unflatten(tree, in_avals)
|
||||
del out_avals
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
dst_discharge,
|
||||
_,
|
||||
dst_sem_discharge,
|
||||
_,
|
||||
*maybe_src_sem_discharge,
|
||||
) = tree_util.tree_unflatten(tree, should_discharge)
|
||||
is_remote = device_id is not None
|
||||
src_sem_discharge = None
|
||||
|
||||
if is_remote:
|
||||
src_sem_discharge = maybe_src_sem_discharge[0]
|
||||
|
||||
if not is_remote:
|
||||
# Local async copies only use one semaphore.
|
||||
assert src_sem is None
|
||||
@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
|
||||
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
|
||||
|
||||
updates = state_discharge.transform_array(src_ref, src_transforms)
|
||||
updates = state_discharge.transform_array(src_ref[...], src_transforms)
|
||||
local_src = updates
|
||||
|
||||
if is_remote:
|
||||
@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
global_dst_transforms,
|
||||
)
|
||||
|
||||
_, new_dst = state_discharge.transform_swap_array(
|
||||
dst_ref, dst_transforms, updates
|
||||
)
|
||||
def do_discharge_dst(dst_ref=dst_ref):
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
dst_ref, dst_transforms, updates
|
||||
)
|
||||
return ret
|
||||
|
||||
# Update semaphore values.
|
||||
# TODO(justinfu): Potentially handle asymmetric copy sizes.
|
||||
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
dst_sem_value = _transform_semaphore(
|
||||
dst_sem, dst_sem_transforms, dst_sem_aval
|
||||
)
|
||||
_, new_dst_sem = state_discharge.transform_swap_array(
|
||||
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
|
||||
)
|
||||
if is_remote:
|
||||
def do_discharge_dst_sem(dst_sem=dst_sem):
|
||||
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
dst_sem_value = _transform_semaphore(
|
||||
dst_sem, dst_sem_transforms, dst_sem_aval
|
||||
)
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
|
||||
)
|
||||
return ret
|
||||
|
||||
def do_discharge_src_sem(src_sem=src_sem):
|
||||
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
|
||||
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
|
||||
src_sem_value = _transform_semaphore(
|
||||
src_sem, src_sem_transforms, src_sem_aval
|
||||
)
|
||||
_, new_src_sem = state_discharge.transform_swap_array(
|
||||
src_sem, src_sem_transforms, src_sem_value + send_size
|
||||
_, ret = state_discharge.transform_swap_array(
|
||||
src_sem, src_sem_transforms, src_sem_value[...] + send_size
|
||||
)
|
||||
else:
|
||||
new_src_sem = None
|
||||
return ret
|
||||
|
||||
new_vals = (None,) # src_val
|
||||
new_vals += (None,) * num_src_transform_vals
|
||||
new_vals += (new_dst,) # dst_val
|
||||
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
|
||||
new_vals += (None,) * num_dst_transform_vals
|
||||
new_vals += (new_dst_sem,) # dst_sem
|
||||
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
|
||||
new_vals += (None,) * num_dst_sem_transforms
|
||||
if is_remote:
|
||||
new_vals += (new_src_sem,) # src_sem
|
||||
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
|
||||
new_vals += (None,) * num_src_sem_transforms
|
||||
new_vals += (None,) # device_id
|
||||
assert (len(new_vals) ==
|
||||
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
|
||||
|
||||
# If we didn't discharge everything we could we should keep writes
|
||||
# to the references that are left over.
|
||||
if not dst_discharge:
|
||||
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
|
||||
if not dst_sem_discharge:
|
||||
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
|
||||
if is_remote and not src_sem_discharge:
|
||||
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))
|
||||
|
||||
return new_vals, []
|
||||
|
||||
state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
|
||||
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)
|
||||
|
||||
|
||||
dma_wait_p = jax_core.Primitive('dma_wait')
|
||||
@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
||||
|
||||
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
|
||||
|
||||
def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
def dma_wait_partial_discharge_rule(should_discharge,
|
||||
in_avals, out_avals,
|
||||
*args, tree, device_id_type):
|
||||
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
|
||||
del out_avals, device_id_type
|
||||
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
|
||||
@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
src_sem_transforms_avals,
|
||||
device_id_aval,
|
||||
) = tree_util.tree_unflatten(tree, in_avals)
|
||||
|
||||
# The only one we can discharge is the dst semaphore. The provided
|
||||
# buffers are only specified for their types and not their value so
|
||||
# it's completely irrelevant for us here if they are discharged.
|
||||
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
|
||||
if not should_discharge_unflattened[4]:
|
||||
return (None,) * len(in_avals), []
|
||||
|
||||
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
|
||||
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
|
||||
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
|
||||
@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
|
||||
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
|
||||
return new_vals, []
|
||||
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
|
||||
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)
|
||||
|
||||
def _get_ref_and_transforms(ref):
|
||||
if isinstance(ref, state.TransformedRef):
|
||||
|
@ -1121,6 +1121,9 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args
|
||||
)
|
||||
|
||||
@register_lowering_rule(pjit.mesh_cast_p)
|
||||
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
||||
return x
|
||||
|
||||
@register_lowering_rule(lax.slice_p)
|
||||
def _slice_lowering_rule(
|
||||
@ -1430,15 +1433,6 @@ def _run_scoped_lowering_rule(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
|
||||
)
|
||||
|
||||
for o in outs:
|
||||
# This is definitely one of the accumulators we produced. Each
|
||||
# run_scoped call is responsible for dereferencing its own
|
||||
# accumulators.
|
||||
if isinstance(o, mgpu.WGMMAAccumulator) or (
|
||||
isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type)
|
||||
):
|
||||
raise ValueError(f"No references are allowed to escape a scope. (got {o})")
|
||||
|
||||
assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
|
||||
return outs
|
||||
|
||||
@ -1508,7 +1502,14 @@ def _lower_jaxpr_to_for_loop(
|
||||
if arg_avals:
|
||||
out_avals = ctx.avals_out[-len(arg_avals):]
|
||||
|
||||
@mgpu.fori(length, [*map(_ensure_fa, args, arg_avals)])
|
||||
is_acc = [isinstance(v, mgpu.WGMMAAccumulator) for v in args]
|
||||
def as_fas(vals, avals):
|
||||
if is_acc != [isinstance(v, mgpu.WGMMAAccumulator) for v in vals]:
|
||||
raise ValueError("Unexpected loop carry w.r.t. accumulators.")
|
||||
|
||||
return [v if a else _ensure_fa(v, av) for a, v, av in zip(is_acc, vals, avals)]
|
||||
|
||||
@mgpu.fori(length, as_fas(args, arg_avals))
|
||||
def loop(loop_index, body_args):
|
||||
if has_loop_index:
|
||||
loop_index = arith_dialect.addi(loop_index, start)
|
||||
@ -1518,7 +1519,7 @@ def _lower_jaxpr_to_for_loop(
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args
|
||||
)
|
||||
return map(_ensure_fa, outs, out_avals)
|
||||
return as_fas(outs, out_avals)
|
||||
|
||||
return loop.results
|
||||
|
||||
@ -1640,7 +1641,10 @@ def _while_lowering_rule(
|
||||
_cond_avals, body_avals, carry_avals = util.split_list(
|
||||
ctx.avals_in, [cond_nconsts, body_nconsts]
|
||||
)
|
||||
carry = map(_ensure_fa, carry, carry_avals)
|
||||
carry = [
|
||||
v if isinstance(v, mgpu.WGMMAAccumulator) else _ensure_fa(v, av)
|
||||
for v, av in zip(carry, carry_avals)
|
||||
]
|
||||
# Flatten the carry to get a concatenated list of registers from each FA.
|
||||
# Note that the treedef is also used below to unflatten the body results.
|
||||
flat_carry, carry_treedef = jax.tree.flatten(carry)
|
||||
@ -1663,9 +1667,19 @@ def _while_lowering_rule(
|
||||
loop_out = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args
|
||||
)
|
||||
loop_out = map(_ensure_fa, loop_out, carry_avals)
|
||||
loop_out = [
|
||||
v if isinstance(v, mgpu.WGMMAAccumulator) else _ensure_fa(v, av)
|
||||
for v, av in zip(loop_out, carry_avals)
|
||||
]
|
||||
for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)):
|
||||
if carry_fa.layout != out_fa.layout:
|
||||
_is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator)
|
||||
if _is_acc(carry_fa) != _is_acc(out_fa):
|
||||
raise ValueError(
|
||||
f"The loop body output has unexpected accumulator type: output[{idx}]"
|
||||
f" is {out_fa}, when it should be {carry_fa}."
|
||||
)
|
||||
|
||||
if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout:
|
||||
raise ValueError(
|
||||
f"The loop body output has unexpected layout: output[{idx}] has"
|
||||
f" layout {out_fa.layout}, when it should be {carry_fa.layout}."
|
||||
@ -1865,6 +1879,19 @@ def merge_indexers(
|
||||
if indexer.int_indexer_shape:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _ensure_idx_fa(x):
|
||||
index = ir.IndexType.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if isinstance(x, ir.Value):
|
||||
return mgpu.FragmentedArray.splat(
|
||||
x, (), is_signed=mgpu.utils.is_signed(x.type)
|
||||
).astype(i32, is_signed=False)
|
||||
if isinstance(x, mgpu.FragmentedArray):
|
||||
return x.astype(i32, is_signed=False)
|
||||
if isinstance(x, int):
|
||||
return mgpu.FragmentedArray.splat(mgpu.c(x, i32), (), is_signed=False)
|
||||
raise NotImplementedError(x)
|
||||
|
||||
num_skipped = 0
|
||||
for i in range(len(current_indices)):
|
||||
# Integer indexers remove dimensions which should be
|
||||
@ -1876,18 +1903,17 @@ def merge_indexers(
|
||||
current_index = current_indices[i]
|
||||
assert isinstance(current_index, indexing.Slice)
|
||||
|
||||
current_start_index = _ensure_fa(current_index.start, jnp.int32)
|
||||
current_start_index = _ensure_idx_fa(current_index.start)
|
||||
if isinstance(dim_indexer, indexing.Slice):
|
||||
if dim_indexer.stride != 1:
|
||||
raise NotImplementedError("Non-unit strides not implemented.")
|
||||
current_indices[i] = indexing.Slice(
|
||||
current_start_index + _ensure_fa(dim_indexer.start, jnp.int32),
|
||||
current_start_index + _ensure_idx_fa(dim_indexer.start),
|
||||
dim_indexer.size,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
current_indices[i] = current_start_index + _ensure_fa(
|
||||
dim_indexer, dtype=jnp.int32)
|
||||
current_indices[i] = current_start_index + _ensure_idx_fa(dim_indexer)
|
||||
removed_dimensions.add(i)
|
||||
return indexing.NDIndexer(
|
||||
indices=tuple(current_indices),
|
||||
|
@ -558,15 +558,9 @@ def _wgmma_lowering(
|
||||
if rhs_tiling != (swizzle_elems, swizzle_elems):
|
||||
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")
|
||||
|
||||
new_acc = mgpu.wgmma(
|
||||
acc,
|
||||
a,
|
||||
b,
|
||||
swizzle=rhs_swizzle,
|
||||
b_order=mgpu.WGMMALayout.COL_MAJOR
|
||||
if rhs_transpose
|
||||
else mgpu.WGMMALayout.ROW_MAJOR,
|
||||
)
|
||||
if rhs_transpose:
|
||||
b = mgpu.memref_transpose(b, (0, 1, 3, 2))
|
||||
new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle)
|
||||
nvvm_dialect.wgmma_commit_group_sync_aligned()
|
||||
return new_acc
|
||||
|
||||
|
@ -37,8 +37,8 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import helpers as pallas_helpers
|
||||
from jax._src.pallas import hlo_interpreter
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.util import (
|
||||
@ -101,12 +101,12 @@ def _pallas_call_jvp_rule(
|
||||
primals,
|
||||
tangents,
|
||||
*,
|
||||
jaxpr,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name_and_src_info,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping,
|
||||
debug,
|
||||
interpret,
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
@ -739,7 +739,7 @@ def _pallas_call_batching_rule(
|
||||
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
|
||||
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
|
||||
|
||||
@pallas_utils.when(run_kernel)
|
||||
@pallas_helpers.when(run_kernel)
|
||||
def f():
|
||||
# Important! This allows us to trace the inner kernel with the correct
|
||||
# grid to preserve user program_id semantics. Ex: program_id(0) will
|
||||
@ -1098,13 +1098,14 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
||||
*error_memref_aval, *output_aval, *scratch_aval]
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
|
||||
retrace_in_avals, {})
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
|
||||
retrace_in_avals, {})
|
||||
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
|
||||
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals)
|
||||
|
||||
# Prepare pallas_call inputs. We need to create new block specs
|
||||
# for the new error inputs and outputs.
|
||||
@ -1161,16 +1162,16 @@ def _trace_kernel_to_jaxpr(
|
||||
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
|
||||
indexer: bool = False,
|
||||
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
|
||||
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
||||
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), kernel_in_tree)
|
||||
lu.wrap_init(fun, debug_info=debug), kernel_in_tree)
|
||||
wrapped_kernel_fun = primitives.wrap_with_transforms(
|
||||
wrapped_kernel_fun, kernel_in_transforms
|
||||
)
|
||||
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
|
||||
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
kernel_avals)
|
||||
if consts:
|
||||
consts_avals = [jax_core.get_aval(c) for c in consts]
|
||||
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
||||
@ -1568,7 +1569,7 @@ def pallas_call(
|
||||
kernel_fun_sig = api_util.fun_signature(kernel)
|
||||
arg_names = None
|
||||
if kernel_fun_sig:
|
||||
kernel_debug_info = api_util.tracing_debug_info(
|
||||
kernel_debug_info = api_util.debug_info(
|
||||
"pallas_call kernel",
|
||||
kernel,
|
||||
[1] * len(kernel_fun_sig.parameters), {})
|
||||
|
@ -896,22 +896,25 @@ def _run_scoped_discharge_rule(
|
||||
**_):
|
||||
del out_avals
|
||||
num_consts = len(args_flat)
|
||||
# discharge_state only discharges invars, not consts, so in order to
|
||||
# discharge the requested refs we need to move them to the invar set.
|
||||
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
||||
num_return_values = len(jaxpr_noconst.outvars)
|
||||
should_discharge = should_discharge + [
|
||||
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
|
||||
]
|
||||
discharged_body, new_consts = state_discharge.discharge_state(
|
||||
jaxpr_noconst, [], should_discharge=should_discharge)
|
||||
jaxpr_noconst,
|
||||
[],
|
||||
should_discharge=should_discharge + [False] * len(jaxpr.invars),
|
||||
)
|
||||
if new_consts:
|
||||
raise NotImplementedError(
|
||||
"Cannot handle new consts created by state discharge.")
|
||||
# Create inputs filled with uninitialized values to the body.
|
||||
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
|
||||
init_vals = [uninitialized_value(
|
||||
aval.shape, aval.dtype) for aval in body_avals]
|
||||
init_vals_with_consts = args_flat + tuple(init_vals)
|
||||
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)
|
||||
|
||||
# Lowering expects that the jaxpr.consts to be the eqn.invals.
|
||||
discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts)
|
||||
|
||||
# Run_scoped discharged the external variables but the scoped ones
|
||||
# are not discharged.
|
||||
out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body)
|
||||
# Order of outputs:
|
||||
# (1) return values, (2) closed refs, (3) scoped refs.
|
||||
return_values = out[:num_return_values]
|
||||
@ -919,8 +922,8 @@ def _run_scoped_discharge_rule(
|
||||
# We update all ref values with their updated values from the discharged
|
||||
# body. For other values we leave them in place.
|
||||
updates = [
|
||||
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
|
||||
else None for aval in in_avals]
|
||||
ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef)
|
||||
else None for should, aval in zip(should_discharge, in_avals)]
|
||||
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
|
||||
return updates, return_values
|
||||
|
||||
@ -931,17 +934,20 @@ state_discharge.register_partial_discharge_rule(run_scoped_p)(
|
||||
|
||||
@functools.partial(mlir.register_lowering, run_scoped_p)
|
||||
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
|
||||
# This lowering rule gets triggered when run_scoped is not discharged.
|
||||
# In this case there are no stateful effects to handle.
|
||||
should_discharge = [
|
||||
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
|
||||
]
|
||||
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
||||
num_return_values = len(jaxpr_noconst.outvars)
|
||||
discharged_body, new_consts = state_discharge.discharge_state(
|
||||
jaxpr_noconst, [], should_discharge=True)
|
||||
if new_consts: raise NotImplementedError(
|
||||
"Cannot handle new consts created by state discharge.")
|
||||
|
||||
def _lower_fun(*lower_fun_args):
|
||||
updates, out = _run_scoped_discharge_rule(
|
||||
should_discharge,
|
||||
[], [], *lower_fun_args,
|
||||
jaxpr=jaxpr)
|
||||
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
|
||||
return out
|
||||
# Create inputs filled with uninitialized values to the body.
|
||||
num_consts = len(lower_fun_args)
|
||||
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
|
||||
init_vals = [uninitialized_value(
|
||||
aval.shape, aval.dtype) for aval in body_avals]
|
||||
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
|
||||
return out[:num_return_values]
|
||||
|
||||
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
|
||||
|
@ -593,12 +593,19 @@ class _Extern:
|
||||
|
||||
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, self.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
|
||||
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
|
||||
bcast_args.append(bcast_arg)
|
||||
|
||||
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
|
||||
if out_aval.shape:
|
||||
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
|
||||
return tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
args,
|
||||
bcast_args,
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol=self.symbol,
|
||||
@ -608,10 +615,23 @@ class _Extern:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Fallback:
|
||||
arg_types: Sequence[jax.typing.DTypeLike]
|
||||
lower: Callable[..., ir.Value]
|
||||
arg_classes: Sequence[jax.typing.DTypeLike]
|
||||
op: Callable[..., ir.Value]
|
||||
|
||||
matches = _Extern.matches
|
||||
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
|
||||
if len(avals) != len(self.arg_classes):
|
||||
return False
|
||||
return all(
|
||||
jnp.issubdtype(aval.dtype, arg_class)
|
||||
for aval, arg_class in zip(avals, self.arg_classes)
|
||||
)
|
||||
|
||||
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg in zip(ctx.avals_in, args):
|
||||
bcast_args.append(_bcast_to(_ensure_ir_value(arg, aval), out_aval.shape))
|
||||
return self.op(*args)
|
||||
|
||||
|
||||
def _make_dispatch_table(
|
||||
@ -626,390 +646,452 @@ def _make_dispatch_table(
|
||||
raise NotImplementedError(
|
||||
f"unsupported types for {name}: {arg_aval_dtypes}"
|
||||
)
|
||||
|
||||
[out_aval] = ctx.avals_out
|
||||
bcast_args = []
|
||||
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
|
||||
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
|
||||
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
|
||||
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
|
||||
bcast_args.append(bcast_arg)
|
||||
return h.lower(ctx, *bcast_args)
|
||||
return h.lower(ctx, *args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_abs_dispatch_table = _make_dispatch_table(
|
||||
abs_dispatch_table = _make_dispatch_table(
|
||||
"abs",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_abs", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_llabs", jnp.int64),
|
||||
_Extern([jnp.float32], "__nv_fabsf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_fabs", jnp.float64),
|
||||
_Fallback([jnp.integer], math_dialect.absi),
|
||||
_Fallback([jnp.floating], math_dialect.absf),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)),
|
||||
_Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)),
|
||||
_Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)),
|
||||
_Fallback([jnp.integer], math_dialect.absi),
|
||||
_Fallback([jnp.floating], math_dialect.absf),
|
||||
],
|
||||
)
|
||||
|
||||
ceil_dispatch_table = _make_dispatch_table(
|
||||
"ceil",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.ceil),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.ceil),
|
||||
],
|
||||
)
|
||||
|
||||
@register_lowering(lax.abs_p)
|
||||
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
try:
|
||||
return _abs_dispatch_table(ctx, x)
|
||||
except NotImplementedError as e:
|
||||
[x_aval] = ctx.avals_in
|
||||
if jnp.issubdtype(x_aval, jnp.integer):
|
||||
return math_dialect.absi(x)
|
||||
elif jnp.issubdtype(x_aval, jnp.floating):
|
||||
return math_dialect.absf(x)
|
||||
else:
|
||||
raise e from None
|
||||
floor_dispatch_table = _make_dispatch_table(
|
||||
"floor",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_floor", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.floor),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.floor),
|
||||
],
|
||||
)
|
||||
|
||||
exp_dispatch_table = _make_dispatch_table(
|
||||
"exp",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.float32], math_dialect.exp),
|
||||
_Fallback([jnp.float64], math_dialect.exp),
|
||||
_Fallback([jnp.floating], math_dialect.exp),
|
||||
],
|
||||
)
|
||||
|
||||
exp2_dispatch_table = _make_dispatch_table(
|
||||
"exp2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp2),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.exp2),
|
||||
],
|
||||
)
|
||||
|
||||
expm1_dispatch_table = _make_dispatch_table(
|
||||
"expm1",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.expm1),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.expm1),
|
||||
],
|
||||
)
|
||||
|
||||
log_dispatch_table = _make_dispatch_table(
|
||||
"log",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_logf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log),
|
||||
],
|
||||
)
|
||||
|
||||
log1p_dispatch_table = _make_dispatch_table(
|
||||
"log1p",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log1p),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.log1p),
|
||||
],
|
||||
)
|
||||
|
||||
sqrt_dispatch_table = _make_dispatch_table(
|
||||
"sqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sqrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sqrt),
|
||||
],
|
||||
)
|
||||
|
||||
pow_dispatch_table = _make_dispatch_table(
|
||||
"pow",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
math_dialect.fpowi
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
math_dialect.powf
|
||||
),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
math_dialect.fpowi
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
math_dialect.powf
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
cbrt_dispatch_table = _make_dispatch_table(
|
||||
"cbrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cbrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cbrt),
|
||||
],
|
||||
)
|
||||
|
||||
rsqrt_dispatch_table = _make_dispatch_table(
|
||||
"rsqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.rsqrt),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.rsqrt),
|
||||
],
|
||||
)
|
||||
|
||||
sin_dispatch_table = _make_dispatch_table(
|
||||
"sin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sin", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sin),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sin),
|
||||
],
|
||||
)
|
||||
|
||||
cos_dispatch_table = _make_dispatch_table(
|
||||
"cos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cos", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cos),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cos),
|
||||
],
|
||||
)
|
||||
|
||||
tan_dispatch_table = _make_dispatch_table(
|
||||
"tan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tan", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tan),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tan),
|
||||
],
|
||||
)
|
||||
|
||||
asin_dispatch_table = _make_dispatch_table(
|
||||
"asin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asin", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asin),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asin),
|
||||
],
|
||||
)
|
||||
|
||||
acos_dispatch_table = _make_dispatch_table(
|
||||
"acos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acos", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acos),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acos),
|
||||
],
|
||||
)
|
||||
|
||||
atan_dispatch_table = _make_dispatch_table(
|
||||
"atan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atan", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atan),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atan),
|
||||
],
|
||||
)
|
||||
|
||||
atan2_dispatch_table = _make_dispatch_table(
|
||||
"atan2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
|
||||
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64),
|
||||
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
|
||||
],
|
||||
)
|
||||
|
||||
sinh_dispatch_table = _make_dispatch_table(
|
||||
"sinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sinh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.sinh),
|
||||
],
|
||||
)
|
||||
|
||||
cosh_dispatch_table = _make_dispatch_table(
|
||||
"cosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cosh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.cosh),
|
||||
],
|
||||
)
|
||||
|
||||
tanh_dispatch_table = _make_dispatch_table(
|
||||
"tanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tanh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.tanh),
|
||||
],
|
||||
)
|
||||
|
||||
asinh_dispatch_table = _make_dispatch_table(
|
||||
"asinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asinh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.asinh),
|
||||
],
|
||||
)
|
||||
|
||||
acosh_dispatch_table = _make_dispatch_table(
|
||||
"acosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acosh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.acosh),
|
||||
],
|
||||
)
|
||||
|
||||
atanh_dispatch_table = _make_dispatch_table(
|
||||
"atanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atanh),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
|
||||
_Fallback([jnp.floating], math_dialect.atanh),
|
||||
],
|
||||
)
|
||||
|
||||
population_count_dispatch_table = _make_dispatch_table(
|
||||
"population_count",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_popc", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
|
||||
_Fallback([jnp.integer], math_dialect.ctpop),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.integer], math_dialect.ctpop),
|
||||
],
|
||||
)
|
||||
|
||||
clz_dispatch_table = _make_dispatch_table(
|
||||
"clz",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_clz", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
|
||||
_Fallback([jnp.integer], math_dialect.ctlz),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.integer], math_dialect.ctlz),
|
||||
],
|
||||
)
|
||||
|
||||
nextafter_dispatch_table = _make_dispatch_table(
|
||||
"nextafter",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
triton_lowering_rules.update({
|
||||
lax.abs_p: abs_dispatch_table,
|
||||
lax.neg_p: lambda ctx, x: _minus(x),
|
||||
lax.ceil_p: _make_dispatch_table(
|
||||
"ceil",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.floor_p: _make_dispatch_table(
|
||||
"floor",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_floor", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
|
||||
],
|
||||
),
|
||||
lax.exp_p: _make_dispatch_table(
|
||||
"exp",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
|
||||
],
|
||||
),
|
||||
lax.exp2_p: _make_dispatch_table(
|
||||
"exp2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
|
||||
],
|
||||
),
|
||||
lax.expm1_p: _make_dispatch_table(
|
||||
"expm1",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.log_p: _make_dispatch_table(
|
||||
"log",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_logf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
|
||||
],
|
||||
),
|
||||
lax.log1p_p: _make_dispatch_table(
|
||||
"log1p",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.sqrt_p: _make_dispatch_table(
|
||||
"sqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
),
|
||||
lax.ceil_p: ceil_dispatch_table,
|
||||
lax.floor_p: floor_dispatch_table,
|
||||
lax.exp_p: exp_dispatch_table,
|
||||
lax.exp2_p: exp2_dispatch_table,
|
||||
lax.expm1_p: expm1_dispatch_table,
|
||||
lax.log_p: log_dispatch_table,
|
||||
lax.log1p_p: log1p_dispatch_table,
|
||||
lax.sqrt_p: sqrt_dispatch_table,
|
||||
lax.square_p: lambda ctx, x: _mul(x, x),
|
||||
lax.pow_p: _make_dispatch_table(
|
||||
"pow",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.cbrt_p: _make_dispatch_table(
|
||||
"cbrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.rsqrt_p: _make_dispatch_table(
|
||||
"rsqrt",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.sin_p: _make_dispatch_table(
|
||||
"sin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sin", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
|
||||
],
|
||||
),
|
||||
lax.cos_p: _make_dispatch_table(
|
||||
"cos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cos", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
|
||||
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
|
||||
],
|
||||
),
|
||||
lax.tan_p: _make_dispatch_table(
|
||||
"tan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tan", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.asin_p: _make_dispatch_table(
|
||||
"asin",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asin", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.acos_p: _make_dispatch_table(
|
||||
"acos",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acos", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atan_p: _make_dispatch_table(
|
||||
"atan",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atan", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atan2_p: _make_dispatch_table(
|
||||
"atan2",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
),
|
||||
lax.sinh_p: _make_dispatch_table(
|
||||
"sinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.cosh_p: _make_dispatch_table(
|
||||
"cosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.tanh_p: _make_dispatch_table(
|
||||
"tanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.asinh_p: _make_dispatch_table(
|
||||
"asinh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.acosh_p: _make_dispatch_table(
|
||||
"acosh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.atanh_p: _make_dispatch_table(
|
||||
"atanh",
|
||||
cuda=[
|
||||
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
|
||||
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
|
||||
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
|
||||
],
|
||||
),
|
||||
lax.population_count_p: _make_dispatch_table(
|
||||
"population_count",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_popc", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)),
|
||||
],
|
||||
),
|
||||
lax.clz_p: _make_dispatch_table(
|
||||
"clz",
|
||||
cuda=[
|
||||
_Extern([jnp.int32], "__nv_clz", jnp.int32),
|
||||
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
|
||||
],
|
||||
rocm=[
|
||||
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)),
|
||||
],
|
||||
),
|
||||
lax.nextafter_p: _make_dispatch_table(
|
||||
"nextafter",
|
||||
cuda=[
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
|
||||
],
|
||||
rocm=[
|
||||
_Extern(
|
||||
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
|
||||
),
|
||||
_Extern(
|
||||
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
|
||||
),
|
||||
],
|
||||
),
|
||||
lax.pow_p: pow_dispatch_table,
|
||||
lax.cbrt_p: cbrt_dispatch_table,
|
||||
lax.rsqrt_p: rsqrt_dispatch_table,
|
||||
lax.sin_p: sin_dispatch_table,
|
||||
lax.cos_p: cos_dispatch_table,
|
||||
lax.tan_p: tan_dispatch_table,
|
||||
lax.asin_p: asin_dispatch_table,
|
||||
lax.acos_p: acos_dispatch_table,
|
||||
lax.atan_p: atan_dispatch_table,
|
||||
lax.atan2_p: atan2_dispatch_table,
|
||||
lax.sinh_p: sinh_dispatch_table,
|
||||
lax.cosh_p: cosh_dispatch_table,
|
||||
lax.tanh_p: tanh_dispatch_table,
|
||||
lax.asinh_p: asinh_dispatch_table,
|
||||
lax.acosh_p: acosh_dispatch_table,
|
||||
lax.atanh_p: atanh_dispatch_table,
|
||||
lax.population_count_p: population_count_dispatch_table,
|
||||
lax.clz_p: clz_dispatch_table,
|
||||
lax.nextafter_p: nextafter_dispatch_table,
|
||||
})
|
||||
|
||||
|
||||
@ -2211,6 +2293,10 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
ctx.context, jaxpr.jaxpr, ctx.block_infos, *args
|
||||
)
|
||||
|
||||
@register_lowering(pjit.mesh_cast_p)
|
||||
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
|
||||
return x
|
||||
|
||||
|
||||
@register_lowering(jax_core.closed_call_p)
|
||||
@register_lowering(custom_derivatives.custom_jvp_call_p)
|
||||
|
@ -17,10 +17,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import re
|
||||
from typing import Any
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
import jax._src.core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import triton
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.triton import lowering
|
||||
@ -51,7 +57,7 @@ def pallas_call_lowering(
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
):
|
||||
del interpret, out_avals
|
||||
del interpret, cost_estimate, out_avals
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
@ -77,6 +83,11 @@ def pallas_call_lowering(
|
||||
print("The grid mapping for pallas_call {name_and_src_info}:")
|
||||
print(grid_mapping)
|
||||
|
||||
# Sanitize the name to conform to NVPTX requirements. We do this here
|
||||
# to avoid the need to fetch the new name from PTX post compilation.
|
||||
name_and_src_info = name_and_src_info.replace(
|
||||
name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name)
|
||||
)
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, grid_mapping, name_and_src_info, lowering_platform
|
||||
)
|
||||
@ -86,35 +97,93 @@ def pallas_call_lowering(
|
||||
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
|
||||
|
||||
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
|
||||
out_types = [
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
|
||||
if jaxlib_version < (0, 5, 1):
|
||||
# AOT Triton compilation is only available on jaxlib 0.5.1+.
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
|
||||
mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype))
|
||||
for bm in grid_mapping.block_mappings_output
|
||||
]
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
backend_config = dict(
|
||||
name=ir.StringAttr.get(name_and_src_info.name),
|
||||
ir=ir.StringAttr.get(buf.getvalue()),
|
||||
num_stages=mlir.i32_attr(num_stages),
|
||||
num_warps=mlir.i32_attr(num_warps),
|
||||
grid_x=mlir.i32_attr(grid_x),
|
||||
grid_y=mlir.i32_attr(grid_y),
|
||||
grid_z=mlir.i32_attr(grid_z),
|
||||
debug=ir.BoolAttr.get(debug),
|
||||
]
|
||||
backend_config = dict(
|
||||
name=ir.StringAttr.get(name_and_src_info.name),
|
||||
ir=ir.StringAttr.get(buf.getvalue()),
|
||||
num_stages=mlir.i32_attr(num_stages),
|
||||
num_warps=mlir.i32_attr(num_warps),
|
||||
grid_x=mlir.i32_attr(grid_x),
|
||||
grid_y=mlir.i32_attr(grid_y),
|
||||
grid_z=mlir.i32_attr(grid_z),
|
||||
debug=ir.BoolAttr.get(debug),
|
||||
)
|
||||
if "serialized_metadata" in (triton_params or {}):
|
||||
# This field is unstable and may be removed in the future.
|
||||
if triton_params["serialized_metadata"] is not None:
|
||||
backend_config["serialized_metadata"] = ir.StringAttr.get(
|
||||
triton_params["serialized_metadata"]
|
||||
)
|
||||
return mlir.custom_call(
|
||||
call_target_name="__gpu$xla.gpu.triton",
|
||||
result_types=out_types,
|
||||
operands=in_nodes,
|
||||
backend_config=backend_config,
|
||||
api_version=4,
|
||||
operand_layouts=avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
).results
|
||||
|
||||
# TODO(slebedev): Make this work for ROCm.
|
||||
try:
|
||||
gpu_device, *_ = jax.local_devices(backend="gpu")
|
||||
except RuntimeError:
|
||||
# GPU device is not available. Fall back to the minimum CC supported by Triton.
|
||||
# TODO(slebedev): Make the fallback CC configurable.
|
||||
arch_name = "8.0"
|
||||
cc = 80
|
||||
else:
|
||||
arch_name = str(gpu_device.compute_capability)
|
||||
cc = int(arch_name.replace(".", ""))
|
||||
|
||||
compilation_result = triton.compile(
|
||||
lowering_platform,
|
||||
buf.getvalue(),
|
||||
arch_name,
|
||||
num_warps=num_warps,
|
||||
num_ctas=1,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
if "serialized_metadata" in (triton_params or {}):
|
||||
# This field is unstable and may be removed in the future.
|
||||
if triton_params["serialized_metadata"] is not None:
|
||||
backend_config["serialized_metadata"] = ir.StringAttr.get(
|
||||
triton_params["serialized_metadata"]
|
||||
)
|
||||
kernel = triton_kernel_call_lib.TritonKernel(
|
||||
name_and_src_info.name,
|
||||
num_warps,
|
||||
compilation_result.smem_bytes,
|
||||
compilation_result.asm,
|
||||
module_op.get_asm(enable_debug_info=True, pretty_debug_info=True),
|
||||
cc,
|
||||
compilation_result.cluster_dim_x,
|
||||
compilation_result.cluster_dim_y,
|
||||
compilation_result.cluster_dim_z,
|
||||
)
|
||||
kernel_call = triton_kernel_call_lib.TritonKernelCall(
|
||||
kernel,
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
[triton_kernel_call_lib.create_array_parameter(0, 16)]
|
||||
* (len(ctx.avals_in) + len(ctx.avals_out)),
|
||||
)
|
||||
# TODO(b/392558289): Migrate to ``jax.ffi``.
|
||||
return mlir.custom_call(
|
||||
call_target_name="__gpu$xla.gpu.triton",
|
||||
result_types=out_types,
|
||||
call_target_name="triton_kernel_call",
|
||||
result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item]
|
||||
operands=in_nodes,
|
||||
backend_config=backend_config,
|
||||
api_version=4,
|
||||
backend_config=zlib.compress(
|
||||
kernel_call.to_proto(
|
||||
name_and_src_info.name,
|
||||
triton_params.get("serialized_metadata") or b"",
|
||||
)
|
||||
),
|
||||
operand_layouts=avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
|
@ -25,15 +25,6 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def when(condition):
|
||||
def _wrapped(f):
|
||||
if isinstance(condition, bool):
|
||||
if condition:
|
||||
f()
|
||||
else:
|
||||
lax.cond(condition, f, lambda: None)
|
||||
return _wrapped
|
||||
|
||||
@overload
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
...
|
||||
|
@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, check_callable, resolve_argnums,
|
||||
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
|
||||
argnames_partial_except, debug_info,
|
||||
hoist_obj_attrs, _check_no_aliased_ref_args,
|
||||
_check_no_aliased_closed_over_refs)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -548,7 +548,7 @@ def _infer_params_impl(
|
||||
ji: PjitInfo,
|
||||
pjit_mesh: mesh_lib.Mesh | None,
|
||||
resource_env: mesh_lib.ResourceEnv | None,
|
||||
dbg: lu.TracingDebugInfo,
|
||||
dbg: core.DebugInfo,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
in_avals: tuple[core.AbstractValue, ...] | None,
|
||||
@ -567,9 +567,7 @@ def _infer_params_impl(
|
||||
|
||||
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, res_paths = result_paths(f)
|
||||
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
|
||||
f = lu.wrap_init(fun, debug_info=dbg)
|
||||
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
|
||||
del args
|
||||
|
||||
@ -618,7 +616,7 @@ def _infer_params_impl(
|
||||
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves,
|
||||
ji.in_layouts_treedef, ji.in_layouts_leaves,
|
||||
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
|
||||
in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs)
|
||||
|
||||
attr_token = _attr_token(flat_fun, in_type)
|
||||
|
||||
@ -627,8 +625,7 @@ def _infer_params_impl(
|
||||
if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh())
|
||||
with mesh_lib.set_abstract_mesh(abstract_mesh):
|
||||
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
HashableFunction(res_paths, closure=()),
|
||||
flat_fun, in_type, attr_token,
|
||||
IgnoreKey(ji.inline))
|
||||
if config.mutable_array_checks.value:
|
||||
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
|
||||
@ -733,7 +730,7 @@ def _infer_params(
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed.')
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
dbg = debug_info(
|
||||
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
|
||||
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
|
||||
signature=ji.fun_signature)
|
||||
@ -756,7 +753,7 @@ def _infer_params(
|
||||
entry.pjit_params = p
|
||||
return entry.pjit_params, entry.pjit_params.consts + dynargs
|
||||
|
||||
def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None,
|
||||
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
|
||||
explicit_args) -> tuple[core.AbstractValue, ...]:
|
||||
avals = []
|
||||
try:
|
||||
@ -1171,17 +1168,18 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
callsites: set[str] = set()
|
||||
|
||||
def explain_tracing_cache_miss(
|
||||
f: Callable, unseen_f: bool, cache: dict, key: tuple):
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
|
||||
if config.check_tracer_leaks.value: return
|
||||
|
||||
def unpack(key):
|
||||
transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key
|
||||
transforms, (), _, (in_type, _, inline), *_, ctx = key
|
||||
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
|
||||
_, (_, (in_tree,)), *_ = transforms
|
||||
return in_tree, in_type, debug_info, inline.val, ctx
|
||||
in_tree, in_type, debug_info, inline, ctx = unpack(key)
|
||||
return in_tree, in_type, inline.val, ctx
|
||||
in_tree, in_type, inline, ctx = unpack(key)
|
||||
if inline: return
|
||||
|
||||
debug_info = fun.debug_info
|
||||
msg: list[str] = []
|
||||
p = msg.append
|
||||
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
|
||||
@ -1190,7 +1188,7 @@ def explain_tracing_cache_miss(
|
||||
p(f"TRACING CACHE MISS at {callsite} because:")
|
||||
|
||||
# have we seen this function before at all?
|
||||
fun_name = getattr(f, '__qualname__', f)
|
||||
fun_name = getattr(fun.f, '__qualname__', fun.f)
|
||||
if debug_info is not None and debug_info.func_src_info:
|
||||
# TODO(necula): clean up the extraction of the source info
|
||||
_, *rest = debug_info.func_src_info.split(' at ')
|
||||
@ -1198,7 +1196,7 @@ def explain_tracing_cache_miss(
|
||||
else:
|
||||
src_info = ''
|
||||
if unseen_f:
|
||||
p(f" never seen function:\n {fun_name} id={id(f)}{src_info}")
|
||||
p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}")
|
||||
if callsite in callsites:
|
||||
p(" but seen another function defined on the same line; maybe the function is\n"
|
||||
" being re-defined repeatedly, preventing caching?")
|
||||
@ -1263,7 +1261,7 @@ def explain_tracing_cache_miss(
|
||||
# have we never seen these input types (eg shapes, dtypes) before?
|
||||
types_match = [k for k in trees_match if k[1] == in_type]
|
||||
if not types_match:
|
||||
if len(in_type) < 5:
|
||||
if len(in_type) < 5 and debug_info is not None:
|
||||
in_type_str = ':\n {}'.format(', '.join(
|
||||
f'{n}: {ty.str_short(short_dtypes=True)}'
|
||||
for n, ty in zip(debug_info.arg_names, in_type)))
|
||||
@ -1275,7 +1273,12 @@ def explain_tracing_cache_miss(
|
||||
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
||||
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
||||
add_weak_type_hint = False
|
||||
for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type):
|
||||
if debug_info:
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
else:
|
||||
arg_names = (None,) * len(in_type)
|
||||
|
||||
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
|
||||
if ty1 != ty2:
|
||||
if type(ty1) == type(ty2) == core.ShapedArray:
|
||||
s1, s2 = ty1.str_short(True), ty2.str_short(True)
|
||||
@ -1302,8 +1305,6 @@ def _create_pjit_jaxpr(
|
||||
fun: lu.WrappedFun,
|
||||
in_type: core.InputType | Sequence[core.AbstractValue],
|
||||
attr_data: int,
|
||||
debug_info: lu.TracingDebugInfo,
|
||||
result_paths: Callable,
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
@ -1317,17 +1318,13 @@ def _create_pjit_jaxpr(
|
||||
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
||||
if config.dynamic_shapes.value:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
|
||||
lu.annotate(fun, cast(core.InputType, in_type)))
|
||||
attrs_tracked = []
|
||||
else:
|
||||
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_type, debug_info=debug_info)
|
||||
fun, in_type)
|
||||
# assert attr_data is sentinel or attr_data matches attrs_tracked
|
||||
|
||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
|
||||
|
||||
if config.debug_key_reuse.value:
|
||||
# Import here to avoid circular imports
|
||||
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
|
||||
@ -1346,7 +1343,7 @@ def _create_pjit_jaxpr(
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_avals,
|
||||
debug_info: core.JaxprDebugInfo | None,
|
||||
debug_info: core.DebugInfo | None,
|
||||
device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
|
||||
@ -1479,7 +1476,6 @@ def check_aval_layout_compatibility(
|
||||
pjit_p = core.Primitive("pjit")
|
||||
pjit_p.multiple_results = True
|
||||
|
||||
|
||||
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
# If device or backend is set, return the default layout. This is because you
|
||||
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
|
||||
@ -1928,7 +1924,9 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
|
||||
name: str, jaxpr: core.ClosedJaxpr,
|
||||
effects, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name):
|
||||
mod_ctx = ctx.module_context
|
||||
@ -1959,7 +1957,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
return func
|
||||
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline, compiler_options_kvs):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
@ -1987,8 +1986,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
mlir.register_lowering(pjit_p, _pjit_lowering)
|
||||
|
||||
|
||||
def _pjit_batcher(axis_data, vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
def _pjit_batcher(axis_data, vals_in,
|
||||
dims_in: tuple[int, ...],
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
@ -2037,7 +2038,8 @@ batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Sharding | UnspecifiedValue,
|
||||
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
|
||||
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
|
||||
ndim: int):
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
return s
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim)
|
||||
@ -2049,7 +2051,7 @@ def _pjit_batcher_for_sharding(
|
||||
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
|
||||
new_op = hlo_s.to_proto().clone()
|
||||
tad = list(new_op.tile_assignment_dimensions)
|
||||
tad.insert(dim, 1)
|
||||
tad.insert(dim, 1) # type: ignore
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op,
|
||||
@ -2171,8 +2173,9 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
|
||||
ad.primitive_linearizations[pjit_p] = _pjit_linearization
|
||||
|
||||
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
def _pjit_partial_eval(trace: pe.JaxprTrace,
|
||||
*in_tracers,
|
||||
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
name, keep_unused, inline, compiler_options_kvs):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
@ -2191,7 +2194,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
else:
|
||||
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
|
||||
unknown_outs = tuple(unknown_outs)
|
||||
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||
@ -2282,7 +2285,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_out_avals = unknown_jaxpr.out_avals
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore
|
||||
for aval in unknown_out_avals
|
||||
]
|
||||
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
|
||||
@ -2707,14 +2710,20 @@ def mesh_cast(xs, out_shardings):
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
mesh_cast_p = core.Primitive('mesh_cast')
|
||||
mesh_cast_p.skip_canonicalization = True
|
||||
def _mesh_cast_abstract_eval(aval, dst_sharding):
|
||||
src_sharding = aval.sharding
|
||||
if src_sharding == dst_sharding:
|
||||
return aval
|
||||
if src_sharding.mesh.empty or dst_sharding.mesh.empty:
|
||||
return aval.update(sharding=dst_sharding)
|
||||
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
|
||||
raise ValueError(
|
||||
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
|
||||
' match the mesh shape of the target sharding'
|
||||
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
|
||||
if src_sharding.mesh.axis_types == dst_sharding.mesh.axis_types:
|
||||
if (src_sharding.mesh.axis_types == dst_sharding.mesh.axis_types and
|
||||
src_sharding.spec != dst_sharding.spec):
|
||||
raise ValueError(
|
||||
'mesh_cast should only be used when AxisTypes changes between the'
|
||||
' input mesh and the target mesh. Got src'
|
||||
@ -2746,7 +2755,9 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
|
||||
mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval)
|
||||
|
||||
def _mesh_cast_impl(x, dst_sharding):
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
|
||||
x_aval = core.shaped_abstractify(x)
|
||||
with mesh_lib.set_abstract_mesh(x_aval.sharding.mesh):
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
|
||||
mesh_cast_p.def_impl(_mesh_cast_impl)
|
||||
|
||||
def _mesh_cast_transpose_rule(ct, x, dst_sharding):
|
||||
@ -2763,7 +2774,6 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, dst_sharding):
|
||||
mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering)
|
||||
|
||||
def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding):
|
||||
assert axis_data.spmd_name is None
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
vmapped_dst_sharding = batching.get_sharding_for_vmap(
|
||||
|
@ -2070,8 +2070,8 @@ def orthogonal(
|
||||
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
||||
z = normal(key, (*shape, n, n), dtype)
|
||||
q, r = jnp.linalg.qr(z)
|
||||
d = jnp.diagonal(r, 0, -2, -1)
|
||||
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
|
||||
d = jnp.linalg.diagonal(r)
|
||||
return q * jnp.expand_dims(jnp.sign(d), -2)
|
||||
|
||||
def generalized_normal(
|
||||
key: ArrayLike,
|
||||
|
@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
|
||||
dbg: api_util.TracingDebugInfo,
|
||||
dbg: core.DebugInfo,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
|
||||
|
||||
@ -997,17 +997,18 @@ def initial_style_jaxpr(
|
||||
def _initial_style_jaxpr(fun: Callable,
|
||||
in_tree: api_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
debug: api_util.TracingDebugInfo):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
debug: core.DebugInfo):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
dbg = api_util.debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
# There may be some uninitialized values here in ref_args.
|
||||
@ -1027,7 +1028,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
|
||||
|
||||
def run_state_reference(f: Callable[..., None]):
|
||||
def wrapped(args):
|
||||
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
|
||||
dbg = api_util.debug_info("run_state", f, (args,), {})
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
|
||||
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)
|
||||
|
@ -217,7 +217,7 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
|
||||
# TODO(yashkatariya): Transform the sharding too instead of setting it to
|
||||
# None.
|
||||
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
|
||||
sharding=None)
|
||||
sharding=core.get_cur_mesh_sharding())
|
||||
else:
|
||||
if transforms:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
|
@ -266,6 +266,14 @@ class TransformedRef:
|
||||
(*self.transforms, RefReshaper.from_ref_new_shape(self, *shape)),
|
||||
)
|
||||
|
||||
def set(self, value, idx=()):
|
||||
from jax._src.state.primitives import ref_set # pytype: disable=import-error
|
||||
return ref_set(self, idx, value)
|
||||
|
||||
def get(self, idx=()):
|
||||
from jax._src.state.primitives import ref_get # pytype: disable=import-error
|
||||
return ref_get(self, idx)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.ref, name)
|
||||
|
||||
|
@ -537,7 +537,7 @@ def request_cpu_devices(nr_devices: int):
|
||||
invoked. Test cases that require a specific number of devices should skip
|
||||
themselves if that number is not met.
|
||||
"""
|
||||
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
|
||||
if config.num_cpu_devices.value < nr_devices:
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
config.update("jax_num_cpu_devices", nr_devices)
|
||||
|
||||
|
@ -104,17 +104,6 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
|
||||
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
|
||||
)
|
||||
|
||||
CPU_COLLECTIVES_IMPLEMENTATIONS = ["none", "gloo", "mpi"]
|
||||
CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
|
||||
name="jax_cpu_collectives_implementation",
|
||||
default="none",
|
||||
enum_values=CPU_COLLECTIVES_IMPLEMENTATIONS,
|
||||
help=(
|
||||
"Cross-process collective implementation used on CPU. Must be one of"
|
||||
f" {CPU_COLLECTIVES_IMPLEMENTATIONS}"
|
||||
),
|
||||
)
|
||||
|
||||
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
name="jax_cpu_enable_async_dispatch",
|
||||
default=True,
|
||||
@ -122,14 +111,6 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
|
||||
"inline without async dispatch.",
|
||||
)
|
||||
|
||||
NUM_CPU_DEVICES = config.int_flag(
|
||||
name="jax_num_cpu_devices",
|
||||
default=-1,
|
||||
help="Number of CPU devices to use. If not provided, the value of "
|
||||
"the XLA flag --xla_force_host_platform_device_count is used."
|
||||
" Must be set before JAX is initialized.",
|
||||
)
|
||||
|
||||
|
||||
# Warn the user if they call fork(), because it's not going to go well for them.
|
||||
def _at_fork():
|
||||
@ -255,7 +236,7 @@ def make_cpu_client(
|
||||
The created CPU client.
|
||||
"""
|
||||
if collectives is None:
|
||||
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
|
||||
collectives_impl = config.cpu_collectives_implementation.value
|
||||
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
|
||||
collectives_impl = 'gloo'
|
||||
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
|
||||
@ -271,12 +252,13 @@ def make_cpu_client(
|
||||
collectives = xla_client._xla.make_mpi_collectives()
|
||||
collectives.Init()
|
||||
atexit.register(collectives.Finalize)
|
||||
elif collectives_impl != 'none':
|
||||
raise RuntimeError(f"Unknown collectives implementation "
|
||||
f"{collectives_impl}. Available implementations are "
|
||||
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
|
||||
elif collectives_impl == 'megascale':
|
||||
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
|
||||
else:
|
||||
# Already validated by config module
|
||||
assert collectives_impl is None
|
||||
|
||||
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
|
||||
num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None
|
||||
return xla_client.make_cpu_client(
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
|
@ -20,13 +20,13 @@ from jax._src.core import (
|
||||
AbstractValue as AbstractValue,
|
||||
Atom as Atom,
|
||||
CallPrimitive as CallPrimitive,
|
||||
DebugInfo as DebugInfo,
|
||||
DShapedArray as DShapedArray,
|
||||
DropVar as DropVar,
|
||||
Effect as Effect,
|
||||
Effects as Effects,
|
||||
get_opaque_trace_state as get_opaque_trace_state,
|
||||
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
|
||||
JaxprDebugInfo as JaxprDebugInfo,
|
||||
JaxprPpContext as JaxprPpContext,
|
||||
JaxprPpSettings as JaxprPpSettings,
|
||||
JaxprTypeError as JaxprTypeError,
|
||||
|
@ -85,8 +85,9 @@ from .utils import (
|
||||
warpgroup_idx as warpgroup_idx,
|
||||
when as when,
|
||||
)
|
||||
# The import below shadows the module, so we need to rename it.
|
||||
from . import wgmma as _wgmma # noqa: F401
|
||||
from .wgmma import (
|
||||
WGMMAAccumulator as WGMMAAccumulator,
|
||||
WGMMALayout as WGMMALayout,
|
||||
wgmma as wgmma,
|
||||
)
|
||||
|
@ -220,7 +220,7 @@ def build_kernel(
|
||||
# TODO(apaszke): Support WGMMA without an initial accumulator.
|
||||
qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv)
|
||||
q, k = qo_smem, memref_slice(k_smem, slot)
|
||||
qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR)
|
||||
qk_acc = wgmma(qk_acc, q, memref_transpose(k, (0, 1, 3, 2)))
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
|
||||
perform_schedule_barrier()
|
||||
@ -441,7 +441,7 @@ def build_kernel(
|
||||
# TODO(apaszke): Support WGMMA without an initial accumulator.
|
||||
qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv)
|
||||
q, k = qo_smem, memref_slice(k_smem, slot)
|
||||
qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR)
|
||||
qk_acc = wgmma(qk_acc, q, memref_transpose(k, (0, 1, 3, 2)))
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
|
||||
# We hide the TMA overhead by overlapping it with the QK matmul.
|
||||
|
@ -68,7 +68,7 @@ class WGMMADefaultImpl:
|
||||
block_tiling: Tiling,
|
||||
tma_tiling: Tiling,
|
||||
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
|
||||
rhs_transpose: WGMMALayout,
|
||||
rhs_transpose: bool,
|
||||
) -> dict[str, jax.ShapeDtypeStruct]:
|
||||
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused.
|
||||
return ()
|
||||
@ -81,7 +81,6 @@ class WGMMADefaultImpl:
|
||||
def wgmma(
|
||||
smem_scratch: Any, # pylint: disable=unused-argument
|
||||
acc: WGMMAAccumulator,
|
||||
b_order: WGMMALayout,
|
||||
a_slice: SmemRef,
|
||||
b_slice: SmemRef,
|
||||
swizzle: int,
|
||||
@ -91,7 +90,7 @@ class WGMMADefaultImpl:
|
||||
This function must guarantee that all WGMMA operations queued before it was
|
||||
called have completed before returning.
|
||||
"""
|
||||
acc = wgmma(acc, a_slice, b_slice, b_order=b_order, swizzle=swizzle)
|
||||
acc = wgmma(acc, a_slice, b_slice, swizzle=swizzle)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(1)
|
||||
return acc
|
||||
@ -250,11 +249,10 @@ def build_kernel(
|
||||
with ctx.named_region("WGMMA"):
|
||||
a_slice = memref_slice(lhs_smem, si)
|
||||
b_slice = memref_slice(rhs_smem, si)
|
||||
rhs_smem_order = (
|
||||
WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR
|
||||
)
|
||||
if rhs_transpose:
|
||||
b_slice = memref_transpose(b_slice, (0, 1, 3, 2))
|
||||
accs = wgmma_impl.wgmma(
|
||||
impl_smem, accs, rhs_smem_order, a_slice, b_slice, swizzle=swizzle
|
||||
impl_smem, accs, a_slice, b_slice, swizzle=swizzle
|
||||
)
|
||||
|
||||
with ctx.named_region("TMA start"):
|
||||
|
197
jax/experimental/mosaic/gpu/examples/matmul_blackwell.py
Normal file
197
jax/experimental/mosaic/gpu/examples/matmul_blackwell.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Matmul kernel for Blackwell."""
|
||||
|
||||
import jax
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import c, ds, utils
|
||||
from jax.experimental.mosaic.gpu import tcgen05
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jr
|
||||
import numpy as np
|
||||
|
||||
|
||||
BLACKWELL_MMA_FP16_K = 16
|
||||
TMA_WARP = 1
|
||||
MMA_WARP = 0
|
||||
|
||||
|
||||
def bytecount(shape, dtype):
|
||||
return int(np.prod(shape) * dtype.dtype.itemsize)
|
||||
|
||||
|
||||
def build_kernel(
|
||||
m, n, k,
|
||||
tile_m: int = 128,
|
||||
tile_n: int = 128,
|
||||
):
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
f32 = ir.F32Type.get()
|
||||
index = ir.IndexType.get()
|
||||
ptr6 = ir.Type.parse("!llvm.ptr<6>") # TMEM
|
||||
|
||||
swizzle = 128
|
||||
tile_k = 64 # TODO(apaszke): I think we need to tile TMA to change this.
|
||||
in_dtype = jnp.float16
|
||||
k_loop_iter = k // tile_k
|
||||
tma_tile_m = 128
|
||||
tma_tile_kn = 64
|
||||
|
||||
if m % tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {tile_m=}")
|
||||
if n % tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {tile_n=}")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
|
||||
def kernel(ctx, a, b, d, smem):
|
||||
# TODO(apaszke): Use more SMEM slots to avoid oversynchronizing warps.
|
||||
a_smem, b_smem, d_smem, barriers, tmem_addr = smem
|
||||
(ab_full_barrier, ab_empty_barrier, mma_done_barrier) = barriers
|
||||
|
||||
warp_idx = mgpu.warp_idx(sync=True)
|
||||
warp_leader = nvvm.elect_sync(i1)
|
||||
|
||||
is_warp = lambda i: arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32))
|
||||
|
||||
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m,index))
|
||||
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n,index))
|
||||
|
||||
with mgpu.when(arith.andi(is_warp(TMA_WARP), warp_leader)):
|
||||
@mgpu.fori(c(k_loop_iter, index), None)
|
||||
def _tma_body(ki, _):
|
||||
# TODO(apaszke): Use a predicate instead of a conditional.
|
||||
with mgpu.when(arith.cmpi(arith.CmpIPredicate.ugt, ki, c(0, index))):
|
||||
ab_empty_barrier.wait()
|
||||
ab_full_barrier.arrive_expect_tx(
|
||||
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
|
||||
)
|
||||
k_start = arith.muli(ki, c(tile_k, index))
|
||||
common_args = dict(
|
||||
swizzle=swizzle, barrier=ab_full_barrier, arrive=False, uniform=False,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=a,
|
||||
dst_ref=a_smem,
|
||||
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
|
||||
gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)),
|
||||
**common_args,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=b,
|
||||
dst_ref=b_smem,
|
||||
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
|
||||
gmem_transform=(
|
||||
mgpu.TileTransform((tma_tile_kn, tma_tile_kn)),
|
||||
mgpu.TransposeTransform((1, 0, 2, 3)),
|
||||
),
|
||||
**common_args,
|
||||
)
|
||||
|
||||
with mgpu.when(is_warp(MMA_WARP)):
|
||||
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
||||
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
|
||||
tcgen05.tmem_relinquish_alloc_permit()
|
||||
with mgpu.when(warp_leader):
|
||||
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
|
||||
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
|
||||
def _mma_body(ki, accumulate):
|
||||
ab_full_barrier.wait()
|
||||
tcgen05.mma(
|
||||
tmem_addr_value,
|
||||
a_smem,
|
||||
mgpu.memref_transpose(b_smem, (0, 1, 3, 2)),
|
||||
a_swizzle=swizzle,
|
||||
b_swizzle=swizzle,
|
||||
accumulate=accumulate,
|
||||
)
|
||||
accumulate = arith.constant(i1, 1)
|
||||
is_last_iter = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
|
||||
)
|
||||
barrier_ptr = arith.select(
|
||||
is_last_iter, mma_done_barrier.get_ptr(), ab_empty_barrier.get_ptr()
|
||||
)
|
||||
tcgen05.commit_arrive(barrier_ptr)
|
||||
return accumulate
|
||||
|
||||
gpu.barrier()
|
||||
mma_done_barrier.wait()
|
||||
|
||||
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
|
||||
tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
|
||||
mgpu.commit_shared()
|
||||
ctx.async_copy(
|
||||
src_ref=d_smem,
|
||||
dst_ref=d,
|
||||
gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)),
|
||||
gmem_transform=mgpu.TileTransform((128, 64)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
ctx.await_async_copy(0)
|
||||
|
||||
smem = (
|
||||
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_k), (tma_tile_m, tma_tile_kn)), jnp.float16),
|
||||
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_k, tile_n), (tma_tile_kn, tma_tile_kn)), jnp.float16),
|
||||
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
|
||||
[mgpu.Barrier(arrival_count=1)] * 3,
|
||||
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
|
||||
)
|
||||
return mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
(n // tile_n, m // tile_m, 1),
|
||||
(128, 1, 1),
|
||||
(
|
||||
jax.ShapeDtypeStruct((m, k), jnp.float16),
|
||||
jax.ShapeDtypeStruct((n, k), jnp.float16),
|
||||
),
|
||||
jax.ShapeDtypeStruct((m, n), jnp.float16),
|
||||
smem,
|
||||
)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
m_tile = 128
|
||||
n_tile = 128
|
||||
k_tile = 64
|
||||
m = 16*m_tile
|
||||
n = 16*n_tile
|
||||
k = 16*k_tile
|
||||
|
||||
ka, kb = jr.split(jr.key(0), 2)
|
||||
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
|
||||
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
|
||||
|
||||
with mlir.make_ir_context(), ir.Location.unknown():
|
||||
f = build_kernel(m, n, k, tile_m=m_tile, tile_n=n_tile)
|
||||
y = f(a, b).block_until_ready()
|
||||
|
||||
ref = np.asarray(a) @ np.asarray(b).T
|
||||
np.testing.assert_allclose(y, ref, atol=1e-3, rtol=1e-3)
|
||||
print("OK!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
import jax
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
338
jax/experimental/mosaic/gpu/tcgen05.py
Normal file
338
jax/experimental/mosaic/gpu/tcgen05.py
Normal file
@ -0,0 +1,338 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import memref
|
||||
import numpy as np
|
||||
|
||||
from . import utils
|
||||
from . import fragmented_array as fa
|
||||
from . import _wgmma
|
||||
|
||||
# MyPy does a terrible job with the MLIR API.
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
|
||||
|
||||
def create_smem_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
):
|
||||
return _wgmma.create_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset,
|
||||
stride_byte_offset,
|
||||
swizzle,
|
||||
memory_space=3,
|
||||
const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
|
||||
)
|
||||
|
||||
def create_instr_descriptor(
|
||||
m: int,
|
||||
n: int,
|
||||
acc_dtype,
|
||||
input_dtype,
|
||||
transpose_a: bool = False,
|
||||
transpose_b: bool = False,
|
||||
):
|
||||
f32 = ir.F32Type.get()
|
||||
bf16 = ir.BF16Type.get()
|
||||
f16 = ir.F16Type.get()
|
||||
if input_dtype not in {f16, bf16}:
|
||||
raise NotImplementedError("Only float16 and bfloat16 inputs supported")
|
||||
if acc_dtype not in {f32, f16}:
|
||||
raise NotImplementedError("Only float32 and float16 accumulators supported")
|
||||
|
||||
desc = 0
|
||||
# We ignore sparsity in bits 0-3
|
||||
desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5
|
||||
# Bit 6 is reserved
|
||||
desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9
|
||||
desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12
|
||||
# We ignore negate bits 13-14
|
||||
desc |= transpose_a << 15 # Transpose A
|
||||
desc |= transpose_b << 16 # Transpose B
|
||||
if n % 8 or n > 256:
|
||||
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
||||
desc |= (n >> 3) << 17 # N, bits 17-22
|
||||
# Bit 23 is reserved
|
||||
if m % 16 or m > 256:
|
||||
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
|
||||
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
|
||||
# Bit 29 is reserved
|
||||
# We ignore max shift under .ws, bits 30-31
|
||||
return arith.constant(ir.IntegerType.get_signless(32), desc)
|
||||
|
||||
|
||||
def mma(
|
||||
d: ir.Value,
|
||||
a: ir.Value,
|
||||
b: ir.Value,
|
||||
*,
|
||||
a_swizzle: int = 128,
|
||||
b_swizzle: int = 128,
|
||||
num_cta: int = 1,
|
||||
accumulate: ir.Value | bool = True,
|
||||
):
|
||||
if not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"A must be a memref, got {a.type}")
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
if a_swizzle != 128 or b_swizzle != 128:
|
||||
raise NotImplementedError("Only swizzle=128 has been tested")
|
||||
if num_cta != 1:
|
||||
raise NotImplementedError("Only num_cta=1 supported")
|
||||
if isinstance(accumulate, bool):
|
||||
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, kn_tiling),
|
||||
element_type,
|
||||
mma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
) = _wgmma._validate_mma(
|
||||
a,
|
||||
b,
|
||||
a_swizzle,
|
||||
_wgmma.WGMMALayout.ROW_MAJOR,
|
||||
_wgmma.WGMMALayout.COL_MAJOR,
|
||||
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
|
||||
)
|
||||
|
||||
if m_tiling != 128:
|
||||
raise ValueError(f"A must have rows tiled by 128, got: {m_tiling}")
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
|
||||
|
||||
groups_k = k // kn_tiling
|
||||
groups_m = m // m_tiling
|
||||
|
||||
# TODO(apaszke): Verify ACC shape.
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
for mi in range(groups_m):
|
||||
for ki in range(groups_k):
|
||||
a_mk = arith.addi(
|
||||
a_desc_base,
|
||||
utils.c(_wgmma.wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
|
||||
)
|
||||
b_k = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(ki * b_k_byte_stride), i64))
|
||||
accumulate = _do_mma(
|
||||
d,
|
||||
a_mk,
|
||||
b_k,
|
||||
d_type=ir.F32Type.get(),
|
||||
m=m_tiling,
|
||||
**mma_params,
|
||||
accumulate=accumulate,
|
||||
)
|
||||
|
||||
|
||||
def _do_mma(
|
||||
d_addr: ir.Value,
|
||||
a_desc: ir.Value,
|
||||
b_desc: ir.Value,
|
||||
a_transpose: bool,
|
||||
b_transpose: bool,
|
||||
a_k_stride: int,
|
||||
b_k_stride: int,
|
||||
m: int,
|
||||
n: int,
|
||||
swizzle: int,
|
||||
element_type: ir.Type,
|
||||
d_type: ir.Type,
|
||||
accumulate: ir.Value,
|
||||
):
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
kn_tiling = swizzle // utils.bytewidth(element_type)
|
||||
instr_k = 32 // utils.bytewidth(element_type)
|
||||
if a_k_stride % 16 or b_k_stride % 16:
|
||||
raise ValueError
|
||||
|
||||
i_desc = create_instr_descriptor(
|
||||
m, n, d_type, element_type, a_transpose, b_transpose
|
||||
)
|
||||
for _ in range(kn_tiling // instr_k):
|
||||
llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[d_addr, a_desc, b_desc, i_desc, accumulate],
|
||||
f"tcgen05.mma.cta_group::1.kind::{element_type} [$0], $1, $2, $3, $4;",
|
||||
"r,l,l,r,b",
|
||||
has_side_effects=True,
|
||||
)
|
||||
accumulate = arith.constant(i1, 1)
|
||||
a_desc = arith.addi(a_desc, arith.constant(i64, a_k_stride >> 4))
|
||||
b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4))
|
||||
return accumulate
|
||||
|
||||
|
||||
def commit_arrive(barrier):
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[barrier],
|
||||
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];",
|
||||
"l",
|
||||
has_side_effects=True
|
||||
)
|
||||
|
||||
def tmem_alloc(tmem_addr, ncols: int):
|
||||
if ncols.bit_count() != 1 or not 32 <= ncols <= 512:
|
||||
raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}")
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[tmem_addr],
|
||||
f"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], {ncols};",
|
||||
"r",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
def tmem_relinquish_alloc_permit():
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[],
|
||||
"tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;",
|
||||
"",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
def tmem_load(tmem_addr, shape, num):
|
||||
if num.bit_count() != 1 or num > 128:
|
||||
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
|
||||
match shape:
|
||||
case "16x128b":
|
||||
num_out_regs = 2
|
||||
case "16x256b":
|
||||
num_out_regs = 4
|
||||
case _:
|
||||
raise NotImplementedError(f"{shape=} is unsupported")
|
||||
num_out_regs *= num
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
out_regs = ",".join("$" + str(i) for i in range(num_out_regs))
|
||||
regs = llvm.inline_asm(
|
||||
ir.Type.parse(
|
||||
"!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
|
||||
),
|
||||
[tmem_addr],
|
||||
f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];",
|
||||
"=r," * num_out_regs + "r",
|
||||
has_side_effects=True,
|
||||
)
|
||||
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
|
||||
|
||||
|
||||
class TMEMLayout(enum.Enum):
|
||||
"""Layout of the array in TMEM.
|
||||
|
||||
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-organization
|
||||
"""
|
||||
D = "D"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TMEMRef:
|
||||
address: ir.Value
|
||||
layout: TMEMLayout
|
||||
num_cols: int
|
||||
dtype: ir.Type
|
||||
|
||||
@classmethod
|
||||
def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int, dtype: ir.Type):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
|
||||
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
|
||||
addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
if addr_ref_ty.memory_space != smem:
|
||||
raise ValueError(f"tmem_addr_ref must be in workgroup memory, got: {addr_ref_ty}")
|
||||
if addr_ref_ty.element_type != i32:
|
||||
raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
|
||||
tmem_addr = memref.load(tmem_addr_ref, [arith.ConstantOp.create_index(0)])
|
||||
# TODO: Do we have to do this??
|
||||
# warp_idx = utils.warp_idx(sync=False)
|
||||
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
|
||||
return cls(tmem_addr, layout, num_cols, dtype)
|
||||
|
||||
@property
|
||||
def num_rows(self):
|
||||
match self.layout:
|
||||
case TMEMLayout.D:
|
||||
return 128
|
||||
case _:
|
||||
raise NotImplementedError(self.layout)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return (self.num_rows, self.num_cols)
|
||||
|
||||
def __getitem__(self, *idxs):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
|
||||
if any(is_squeezed):
|
||||
raise ValueError("TMEM loads only support slicing")
|
||||
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
|
||||
raise NotImplementedError("Slicing of TMEM not impelmented yet")
|
||||
if self.layout != TMEMLayout.D:
|
||||
raise NotImplementedError(self.layout)
|
||||
if self.num_cols % 8:
|
||||
raise NotImplementedError
|
||||
if self.dtype != ir.F32Type.get():
|
||||
raise NotImplementedError(self.dtype)
|
||||
layout = _m128_256bit_32bit_layout(self.shape)
|
||||
regs_shape = layout.registers_shape(self.shape)
|
||||
num = self.num_cols // 8
|
||||
registers = np.empty(regs_shape, dtype=object)
|
||||
# We load 16 lanes at a time, but need 32 in total.
|
||||
for row_group in range(2):
|
||||
addr = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16))
|
||||
regs = tmem_load(addr, "16x256b", num)
|
||||
regs = [llvm.bitcast(self.dtype, r) for r in regs]
|
||||
vector_regs = []
|
||||
undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype))
|
||||
for r_low, r_high in zip(regs[::2], regs[1::2]):
|
||||
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
|
||||
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
|
||||
vector_regs.append(vreg)
|
||||
# Dimension 4 is the one where we split 32 rows into tiles of 8.
|
||||
regs_slice = [slice(None)] * 4 + [slice(row_group * 2, (row_group + 1) * 2)]
|
||||
registers[*regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[*regs_slice].shape)
|
||||
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
|
||||
|
||||
|
||||
def _m128_256bit_32bit_layout(shape: tuple[int, ...]):
|
||||
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Shape {shape} is not 2D")
|
||||
if shape[0] % 128 != 0 or shape[1] % 8 != 0:
|
||||
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
||||
return fa.TiledLayout(
|
||||
fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))),
|
||||
warp_dim=-8,
|
||||
lane_dims=(-4, -3),
|
||||
vector_dim=-1,
|
||||
)
|
@ -17,6 +17,7 @@ import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
@ -100,6 +101,7 @@ def create_descriptor(
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
memory_space: int | None = None,
|
||||
const_init: int = 0,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
|
||||
@ -118,7 +120,8 @@ def create_descriptor(
|
||||
)
|
||||
# We ignore the offset
|
||||
desc_const = (
|
||||
(wgmma_encode(leading_byte_offset) << 16)
|
||||
const_init
|
||||
| (wgmma_encode(leading_byte_offset) << 16)
|
||||
| (wgmma_encode(stride_byte_offset) << 32)
|
||||
)
|
||||
desc = llvm.or_(
|
||||
@ -299,132 +302,218 @@ class WGMMALayout(enum.Enum):
|
||||
COL_MAJOR = enum.auto()
|
||||
|
||||
|
||||
def _validate_mma(
|
||||
a: Any,
|
||||
b: ir.Value,
|
||||
swizzle: int,
|
||||
a_layout: WGMMALayout,
|
||||
b_layout: WGMMALayout,
|
||||
descriptor_const_init: int = 0,
|
||||
):
|
||||
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
|
||||
# instruction's K width.
|
||||
if swizzle < 32:
|
||||
raise ValueError(f"Unsupported swizzle: {swizzle}")
|
||||
|
||||
# Get A type.
|
||||
if a_in_smem := isinstance(a, ir.Value):
|
||||
if not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}")
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
a_element_type = a_ty.element_type
|
||||
a_shape = tuple(a_ty.shape)
|
||||
if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
|
||||
raise ValueError("A must be in workgroup memory when it's a reference")
|
||||
if len(a_shape) != 4:
|
||||
raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}")
|
||||
elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"):
|
||||
a_element_type = a.mlir_dtype
|
||||
a_shape = a.shape
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported A type: {type(a)}")
|
||||
|
||||
# Get B type (always a reference).
|
||||
b_ty = ir.MemRefType(b.type)
|
||||
if b_ty.rank != 4:
|
||||
raise ValueError(f"B must be 4D, got rank {b_ty.rank}")
|
||||
|
||||
# Veirfy element types and compute the tiling.
|
||||
if (element_type := a_element_type) != b_ty.element_type:
|
||||
raise ValueError(
|
||||
f"A and B must have the same element type, got: {a_element_type} and"
|
||||
f" {b_ty.element_type}"
|
||||
)
|
||||
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
|
||||
if element_type not in supported_types:
|
||||
raise ValueError(a_element_type)
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
kn_tiling = swizzle // element_bytewidth
|
||||
|
||||
# Verify the shape and strides of B are as expected.
|
||||
k_tiles, n_tiles, k_tiling, n_tiling = b_ty.shape
|
||||
if k_tiling != kn_tiling:
|
||||
raise ValueError(b_ty.shape)
|
||||
# Note that while this technically allows n to be smaller than kn_tile,
|
||||
# the stride checks above will still enforce that the memory region is padded.
|
||||
# It might be possible to relax that requirement, but I haven't tested it.
|
||||
if n_tiling > kn_tiling and n_tiling % kn_tiling:
|
||||
raise ValueError(n_tiling, kn_tiling)
|
||||
k = k_tiles * kn_tiling
|
||||
n = n_tiles * n_tiling
|
||||
|
||||
b_strides, _ = b_ty.get_strides_and_offset()
|
||||
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
||||
b_k_byte_stride = b_byte_strides[0]
|
||||
if b_byte_strides[1] != swizzle * kn_tiling:
|
||||
raise ValueError(b_byte_strides)
|
||||
if b_byte_strides[2:] == [swizzle, element_bytewidth]:
|
||||
b_order = WGMMALayout.ROW_MAJOR
|
||||
elif b_byte_strides[2:] == [element_bytewidth, swizzle]:
|
||||
b_order = WGMMALayout.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
# Verify the shape and strides of A are as expected.
|
||||
if not a_in_smem:
|
||||
m = a_shape[0]
|
||||
a_order = m_tiling = None
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
m_tiles, k_tiles, m_tiling, k_tiling = a_ty.shape
|
||||
m = m_tiles * m_tiling
|
||||
if k_tiling != kn_tiling or k_tiles * k_tiling != k:
|
||||
raise ValueError(a_ty.shape)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
||||
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
|
||||
a_order = WGMMALayout.COL_MAJOR
|
||||
else:
|
||||
raise ValueError(a_byte_strides)
|
||||
if a_order == WGMMALayout.COL_MAJOR and swizzle != 128:
|
||||
# Not sure what the layout is like, since the tiles aren't square.
|
||||
raise NotImplementedError
|
||||
|
||||
tnsp_lbo = swizzle * (swizzle // 32)
|
||||
sbo = swizzle // 2
|
||||
a_desc_fields = dict(
|
||||
leading_byte_offset=(1 if a_order == a_layout else tnsp_lbo) << 4,
|
||||
stride_byte_offset=sbo << 4,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=(1 if b_order == b_layout else tnsp_lbo) << 4,
|
||||
stride_byte_offset=sbo << 4,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
wgmma_params = dict(
|
||||
a_transpose=a_order != a_layout,
|
||||
b_transpose=b_order != b_layout,
|
||||
a_k_stride=(2 if a_order == a_layout else swizzle) << 4,
|
||||
b_k_stride=(2 if b_order == b_layout else swizzle) << 4,
|
||||
n=n,
|
||||
swizzle=swizzle,
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
else element_type,
|
||||
)
|
||||
if not a_in_smem:
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
a_k_byte_stride = a_desc_base = None
|
||||
else:
|
||||
a_desc_base = create_descriptor(
|
||||
a, **a_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_k_byte_stride = a_strides[1] * element_bytewidth
|
||||
b_desc_base = create_descriptor(
|
||||
b, **b_desc_fields, const_init=descriptor_const_init
|
||||
)
|
||||
|
||||
return (
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, kn_tiling),
|
||||
element_type,
|
||||
wgmma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
)
|
||||
|
||||
|
||||
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
|
||||
# transpositions from memref strides.
|
||||
def wgmma(
|
||||
acc: WGMMAAccumulator,
|
||||
a,
|
||||
b,
|
||||
a: fa.FragmentedArray | ir.Value,
|
||||
b: ir.Value,
|
||||
*,
|
||||
swizzle: int = 128,
|
||||
# Order only applies within each tile!
|
||||
a_order: WGMMALayout | None = None,
|
||||
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
|
||||
swizzle: int = 128,
|
||||
):
|
||||
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
||||
a_element_type = a.mlir_dtype
|
||||
a_shape = a.shape
|
||||
else:
|
||||
a_ty = ir.MemRefType(a.type)
|
||||
a_element_type = a_ty.element_type
|
||||
a_shape = a_ty.shape
|
||||
b_ty = ir.MemRefType(b.type)
|
||||
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
|
||||
if a_element_type not in supported_types:
|
||||
raise ValueError(a_element_type)
|
||||
if b_ty.element_type not in supported_types:
|
||||
raise ValueError(b_ty.element_type)
|
||||
if (element_type := a_element_type) != b_ty.element_type:
|
||||
raise ValueError
|
||||
element_bytewidth = bytewidth(element_type)
|
||||
kn_tile = swizzle // element_bytewidth
|
||||
"""Perform acc += a @ b using the WGMMA instruction.
|
||||
|
||||
groups_k, groups_n = b_ty.shape[:2]
|
||||
k_group_size, n_group_size = (
|
||||
b_ty.shape[2:] if b_order == WGMMALayout.ROW_MAJOR else b_ty.shape[:1:-1]
|
||||
)
|
||||
# Note that while this technically allows n to be smaller than kn_tile,
|
||||
# the stride checks below will still enforce that the memory region is padded.
|
||||
# It might be possible to relax that requirement, but I haven't tested it.
|
||||
if n_group_size > kn_tile and n_group_size % kn_tile:
|
||||
raise ValueError(n_group_size, kn_tile)
|
||||
if k_group_size != kn_tile:
|
||||
raise ValueError(b_ty.shape)
|
||||
The expected memref shapes are:
|
||||
a: (m, k, 64, S)
|
||||
b: (k, n, S, S)
|
||||
where S = swizzle // bytewidth(element_type).
|
||||
|
||||
The refs must be contiguous or be contiguous except for having their two minor
|
||||
dimensions swapped.
|
||||
"""
|
||||
a_in_regs = isinstance(a, fa.FragmentedArray)
|
||||
if not a_in_regs and not ir.MemRefType.isinstance(a.type):
|
||||
raise ValueError(f"Unsupported A type: {type(a)}")
|
||||
if not ir.MemRefType.isinstance(b.type):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
|
||||
(
|
||||
a_desc_base,
|
||||
b_desc_base,
|
||||
(m, k, n),
|
||||
(m_tiling, kn_tiling),
|
||||
element_type,
|
||||
wgmma_params,
|
||||
a_k_byte_stride,
|
||||
b_k_byte_stride,
|
||||
) = _validate_mma(a, b, swizzle, WGMMALayout.ROW_MAJOR, WGMMALayout.COL_MAJOR)
|
||||
|
||||
if a_in_regs:
|
||||
if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
|
||||
raise ValueError(a_element_type)
|
||||
if a_shape[0] % 64 or a_shape[1] % kn_tile:
|
||||
raise ValueError(a_shape)
|
||||
if a_shape[1] // kn_tile != groups_k:
|
||||
raise ValueError(a_shape[1] // kn_tile, groups_k)
|
||||
groups_m = a_shape[0] // 64
|
||||
if a_order is not None:
|
||||
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
||||
raise ValueError(
|
||||
"a_order can only be specified when A is in shared memory"
|
||||
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
|
||||
)
|
||||
if a.shape[0] % 64:
|
||||
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
|
||||
a_m_byte_stride = None
|
||||
else:
|
||||
groups_m = a_shape[0]
|
||||
if a_shape[1] != groups_k:
|
||||
raise ValueError(a_shape[1], groups_k)
|
||||
if a_shape[2:] != [64, kn_tile]:
|
||||
raise ValueError(a_shape)
|
||||
if a_order is None:
|
||||
a_order = WGMMALayout.ROW_MAJOR
|
||||
if m_tiling != 64:
|
||||
raise ValueError(f"A must have rows tiled by 64, got: {m_tiling}")
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_m_byte_stride = a_strides[0] * bytewidth(element_type)
|
||||
|
||||
if a_order == WGMMALayout.COL_MAJOR and swizzle != 128:
|
||||
# Not sure what the layout is like, since the tiles aren't square.
|
||||
raise NotImplementedError
|
||||
expected_acc_shape = (groups_m * 64, groups_n * n_group_size)
|
||||
groups_k = k // kn_tiling
|
||||
groups_m = m // 64
|
||||
|
||||
expected_acc_shape = (groups_m * 64, n)
|
||||
if acc.value.shape != expected_acc_shape:
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {expected_acc_shape}, got"
|
||||
f" {acc.value.shape}"
|
||||
)
|
||||
|
||||
row_major = WGMMALayout.ROW_MAJOR
|
||||
col_major = WGMMALayout.COL_MAJOR
|
||||
tnsp_lbo = swizzle * (swizzle // 32)
|
||||
sbo = swizzle // 2
|
||||
a_desc_fields = dict(
|
||||
leading_byte_offset=(1 if a_order == row_major else tnsp_lbo) << 4,
|
||||
stride_byte_offset=sbo << 4,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
b_desc_fields = dict(
|
||||
leading_byte_offset=(tnsp_lbo if b_order == row_major else 1) << 4,
|
||||
stride_byte_offset=sbo << 4,
|
||||
swizzle=swizzle,
|
||||
memory_space=3,
|
||||
)
|
||||
wgmma_params = dict(
|
||||
a_transpose=a_order == col_major,
|
||||
b_transpose=b_order == row_major,
|
||||
a_k_stride=(2 if a_order == row_major else 128) << 4,
|
||||
b_k_stride=(swizzle if b_order == row_major else 2) << 4,
|
||||
n=(groups_n * n_group_size),
|
||||
swizzle=swizzle,
|
||||
element_type=ir.FloatTF32Type.get()
|
||||
if ir.F32Type.isinstance(element_type)
|
||||
else element_type,
|
||||
)
|
||||
if a_in_regs:
|
||||
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
||||
|
||||
if a_in_regs:
|
||||
a = wgmma_fence(a) # Make sure the registers are ready.
|
||||
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
|
||||
else:
|
||||
a_desc_base = create_descriptor(a, **a_desc_fields)
|
||||
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
||||
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
||||
a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
|
||||
if a_byte_strides[2:] != [swizzle, element_bytewidth]:
|
||||
raise ValueError(a_byte_strides)
|
||||
b_desc_base = create_descriptor(b, **b_desc_fields)
|
||||
b_strides, _ = b_ty.get_strides_and_offset()
|
||||
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
||||
b_k_byte_stride = b_byte_strides[0]
|
||||
if b_byte_strides[1:] != [swizzle * kn_tile, swizzle, element_bytewidth]:
|
||||
raise ValueError(b_byte_strides)
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
new_acc_regs = acc.value.registers.copy()
|
||||
for mi in range(groups_m):
|
||||
for ki in range(groups_k):
|
||||
if a_in_regs:
|
||||
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
|
||||
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tiling : (ki + 1) * kn_tiling]
|
||||
else:
|
||||
a_mk = llvm_add(
|
||||
a_desc_base,
|
||||
|
@ -23,9 +23,9 @@ from jax._src.pallas.core import BlockSpec as BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams as CompilerParams
|
||||
from jax._src.pallas.core import core_map as core_map
|
||||
from jax._src.pallas.core import CostEstimate as CostEstimate
|
||||
from jax._src.pallas.core import lower_as_mlir as lower_as_mlir
|
||||
from jax._src.pallas.core import GridSpec as GridSpec
|
||||
from jax._src.pallas.core import IndexingMode as IndexingMode
|
||||
from jax._src.pallas.core import lower_as_mlir as lower_as_mlir
|
||||
from jax._src.pallas.core import MemoryRef as MemoryRef
|
||||
from jax._src.pallas.core import MemorySpace as MemorySpace
|
||||
from jax._src.pallas.core import no_block_spec as no_block_spec
|
||||
@ -34,6 +34,7 @@ from jax._src.pallas.core import unblocked as unblocked
|
||||
from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost
|
||||
from jax._src.pallas.helpers import empty as empty
|
||||
from jax._src.pallas.helpers import empty_like as empty_like
|
||||
from jax._src.pallas.helpers import when as when
|
||||
from jax._src.pallas.pallas_call import pallas_call as pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add as atomic_add
|
||||
@ -57,7 +58,6 @@ from jax._src.pallas.primitives import swap as swap
|
||||
from jax._src.pallas.utils import cdiv as cdiv
|
||||
from jax._src.pallas.utils import next_power_of_2 as next_power_of_2
|
||||
from jax._src.pallas.utils import strides_from_shape as strides_from_shape
|
||||
from jax._src.pallas.utils import when as when
|
||||
from jax._src.state.discharge import run_state as run_state
|
||||
from jax._src.state.indexing import ds as ds
|
||||
from jax._src.state.indexing import dslice as dslice
|
||||
|
@ -25,6 +25,8 @@ from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
|
||||
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
|
||||
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
|
||||
from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy
|
||||
from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier
|
||||
from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef
|
||||
|
@ -456,6 +456,9 @@ MaybeTracer = Union[JaxType, Tracer]
|
||||
class ShardMapPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
|
||||
def bind(self, *args, **params):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, fun_and_args, params):
|
||||
fun, *args = fun_and_args
|
||||
return trace.process_shard_map(shard_map_p, fun, args, **params)
|
||||
@ -1160,7 +1163,8 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
|
||||
for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p,
|
||||
control_flow.loops.cumprod_p, control_flow.loops.cummax_p,
|
||||
control_flow.loops.cummin_p, pjit.sharding_constraint_p]:
|
||||
control_flow.loops.cummin_p, pjit.sharding_constraint_p,
|
||||
pjit.mesh_cast_p]:
|
||||
register_standard_check(p)
|
||||
register_standard_rewrite(p)
|
||||
|
||||
@ -1715,7 +1719,9 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||
mesh = eqn.params['mesh']
|
||||
with (core.extend_axis_env_nd(mesh.shape.items()),
|
||||
set_abstract_mesh(_as_manual_mesh(mesh))):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
||||
|
@ -42,7 +42,8 @@ def generate_sourcemaps(
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
for pass_to_eval in passes:
|
||||
if pass_to_eval.compile_fn not in compile_cache:
|
||||
pass_work_dir = os.path.join(work_dir, pass_to_eval.name)
|
||||
dirname = pass_to_eval.name.replace(":", "__")
|
||||
pass_work_dir = os.path.join(work_dir, dirname)
|
||||
os.makedirs(pass_work_dir, exist_ok=False)
|
||||
compile_result = pass_to_eval.compile_fn(
|
||||
pass_work_dir, f, args, kwargs
|
||||
|
@ -24,7 +24,8 @@ LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")
|
||||
|
||||
SRC_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) ="
|
||||
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)\)"
|
||||
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)"
|
||||
r"( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)"
|
||||
)
|
||||
|
||||
SCOPED_REGEX = re.compile(
|
||||
|
@ -86,9 +86,9 @@ from jax._src.interpreters.partial_eval import (
|
||||
|
||||
|
||||
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
|
||||
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
|
||||
def trace_to_jaxpr_dynamic(fun, in_avals, *, keep_inputs=None): # noqa
|
||||
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
|
||||
fun, in_avals, debug_info, keep_inputs=keep_inputs)
|
||||
fun, in_avals, keep_inputs=keep_inputs)
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
|
||||
|
@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -46,8 +45,3 @@ py_library(
|
||||
"//jax/experimental/jax2tf",
|
||||
] + py_deps("tensorflow_core"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "build_utils",
|
||||
srcs = ["build_utils.py"],
|
||||
)
|
||||
|
@ -35,6 +35,8 @@ def _get_version_string() -> str:
|
||||
# In this case we return it directly.
|
||||
if _release_version is not None:
|
||||
return _release_version
|
||||
if os.getenv("WHEEL_VERSION_SUFFIX"):
|
||||
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
|
||||
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
|
||||
|
||||
|
||||
@ -71,16 +73,23 @@ def _get_version_for_build() -> str:
|
||||
"""Determine the version at build time.
|
||||
|
||||
The returned version string depends on which environment variables are set:
|
||||
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
|
||||
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
|
||||
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
|
||||
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
|
||||
wheel build rule.
|
||||
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
|
||||
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
|
||||
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
|
||||
"""
|
||||
if _release_version is not None:
|
||||
return _release_version
|
||||
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
|
||||
return _version_from_todays_date(_version)
|
||||
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
|
||||
if os.getenv("WHEEL_VERSION_SUFFIX"):
|
||||
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
|
||||
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
|
||||
return _version
|
||||
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
|
||||
return _version_from_todays_date(_version)
|
||||
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from jax._src.lib import triton
|
||||
from jax._src.lib import xla_client
|
||||
import jax._src.xla_bridge as xb
|
||||
|
||||
@ -99,5 +100,11 @@ def initialize():
|
||||
cuda_plugin_extension.register_custom_type_id, c_api
|
||||
),
|
||||
)
|
||||
triton.register_compilation_handler(
|
||||
"CUDA",
|
||||
functools.partial(
|
||||
cuda_plugin_extension.compile_triton_to_asm, c_api
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.warning('cuda_plugin_extension is not found.')
|
||||
|
@ -234,8 +234,8 @@ cc_library(
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@nanobind",
|
||||
"@tsl//tsl/platform:statusor",
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
@ -243,6 +243,7 @@ cc_library(
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||
"@xla//xla/python:py_client_gpu",
|
||||
"@xla//xla/tsl/python/lib/core:numpy",
|
||||
],
|
||||
|
@ -41,8 +41,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update", CholeskyUpdate,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA",
|
||||
GetrfFfi);
|
||||
|
@ -23,15 +23,10 @@ namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
|
||||
EncapsulateFfiHandler(ThreeFry2x32Ffi);
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
|
@ -33,29 +33,6 @@ namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/338022728): old custom call target, remove after 6 months
|
||||
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
LaunchThreeFry2x32Kernel(stream, buffers, **s);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
std::string_view message = s.message();
|
||||
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
ffi::Error ThreeFry2x32Impl(gpuStream_t stream,
|
||||
|
@ -121,35 +121,6 @@ void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
|
||||
out1, n, nullptr);
|
||||
}
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
|
||||
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
|
||||
std::array<const std::uint32_t*, 2> data;
|
||||
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
|
||||
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
|
||||
std::int64_t n = descriptor.n;
|
||||
int output_idx = 4;
|
||||
std::int64_t* n_ptr = nullptr;
|
||||
if (n < 0) {
|
||||
// n is an operand in device memory.
|
||||
n_ptr = reinterpret_cast<std::int64_t*>(buffers[4]);
|
||||
output_idx = 5;
|
||||
}
|
||||
|
||||
std::array<std::uint32_t*, 2> out;
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[output_idx]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
n < 0 ? 32
|
||||
: std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], n, n_ptr);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -26,19 +26,6 @@ limitations under the License.
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n; // If -1 then the length is passed as a 5th operand
|
||||
};
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor);
|
||||
|
||||
// TODO(b/338022728): remove after 6 months
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
|
||||
std::int64_t n,
|
||||
std::uint32_t *keys0, std::uint32_t *keys1,
|
||||
|
@ -16,23 +16,28 @@ limitations under the License.
|
||||
#include "jaxlib/gpu_plugin_extension.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "nanobind/stl/string.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_triton_extension.h"
|
||||
#include "xla/pjrt/status_casters.h"
|
||||
#include "xla/python/py_client_gpu.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h"
|
||||
#include "xla/util.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
@ -40,6 +45,44 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
struct TritonCompilationResult {
|
||||
std::string asm_text;
|
||||
int64_t smem_bytes;
|
||||
int cluster_dim_x;
|
||||
int cluster_dim_y;
|
||||
int cluster_dim_z;
|
||||
};
|
||||
|
||||
absl::StatusOr<TritonCompilationResult> CompileTritonToASM(
|
||||
const PJRT_Api* c_api, absl::string_view module,
|
||||
absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) {
|
||||
const PJRT_Triton_Extension* triton_ext =
|
||||
pjrt::FindExtension<PJRT_Triton_Extension>(
|
||||
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton);
|
||||
if (triton_ext == nullptr) {
|
||||
return Unimplemented("The plugin does not have a Triton extension.");
|
||||
}
|
||||
PJRT_Triton_Compile_Args args;
|
||||
args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE;
|
||||
args.module = module.data();
|
||||
args.module_size = module.size();
|
||||
args.arch_name = arch_name.data();
|
||||
args.arch_name_size = arch_name.size();
|
||||
args.num_warps = num_warps;
|
||||
args.num_ctas = num_ctas;
|
||||
args.num_stages = num_stages;
|
||||
RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api);
|
||||
auto asm_text = std::string(args.out_asm, args.out_asm_size);
|
||||
delete[] args.out_asm;
|
||||
return TritonCompilationResult{
|
||||
.asm_text = asm_text,
|
||||
.smem_bytes = args.out_smem_bytes,
|
||||
.cluster_dim_x = args.out_cluster_dim_x,
|
||||
.cluster_dim_y = args.out_cluster_dim_y,
|
||||
.cluster_dim_z = args.out_cluster_dim_z,
|
||||
};
|
||||
}
|
||||
|
||||
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
|
||||
const char* fn_name_c_str,
|
||||
size_t fn_name_size, nb::object fn,
|
||||
@ -170,6 +213,24 @@ nb::dict Registrations() {
|
||||
|
||||
void BuildGpuPluginExtension(nanobind::module_& m) {
|
||||
tsl::ImportNumpy();
|
||||
|
||||
nb::class_<TritonCompilationResult>(m, "TritonCompilationResult")
|
||||
.def_ro("asm", &TritonCompilationResult::asm_text)
|
||||
.def_ro("smem_bytes", &TritonCompilationResult::smem_bytes)
|
||||
.def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x)
|
||||
.def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y)
|
||||
.def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z);
|
||||
|
||||
m.def("compile_triton_to_asm",
|
||||
[](nb::capsule c_api, nb::bytes module, absl::string_view arch_name,
|
||||
int num_warps, int num_ctas, int num_stages) {
|
||||
return xla::ValueOrThrow(CompileTritonToASM(
|
||||
static_cast<const PJRT_Api*>(c_api.data()),
|
||||
absl::string_view(static_cast<const char*>(module.data()),
|
||||
module.size()),
|
||||
arch_name, num_warps, num_ctas, num_stages));
|
||||
});
|
||||
|
||||
m.def(
|
||||
"register_custom_call_target",
|
||||
[](nb::capsule c_api, nb::object fn_name_py, nb::object fn,
|
||||
|
@ -36,10 +36,8 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
|
||||
if _cuda_prng:
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
# TODO(b/338022728): remove after 6 months, always api_version=1
|
||||
api_version = 1 if "_ffi" in _name else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
|
||||
api_version=api_version)
|
||||
api_version=1)
|
||||
|
||||
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
try:
|
||||
|
112
jaxlib/jax.bzl
112
jaxlib/jax.bzl
@ -14,7 +14,10 @@
|
||||
|
||||
"""Bazel macros used by the JAX build."""
|
||||
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
|
||||
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
||||
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
|
||||
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
|
||||
@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = []
|
||||
jax_test_util_visibility = []
|
||||
loops_visibility = []
|
||||
|
||||
PLATFORM_TAGS_DICT = {
|
||||
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
|
||||
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}
|
||||
|
||||
# TODO(vam): remove this once zstandard builds against Python 3.13
|
||||
def get_zstandard():
|
||||
if HERMETIC_PYTHON_VERSION == "3.13":
|
||||
@ -106,6 +118,7 @@ def jax_visibility(_target):
|
||||
return []
|
||||
|
||||
jax_extra_deps = []
|
||||
jax_gpu_support_deps = []
|
||||
jax2tf_deps = []
|
||||
|
||||
def pytype_library(name, pytype_srcs = None, **kwargs):
|
||||
@ -208,7 +221,7 @@ def if_building_jaxlib(
|
||||
"@pypi_jax_cuda12_pjrt//:pkg",
|
||||
],
|
||||
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]):
|
||||
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
|
||||
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
|
||||
|
||||
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
|
||||
|
||||
@ -267,7 +280,7 @@ def jax_multiplatform_test(
|
||||
]
|
||||
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
|
||||
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
|
||||
test_tags += ["manual"]
|
||||
test_tags.append("manual")
|
||||
if backend == "gpu":
|
||||
test_tags += tf_cuda_tests_tags()
|
||||
native.py_test(
|
||||
@ -308,15 +321,60 @@ def jax_generate_backend_suites(backends = []):
|
||||
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
||||
)
|
||||
|
||||
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
|
||||
if no_abi:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
|
||||
else:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
|
||||
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
|
||||
return wheel_name_template.format(
|
||||
package_name = package_name,
|
||||
python_version = python_version,
|
||||
major_python_version = python_version[0],
|
||||
wheel_version = wheel_version,
|
||||
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
|
||||
)
|
||||
|
||||
def _jax_wheel_impl(ctx):
|
||||
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
|
||||
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
|
||||
output_path = ctx.attr.output_path[BuildSettingInfo].value
|
||||
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
|
||||
executable = ctx.executable.wheel_binary
|
||||
|
||||
output = ctx.actions.declare_directory(ctx.label.name)
|
||||
if include_cuda_libs and not override_include_cuda_libs:
|
||||
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
|
||||
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
|
||||
" If you absolutely need to build links directly against the CUDA libraries, provide" +
|
||||
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")
|
||||
|
||||
env = {}
|
||||
args = ctx.actions.args()
|
||||
args.add("--output_path", output.path) # required argument
|
||||
args.add("--cpu", ctx.attr.platform_tag) # required argument
|
||||
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
|
||||
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
|
||||
|
||||
full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
|
||||
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
|
||||
if BUILD_TAG:
|
||||
env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG)
|
||||
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
|
||||
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
|
||||
env["JAX_RELEASE"] = "1"
|
||||
|
||||
cpu = ctx.attr.cpu
|
||||
platform_name = ctx.attr.platform_name
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
no_abi = ctx.attr.no_abi,
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
output_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = output_file.path[:output_file.path.rfind("/")]
|
||||
|
||||
args.add("--output_path", wheel_dir) # required argument
|
||||
args.add("--cpu", cpu) # required argument
|
||||
args.add("--jaxlib_git_hash", git_hash) # required argument
|
||||
|
||||
if ctx.attr.enable_cuda:
|
||||
args.add("--enable-cuda", "True")
|
||||
@ -335,11 +393,13 @@ def _jax_wheel_impl(ctx):
|
||||
args.use_param_file("@%s", use_always = False)
|
||||
ctx.actions.run(
|
||||
arguments = [args],
|
||||
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
|
||||
outputs = [output],
|
||||
inputs = [],
|
||||
outputs = [output_file],
|
||||
executable = executable,
|
||||
env = env,
|
||||
)
|
||||
return [DefaultInfo(files = depset(direct = [output]))]
|
||||
|
||||
return [DefaultInfo(files = depset(direct = [output_file]))]
|
||||
|
||||
_jax_wheel = rule(
|
||||
attrs = {
|
||||
@ -349,19 +409,25 @@ _jax_wheel = rule(
|
||||
# b/365588895 Investigate cfg = "exec" for multi platform builds
|
||||
cfg = "target",
|
||||
),
|
||||
"platform_tag": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(allow_single_file = True),
|
||||
"wheel_name": attr.string(mandatory = True),
|
||||
"no_abi": attr.bool(default = False),
|
||||
"cpu": attr.string(mandatory = True),
|
||||
"platform_name": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
|
||||
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
|
||||
"enable_cuda": attr.bool(default = False),
|
||||
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
|
||||
"platform_version": attr.string(mandatory = True, default = ""),
|
||||
"skip_gpu_kernels": attr.bool(default = False),
|
||||
"enable_rocm": attr.bool(default = False),
|
||||
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
|
||||
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
|
||||
},
|
||||
implementation = _jax_wheel_impl,
|
||||
executable = False,
|
||||
)
|
||||
|
||||
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
|
||||
"""Create jax artifact wheels.
|
||||
|
||||
Common artifact attributes are grouped within a single macro.
|
||||
@ -369,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
Args:
|
||||
name: the name of the wheel
|
||||
wheel_binary: the binary to use to build the wheel
|
||||
wheel_name: the name of the wheel
|
||||
no_abi: whether to build a wheel without ABI
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
|
||||
@ -378,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
_jax_wheel(
|
||||
name = name,
|
||||
wheel_binary = wheel_binary,
|
||||
wheel_name = wheel_name,
|
||||
no_abi = no_abi,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
|
||||
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
|
||||
# the git hash file needs to be created first.
|
||||
git_hash = select({
|
||||
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
|
||||
"//conditions:default": None,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
# flag in bazel command to pass the git hash for nightly or release builds.
|
||||
platform_name = select({
|
||||
"@platforms//os:osx": "Darwin",
|
||||
"@platforms//os:macos": "Darwin",
|
||||
"@platforms//os:windows": "Windows",
|
||||
"@platforms//os:linux": "Linux",
|
||||
}),
|
||||
# Following the convention in jax/tools/build_utils.py.
|
||||
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
|
||||
platform_tag = select({
|
||||
cpu = select({
|
||||
"//jaxlib/tools:macos_arm64": "arm64",
|
||||
"//jaxlib/tools:win_amd64": "AMD64",
|
||||
"//jaxlib/tools:arm64": "aarch64",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user