1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Merge branch 'main' into scipy-expon

This commit is contained in:
Qazalbash 2025-02-05 23:33:51 +05:00
commit 7fc605f783
No known key found for this signature in database
GPG Key ID: 624E2F28F6A2AAA7
128 changed files with 3616 additions and 1653 deletions
.github/workflows
CHANGELOG.mdWORKSPACE
build
docs
examples/ffi
jax
jax_plugins/cuda
jaxlib

@ -14,10 +14,15 @@ on:
pull_request:
branches:
- main
push:
branches:
- main
- 'release/**'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
run_tests:
@ -26,14 +31,22 @@ jobs:
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
strategy:
matrix:
python: ["3.10", "3.13"]
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
enable-x_64: [1, 0]
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
exclude:
# Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have
# coverage for one of each, we don't need to run both.
- python: "3.10"
enable-x_64: 1
- python: "3.13"
enable-x_64: 0
name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# End Presubmit Naming Check github-cpu-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

@ -14,10 +14,15 @@ on:
pull_request:
branches:
- main
push:
branches:
- main
- 'release/**'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
run_tests:
@ -25,14 +30,22 @@ jobs:
runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
strategy:
matrix:
python: ["3.10", "3.13"]
runner: ["linux-x86-n2-16"]
enable-x_64: [1, 0]
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
exclude:
# Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have
# coverage for one of each, we don't need to run both.
- python: "3.10"
enable-x_64: 1
- python: "3.13"
enable-x_64: 0
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})"
# End Presubmit Naming Check github-cuda-presubmits
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

@ -31,7 +31,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: 3.11
- run: python -m pip install pre-commit
@ -70,22 +70,13 @@ jobs:
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install .[minimum-jaxlib] -r build/test-requirements.txt
pip install uv
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
- name: Run tests
env:
@ -117,22 +108,13 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
pip install uv
uv pip install --system -r docs/requirements.txt
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
@ -140,7 +122,7 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
@ -160,22 +142,13 @@ jobs:
apt update
apt install -y libssl-dev libsqlite3-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
pip install uv
uv pip install --system -r docs/requirements.txt
- name: Render documentation
run: |
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
@ -195,22 +168,13 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
pip install uv
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
- name: Run tests
env:
@ -236,23 +200,15 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: 3.12
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
- name: Install JAX
run: pip install .[cuda12]
run: |
pip install uv
uv pip install --system .[cuda12]
- name: Build and install example project
run: python -m pip install -v ./examples/ffi[test]
run: uv pip install --system ./examples/ffi[test]
env:
# We test building using GCC instead of clang. All other JAX builds use
# clang, but it is useful to make sure that FFI users can compile using

@ -17,6 +17,10 @@ on:
pull_request:
branches:
- main
push:
branches:
- main
- 'release/**'
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
@ -25,7 +29,8 @@ permissions:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
# Don't cancel in-progress jobs for main/release branches.
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
cloud-tpu-test:

@ -32,13 +32,14 @@ jobs:
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install .[ci]
python -m pip install pytest-xdist -r array-api-tests/requirements.txt
pip install uv
uv pip install --system .[ci]
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.numpy

@ -6,7 +6,7 @@ concurrency:
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
- cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
@ -72,7 +72,7 @@ jobs:
# Create archive to be used with bazel as hermetic python:
cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan
- name: Save CPython with TSAN
- name: Save TSAN CPython
id: cache-cpython-tsan-save
if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
@ -102,9 +102,11 @@ jobs:
# If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz
if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then
echo "Extract cpython from python-tsan.tgz"
pushd .
ls ${GITHUB_WORKSPACE}/python-tsan.tgz
cd ${GITHUB_WORKSPACE} && tar -xvzf python-tsan.tgz
cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz
ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/
popd
fi
export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH
@ -172,7 +174,6 @@ jobs:
--clang_path=/usr/bin/clang-18
# Update the patch to use TSAN instrumented numpy
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
cat .github/workflows/requirements_lock_3_13_ft.patch

@ -33,7 +33,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements

@ -27,7 +27,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.pyver }}
cache: 'pip'

@ -35,7 +35,7 @@ jobs:
with:
path: jax
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
python-version: ${{ matrix.pyver }}
cache: 'pip'

@ -22,6 +22,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.
## jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses

@ -62,6 +62,21 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()
load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
jax_python_wheel_repository(
name = "jax_wheel",
version_key = "_version",
version_source = "//jax:version.py",
)
load(
"@tsl//third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",

@ -7,7 +7,8 @@ flatbuffers
hypothesis
mpmath>=1.3
pillow>=10.4.0
portpicker
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
portpicker; python_version<"3.13"
pytest-xdist
wheel
rich

@ -146,7 +146,7 @@
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
"to `bind`, and in particular we follow a handy internal convention: when we\n",
"call `bind`, we pass values representing array data as positional arguments,\n",
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
"and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This\n",
"calling convention simplifies some core logic (since e.g. instances of the\n",
"`Tracer` class to be defined below can only occur in positional arguments to\n",
"`bind`). The wrappers can also provide docstrings!\n",

@ -133,7 +133,7 @@ The functions that user code calls, like `add` and `sin`, are just wrappers
around calls to `bind`. These wrappers let us control how arguments are passed
to `bind`, and in particular we follow a handy internal convention: when we
call `bind`, we pass values representing array data as positional arguments,
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
calling convention simplifies some core logic (since e.g. instances of the
`Tracer` class to be defined below can only occur in positional arguments to
`bind`). The wrappers can also provide docstrings!

@ -123,7 +123,7 @@ def bind1(prim, *args, **params):
# around calls to `bind`. These wrappers let us control how arguments are passed
# to `bind`, and in particular we follow a handy internal convention: when we
# call `bind`, we pass values representing array data as positional arguments,
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
# and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
# calling convention simplifies some core logic (since e.g. instances of the
# `Tracer` class to be defined below can only occur in positional arguments to
# `bind`). The wrappers can also provide docstrings!

@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
in subsequent runs as a result of a compilation cache miss.
### Pre-compiling multi-node programs on single node
JAX can populate the compilation cache with compiled programs for multiple nodes
on a single node. Preparing the cache on a single node helps to decrease the costly
compilation time on a cluster. To compile and run multi-node programs on a single
node, users can create fake remote devices using
the `jax_mock_gpu_topology` configuration option.
For instance, the snippet below instructs JAX to mock a cluster with four
nodes, each node running eight processes with each process attached to one GPU.
```python
jax.config.update("jax_mock_gpu_topology", "4x8x1")
```
After populating the cache with this config, users can run the program
without recompilation on four nodes, eight processes per node,
one GPU per process.
Important notes:
* The process running the mocked program must have the same amount of GPUs
and the same GPU model as the nodes that would use the cache. For instance,
a mocked topology `8x4x2` must run in a process with two GPUs.
* When running programs with mocked topology, the results of communications
with other nodes are undefined, so the outputs of JAX programs running
in mocked environments will likely be incorrect.
## Logging cache activity
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.

@ -13,12 +13,12 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
find_package(nanobind CONFIG REQUIRED)
set(
JAX_FFI_EXAMPLE_PROJECTS
JAX_FFI_EXAMPLE_CPU_PROJECTS
"rms_norm"
"cpu_examples"
)
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
@ -26,9 +26,16 @@ endforeach()
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()

@ -0,0 +1,62 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include "nanobind/nanobind.h"
#include "cuda_runtime_api.h"
#include "xla/ffi/api/ffi.h"
namespace nb = nanobind;
namespace ffi = xla::ffi;
struct State {
static xla::ffi::TypeId id;
explicit State(int32_t value) : value(value) {}
int32_t value;
};
ffi::TypeId State::id = {};
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
return std::make_unique<State>(42);
}
static ffi::Error StateExecute(cudaStream_t stream, State* state,
ffi::ResultBufferR0<ffi::S32> out) {
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
ffi::Ffi::BindInstantiate());
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Ctx<ffi::State<State>>()
.Ret<ffi::BufferR0<ffi::S32>>());
NB_MODULE(_gpu_examples, m) {
m.def("type_id",
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
m.def("handler", []() {
nb::dict d;
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
return d;
});
}

@ -0,0 +1,24 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax_ffi_example import _gpu_examples
import jax.numpy as jnp
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")
def read_state():
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()

@ -0,0 +1,41 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
class GpuExamplesTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cuda"]):
self.skipTest("Unsupported platform")
# Import here to avoid trying to load the library when it's not built.
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top
self.read_state = gpu_examples.read_state
def test_basic(self):
self.assertEqual(self.read_state(), 42)
self.assertEqual(jax.jit(self.read_state)(), 42)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

@ -45,8 +45,9 @@ array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-
def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
raise ValueError(
"numpy masked arrays are not supported as direct inputs to JAX functions."
" Use arr.filled() to convert the value to a standard numpy array.")
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
@ -54,7 +55,8 @@ core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype),
sharding=core.get_cur_mesh_sharding(core.P(*[None] * x.ndim)))
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
@ -62,7 +64,9 @@ core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
shape = np.shape(x)
return ShapedArray(shape, dtypes.canonicalize_dtype(dtype),
sharding=core.get_cur_mesh_sharding(core.P(*[None] * len(shape))))
for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
@ -74,7 +78,8 @@ def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
weak_type=typ is not bool,
sharding=core.get_cur_mesh_sharding())
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)

@ -323,7 +323,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = api_util.tracing_debug_info(
debug = api_util.debug_info(
"checkpoint / remat", fun,
args, kwargs, static_argnums=static_argnums)
fun_, args = _remat_static_argnums(fun, static_argnums, args)
@ -418,11 +418,11 @@ _dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
def _trace_to_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: lu.TracingDebugInfo
debug: core.DebugInfo
) -> tuple[core.Jaxpr, Sequence[Any], PyTreeDef]:
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = api_util.flatten_fun(lu.wrap_init(fun, debug_info=debug), in_tree)
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' in msg:
@ -447,7 +447,7 @@ def saved_residuals(f: Callable,
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
debug_info = api_util.tracing_debug_info("saved_residuals", f, args, kwargs)
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
@ -699,7 +699,8 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
assert next(ins_iter, None) is None
with source_info_util.extend_name_stack('rematted_computation'):
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info),
in_pvals, False)
# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)

@ -57,11 +57,11 @@ from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, tracing_debug_info,
result_paths, flat_out_axes)
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
apply_flat_fun_nokwargs, check_callable, debug_info,
flat_out_axes)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -452,7 +452,7 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
dbg = tracing_debug_info('value_and_grad', fun, args, kwargs)
dbg = debug_info('value_and_grad', fun, args, kwargs)
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
@ -1021,7 +1021,7 @@ def _mapped_axis_spec(args_flat, in_axes_flat):
try:
# Duck type arrays like BCOO arrays can be passed to vmap.
return shaped_abstractify(arg).sharding.spec[i]
except TypeError:
except (IndexError, TypeError):
return None
temp_spec = None
@ -1426,11 +1426,12 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
dbg = tracing_debug_info(
dbg = debug_info(
"pmap", fun, args, kwargs,
static_argnums=static_broadcasted_tuple)
f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
del dbg
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
@ -1477,9 +1478,6 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f)
dbg = dbg.add_result_paths(res_paths)
f = lu.add_debug_info(f, dbg)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)
@ -2235,7 +2233,7 @@ def _check_sharding(aval, s):
f" invalid value: {s}")
if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
aval = core.get_token_aval()
if not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)

@ -31,7 +31,6 @@ from jax._src.tree_util import (
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip)
from jax._src import traceback_util
@ -582,7 +581,7 @@ def api_hook(fun, tag: str):
return fun
def tracing_debug_info(
def debug_info(
traced_for: str,
fun: Callable,
args: Sequence[Any],
@ -591,17 +590,17 @@ def tracing_debug_info(
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = (),
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
# TODO(necula): check if we really need this, e.g., to speed up tracing.
# TODO(necula): check if we really need this, e.g., to speed up tracing?
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> core.DebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)
def fun_signature(fun: Callable) -> inspect.Signature | None:
@ -619,7 +618,7 @@ _fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
# See DebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
@ -675,30 +674,6 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
arg_names = args_arg_names + kwargs_arg_names
return arg_names
@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
# TODO(necula): simplify this function, all it needs is to add the trace_debug to the Jaxpr
def add_jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
return jaxpr
# TODO(necula): re-enable this safety check
# assert (result_paths is not None) ^ (trace_debug.result_paths_thunk is not None)
if result_paths is None:
result_paths = trace_debug.result_paths_thunk() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)
def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
@ -723,7 +698,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()
# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
@ -737,7 +712,7 @@ def _check_no_aliased_ref_args(dbg, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
@ -748,4 +723,4 @@ def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}")

@ -39,6 +39,7 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
@ -55,7 +56,10 @@ PRNGKeyArray = Any # TODO(jakevdp): fix cycles and import this.
def _get_device(a: ArrayImpl) -> Device:
devices = a.sharding._internal_device_list # pytype: disable=attribute-error
assert len(devices) == 1
if len(devices) != 1:
raise ValueError(
"When making an array from single-device arrays the input arrays must "
f"have one shard each. An argument array had {len(devices)} shard(s).")
return devices[0]
@ -195,54 +199,102 @@ class ArrayImpl(basearray.Array):
self.aval = aval
self._sharding = sharding
self._arrays = [a._arrays[0] for a in arrays]
self._committed = committed
self._npy_value = None
arrays = [a._arrays[0] for a in arrays]
# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays # type: ignore
def _check_and_rearrange(self):
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
if xla_extension_version >= 310:
def _check_and_rearrange(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]
else:
def _check_and_rearrange(self): # type: ignore
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
@property
def shape(self) -> Shape:
@ -1220,7 +1272,7 @@ pxla.shard_arg_handlers[core.Token] = _token_shard_arg
def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed)
core.get_token_aval(), out_sharding, committed)
def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)

@ -35,6 +35,7 @@ from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import mesh as mesh_lib
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
@ -966,7 +967,8 @@ def shard_map_error_check(
raise ValueError(f'Unsupported aval type: {type(v)}')
in_avals[i] = sharder(mesh, new_in_names[i], v)
with core.extend_axis_env_nd(mesh.shape.items()):
with (core.extend_axis_env_nd(mesh.shape.items()),
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh))):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
@ -1202,11 +1204,11 @@ def checkify(f: Callable[..., Out],
in_tree = jtu.tree_structure(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
debug = api_util.tracing_debug_info("checkify", f, args, kwargs)
debug = api_util.debug_info("checkify", f, args, kwargs)
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
debug_info=debug),
in_tree)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, ())
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)

@ -1716,3 +1716,21 @@ memory_fitting_effort = float_state(
default=0.0,
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
)
cpu_collectives_implementation = optional_enum_state(
name='jax_cpu_collectives_implementation',
enum_values=["gloo", "mpi", "megascale"],
default=None,
help=(
"Cross-process collective implementation used on CPU. Must be one of "
'("gloo", "mpi")'),
)
num_cpu_devices = int_state(
name="jax_num_cpu_devices",
default=-1,
help=(
"Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized."),
)

@ -82,31 +82,7 @@ EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects
# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
def safe_result_paths(self, expected: int) -> tuple[str | None, ...]:
"""Get the result_paths with a safety check."""
if len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
DebugInfo = lu.DebugInfo
class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
@ -117,7 +93,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: JaxprDebugInfo | None
_debug_info: DebugInfo | None
@property
def constvars(self) -> list[Var]:
@ -140,13 +116,13 @@ class Jaxpr:
return self._effects
@property
def debug_info(self) -> JaxprDebugInfo | None:
def debug_info(self) -> DebugInfo | None:
return self._debug_info
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: JaxprDebugInfo | None = None):
debug_info: DebugInfo | None = None):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
@ -157,14 +133,14 @@ class Jaxpr:
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional JaxprDebugInfo.
debug_info: optional DebugInfo.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
@ -505,6 +481,8 @@ class Primitive:
map_primitive: bool = False
# set for ref primitives
ref_primitive: bool = False
# set for primitives that can skip canonicalization of values
skip_canonicalization: bool = False
def __init__(self, name: str):
self.name = name
@ -513,6 +491,12 @@ class Primitive:
return f'{self.name}'
def bind(self, *args, **params):
if not config.sharding_in_types.value:
return self._true_bind(*args, **params)
args = args if self.skip_canonicalization else map(canonicalize_value, args)
return self._true_bind(*args, **params)
def _true_bind(self, *args, **params):
for arg in args:
if (isinstance(arg, Tracer)
and not arg._trace.is_valid()
@ -610,8 +594,8 @@ def check_avals_context_mesh(avals, prim_name):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_abstract_mesh()
for a in avals:
if a.sharding.mesh.empty or cur_mesh.empty:
continue
# avals can have meshes with different axis_names so allow that in
# full auto mode.
if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto:
continue
if a.sharding.mesh != cur_mesh:
@ -621,21 +605,6 @@ def check_avals_context_mesh(avals, prim_name):
" error occurs at source: "
f" {source_info_util.summarize(source_info_util.current())}")
# TODO(yashkatariya, dougalm): Remove this and replace with canonicalize_value
# function which casts scalar, numpy arrays, etc to jax arrays so that values
# passed to primitives are always have avals, etc i.e. they are canonical and
# also does mesh casting, etc
def cast_from_auto_to_manual(avals):
if not config.sharding_in_types.value:
return avals
from jax._src.sharding_impls import NamedSharding # type: ignore
cur_mesh = mesh_lib.get_abstract_mesh()
return [a.update(sharding=NamedSharding(cur_mesh, P(*[None] * a.ndim)))
if (not a.sharding.mesh.empty and cur_mesh._are_all_axes_manual and
a.sharding.mesh._are_all_axes_auto)
else a for a in avals]
# -------------------- tracing --------------------
TracerType = TypeVar('TracerType', bound='Tracer')
@ -1775,6 +1744,38 @@ def _make_lengths_same(sharding, ndim):
assert False, "unreachable"
# TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values
# passed to primitives are always have avals, etc i.e. they are canonical.
def canonicalize_value(val):
if not config.sharding_in_types.value:
return val
from jax._src.pjit import NamedSharding, mesh_cast # type: ignore
try:
aval = get_aval(val)
except TypeError:
return val
if not isinstance(aval, ShapedArray):
return val
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh == aval.sharding.mesh: # type: ignore
return val
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: # type: ignore
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
if aval.sharding.mesh.empty and not cur_mesh.empty: # type: ignore
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
return val
def get_cur_mesh_sharding(spec=None):
from jax._src.sharding_impls import NamedSharding # type: ignore
spec = P() if spec is None else spec
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def modify_spec_for_auto_manual(spec, mesh) -> P:
@ -1791,13 +1792,16 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
return P(*new_spec)
def _maybe_modify_sharding(sharding, ndim):
if len(sharding.spec) == 0 or all(s is None for s in sharding.spec):
if len(sharding.spec) != ndim:
return _make_lengths_same(sharding, ndim)
return sharding
if sharding.mesh._are_all_axes_explicit:
out = sharding
elif all(s is None for s in sharding.spec):
out = sharding
else:
out = sharding.with_spec(modify_spec_for_auto_manual(
sharding.spec, sharding.mesh))
return sharding
out = sharding.with_spec(modify_spec_for_auto_manual(
sharding.spec, sharding.mesh))
if (len(out.spec) != ndim and
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
out = _make_lengths_same(out, ndim)
@ -1807,18 +1811,14 @@ def _maybe_modify_sharding(sharding, ndim):
def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding # type: ignore
if sharding is not None:
out_s = _maybe_modify_sharding(sharding, ndim)
if len(out_s.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
else:
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh.empty:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
out_s = NamedSharding(cur_mesh, P(*[None] * ndim))
if sharding is None:
return NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
out_s = _maybe_modify_sharding(sharding, ndim)
if len(out_s.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {out_s.spec}, aval.ndim {ndim} and sharding {out_s}")
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
@ -2112,7 +2112,8 @@ class AbstractToken(AbstractValue):
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
def get_token_aval():
return ShapedArray((0,), np.dtype(np.bool_), sharding=get_cur_mesh_sharding())
# Concrete token object
class Token:
@ -2377,7 +2378,8 @@ def dim_constant(ct: int):
return np.int64(ct)
def dim_value_aval() -> AbstractValue:
return ShapedArray((), dim_value_dtype(), weak_type=True)
return ShapedArray((), dim_value_dtype(), weak_type=True,
sharding=get_cur_mesh_sharding())
# ------------------- Call -------------------
@ -2385,6 +2387,9 @@ class CallPrimitive(Primitive):
multiple_results = True
call_primitive = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
args = fun_and_args[1:]
@ -2393,7 +2398,8 @@ class CallPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
jaxpr, ())
if config.dynamic_shapes.value:
subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
@ -2425,8 +2431,11 @@ class MapPrimitive(Primitive):
multiple_results = True
map_primitive = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun = fun_and_args[0]
fun: lu.WrappedFun = fun_and_args[0]
args = fun_and_args[1:]
assert len(params['in_axes']) == len(args)
return trace.process_map(self, fun, args, params)
@ -2436,8 +2445,9 @@ class MapPrimitive(Primitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr), jaxpr, ())
jaxpr: Jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(eval_jaxpr,
debug_info=jaxpr.debug_info), jaxpr, ())
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params

@ -119,9 +119,11 @@ def element_type_to_backend_config_type_mapping(dtype):
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def get_max_seg_per_batch(q_offsets):
return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1
def create_dot_product_attention_backend_config_base(
batch, num_heads, seq_q, seq_kv, dtype,fmha_scale, mask_type, layout, is_bwd
batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd
):
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
# P: BMM1 output in shape of BNTS
@ -226,6 +228,7 @@ def create_dot_product_attention_backend_config(
mask_type,
layout,
sliding_window_length,
max_seg_per_batch,
is_bwd
):
backend_config = create_dot_product_attention_backend_config_base(
@ -237,6 +240,7 @@ def create_dot_product_attention_backend_config(
backend_config['cudnn_fmha_backend_config']["dropout_rate"] = dropout_rate
backend_config['cudnn_fmha_backend_config']["seed"] = seed
backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length
backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch
return json.dumps(backend_config)
def create_dot_product_attention_fp8_backend_config(
@ -268,7 +272,8 @@ get_fp8_custom_call_name = functools.partial(
get_custom_call_name, has_bias=False, has_dropout=False, is_fp8=True
)
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen,
q_offsets, kv_offsets, layout):
def check_eq(a, b, c, msg):
if not (a == b == c):
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
@ -300,36 +305,36 @@ def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
if kS != vS:
raise ValueError(f"KV must have same seq length, got {kS} vs {vS}")
# check bias/q_seqlen/kv_seqlen
# check bias
if bias is not None:
_, _, bT, bS = bias.shape
if bT != qT or bS != vS:
raise ValueError(
f"Bias must have same seq length as QKV, got {bT} and {bS}")
if q_seqlen is not None:
q_seq_dtype = q_seqlen.dtype
q_seq_rank = len(q_seqlen.shape)
if q_seq_dtype != jnp.int32:
raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}")
if q_seq_rank != 1:
raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}")
q_seq_b = q_seqlen.shape[0]
if q_seq_b != qB:
raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}")
if kv_seqlen is not None:
kv_seq_dtype = kv_seqlen.dtype
kv_seq_rank = len(kv_seqlen.shape)
if kv_seq_dtype != jnp.int32:
raise ValueError(
f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}")
if kv_seq_rank != 1:
raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}")
kv_seq_b = kv_seqlen.shape[0]
if kv_seq_b != qB:
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
# check q_seqlen/kv_seqlen/q_offsets/kv_offsets
expected_rank = 2 if q_offsets is not None else 1
def check_seqlen_offsets(tensor, name):
if tensor is not None:
dtype = tensor.dtype
rank = len(tensor.shape)
if dtype != jnp.int32:
raise ValueError(f"{name} must have int32 datatype, got {dtype}")
if rank != expected_rank:
raise ValueError(f"{name} must have a rank of {expected_rank}, got {rank}")
b = tensor.shape[0]
if b != qB:
raise ValueError(f"{name} must have same batch as Q, got {b}")
check_seqlen_offsets(q_seqlen, "q_seqlen")
check_seqlen_offsets(kv_seqlen, "kv_seqlen")
check_seqlen_offsets(q_offsets, "q_offsets")
check_seqlen_offsets(kv_offsets, "kv_offsets")
def check_is_flash_attention(
query, key, layout: int, cudnn_version, has_bias, is_training, is_fp8=False):
query, key, layout: int, cudnn_version, has_bias, is_training, is_packed,
is_fp8=False):
# Extract sequence length (T) and head dim (H) based on layout
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
@ -363,6 +368,9 @@ def check_is_flash_attention(
f"Unsupported sequence length Q {T}, KV {S}."
)
if is_packed and cudnn_version < 90600:
raise NotImplementedError("Packed layout requires cudnn version >= 9.6.")
def check_cudnn_version():
# check if cuDNN is installed
if cuda_versions is None:
@ -378,78 +386,142 @@ def check_compute_capability(capability):
return current >= target
def _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, False)
query, key, layout, cudnn_version, bias is not None, False,
get_max_seg_per_batch(q_offsets) > 1)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=False)
output = outputs[0]
return output
def _dot_product_attention_fwd_rule(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, True)
query, key, layout, cudnn_version, bias is not None, True,
get_max_seg_per_batch(q_offsets) > 1)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=True)
res = (query, key, value, bias, q_seqlen, kv_seqlen,
outputs[1], outputs[0])
res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
kv_offsets, outputs[1], outputs[0])
return outputs[0], res
def _dot_product_attention_bwd_rule(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training, res, grad_output):
(query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output) = res
(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output) = res
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length
)
grads = (*grads,) + (None,) * (6 - len(grads))
grads = (*grads,) + (None,) * (8 - len(grads))
return grads
def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key):
# fix seqlen and offsets to what cuDNN expects in sequence packing.
# cuDNN expects seqlen to have shape [S] where S is the total number of segments
# while the SDPA API accetps seqlen with shape [B, M] where B is the batch and M
# is the maximum number of segments of one batch. B x M is larger than S and seqlen
# is filled with -1 for padded regions. Therefore, we need to shift all non negative
# values to left side to form a correct seqlen. Similar layout is required for
# offsets tensors.
# cuDNN expects offsets to have offset for each segment starting from first segment
# while SDPA API accetps offsets to have offset for each segment starting from
# current batch, therefore we need to calculate accumulative offset of each segment
# starting from first segment.
def _shift_to_left(x, fill_value):
# shift any non-negative value to left
# [[1, 3, -1, -1], [2, 3, 4, -1]]
# -> [[1, 3, 2, 3], [4, -1, -1, -1]]
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(x >= 0, size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def _cu_offset(offsets, max_seq):
# calculate accumulative offset by batch
# [[1, 3, 5, 7], [4, 5, -1, -1]], max_seq = 8
# -> [[1, 3, 5, 7], [12, 13, -1, -1]]
batch = offsets.shape[0]
offsets = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis],
offsets,
)
return offsets
if get_max_seg_per_batch(q_offsets) > 1:
B, T, N, H = query.shape
_, S, _, _ = key.shape
q_seqlen = _shift_to_left(q_seqlen, -1)
kv_seqlen = _shift_to_left(kv_seqlen, -1)
q_offsets = _cu_offset(q_offsets, T)
kv_offsets = _cu_offset(kv_offsets, S)
q_offsets = _shift_to_left(q_offsets, -1)
kv_offsets = _shift_to_left(kv_offsets, -1)
# mark any invalid entries as maximum offset
q_offsets = jnp.where(q_offsets < 0, B * T, q_offsets)
kv_offsets = jnp.where(kv_offsets < 0, B * S, kv_offsets)
# multiply by stride_per_token to get correct offsets
# do it here because real stride changes after sharding
q_offsets = q_offsets * N * H
kv_offsets = kv_offsets * N * H
return q_seqlen, kv_seqlen, q_offsets, kv_offsets
def _dot_product_attention_fwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
# args: {Q, K, V, mask*, bias*}
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
outputs = _dot_product_attention_fwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
return outputs
def _dot_product_attention_bwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length):
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
q_seqlen, kv_seqlen, q_offsets, kv_offsets = \
_fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key)
grads = _dot_product_attention_bwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length)
return grads
def _dot_product_attention_fwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
*, scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
if layout == AttentionLayout.BNTH.value:
@ -459,7 +531,9 @@ def _dot_product_attention_fwd_abstract(
B, T, N, _ = query.shape
_, S, _, _ = key.shape
output_shape = query.shape
softmax_stat_shape = (B, N, T)
max_seg_per_batch = get_max_seg_per_batch(q_offsets)
softmax_stat_shape = (B * max_seg_per_batch, N, T)
if is_training:
return (
@ -472,9 +546,9 @@ def _dot_product_attention_fwd_abstract(
)
def _dot_product_attention_bwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type,
layout, sliding_window_length):
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, *, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
@ -511,9 +585,9 @@ def _dot_product_attention_bwd_abstract(
)
def _dot_product_attention_fwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type,
layout, sliding_window_length, is_training):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -530,24 +604,30 @@ def _dot_product_attention_fwd_cuda_lowering(
output_layout = (3, 1, 2, 0)
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
output_shape = (B, N, T, H)
softmax_stat_shape = (B, N, T)
softmax_stat_shape = (B * max_seg_per_batch, N, T)
workspace_shape = (0,)
workspace_type = ir.IntegerType.get_unsigned(8)
has_bias, _ = variadic_args
backend_config = create_dot_product_attention_backend_config(
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=False,
)
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*}
mask_type, layout, sliding_window_length, max_seg_per_batch,
is_bwd=False)
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}}
# {output, activation*, workspace}
has_dropout = dropout_rate > 0
has_bias, _ = variadic_args
operands = [query, key, value]
if has_bias:
operands.append(bias)
if has_padding(mask_type):
if has_padding(mask_type) or max_seg_per_batch > 1:
operands.append(q_seqlen)
operands.append(kv_seqlen)
if max_seg_per_batch > 1:
operands.append(q_offsets)
operands.append(kv_offsets)
custom_call_name = get_custom_call_name(has_bias, has_dropout, False)
if is_training:
@ -581,9 +661,9 @@ def _dot_product_attention_fwd_cuda_lowering(
return [hlo.transpose(out.results[0], output_transpose_perm)]
def _dot_product_attention_bwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale, seed, dropout_rate,
variadic_args, mask_type, layout, sliding_window_length):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
@ -607,23 +687,29 @@ def _dot_product_attention_bwd_cuda_lowering(
grad_query_shape = (B, q_N, T, H)
grad_key_shape = (B, k_N, S, H)
grad_value_shape = (B, k_N, S, H)
has_bias, has_dbias = variadic_args
max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type))
backend_config = create_dot_product_attention_backend_config(
B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=True,
)
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*}
mask_type, layout, sliding_window_length, max_seg_per_batch,
is_bwd=True)
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*,
# q_offsets*, kv_offsets*}
# {dQ, dK, dV, dbias*, workspace}
has_dropout = dropout_rate > 0
has_bias, has_dbias = variadic_args
# create operands
operands = [query, key, value, activation, grad_output]
if has_bias:
# flash attention requires bias in the bwd for remat
operands.append(bias)
operands.append(fwd_output)
if has_padding(mask_type):
if has_padding(mask_type) or max_seg_per_batch > 1:
operands.append(q_seqlen)
operands.append(kv_seqlen)
if max_seg_per_batch > 1:
operands.append(q_offsets)
operands.append(kv_offsets)
# get custom call name
custom_call_name = get_custom_call_name(has_bias, has_dropout, True)
@ -674,7 +760,8 @@ def _dot_product_attention_fwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length, is_training):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, kv_seqlen = batched_args
query, key, value, bias, q_seqlen, kv_seqlen, \
q_offsets, kv_offsets = batched_args
query_bdim = batch_dims[0]
if is_training:
out_bdims = query_bdim, query_bdim
@ -701,9 +788,9 @@ def _dot_product_attention_fwd_batcher(
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
# reshape to original shape
@ -720,8 +807,8 @@ def _dot_product_attention_bwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, \
kv_seqlen, activation, fwd_output, grad_output = batched_args
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \
activation, fwd_output, grad_output = batched_args
query_bdim = batch_dims[0]
out_bdims = query_bdim, query_bdim, query_bdim
@ -757,8 +844,8 @@ def _dot_product_attention_bwd_batcher(
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
@ -834,7 +921,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo
return [out_sharding]
_dot_product_attention_fwd_lower = custom_partitioning(
_dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
_dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _dot_product_attention_fwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
@ -883,7 +970,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args):
return out_shardings
_dot_product_attention_bwd_lower = custom_partitioning(
_dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15)
_dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17)
)
def _dot_product_attention_bwd_infer_sharding_from_operands(
@ -1003,13 +1090,15 @@ dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_bwd_p_wrapper
)
@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Array,
q_seqlen: Array,
kv_seqlen: Array,
q_offsets: Array,
kv_offsets: Array,
scale: float,
seed: int,
dropout_rate: float,
@ -1019,9 +1108,10 @@ def _dot_product_attention(query: Array,
sliding_window_length: int | None,
cudnn_version: int):
output = _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length,
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
cudnn_version=cudnn_version)
return output
@ -1612,6 +1702,8 @@ def dot_product_attention(
mask: Array | None = None,
q_seqlen: Array | None = None,
kv_seqlen: Array | None = None,
q_offsets: Array | None = None,
kv_offsets: Array | None = None,
fp8_params: FP8Params | None = None,
*,
scale: float = 1.0,
@ -1647,8 +1739,26 @@ def dot_product_attention(
value: Values to be used in attention with a shape of BSNH or BNSH.
bias: Bias to be added to logits with a shape of BNTS.
mask: Mask used to filter out logits with a shape of BNTS.
q_seqlen: Non padded sequence length of Queries with a shape of B.
kv_seqlen: Non padded sequence length of Keys and Values with a shape of B.
q_seqlen: Non padded sequence length of query with a shape of B.
If q_offsets is set, q_seqlen should have shape [B,M] where M is the
maximum number of segments per batch. For batch that has less segments
than maximum segments, fill the padded entries with -1.
kv_seqlen: Non padded sequence length of key and value with a shape of B.
If kv_offsets is set, kv_seqlen should have shape [B,M] where M is the
maximum number of segments per batch. For batch that has less segments
than maximum segments, fill the padded entries with -1.
q_offsets: offset of each segment packed in query with a shape of [B,M+1]
where M is the maximum number of segments per batch. For batch that has
less segments than maximum segments, fill the padded entries with -1.
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
size 1, q_offsets = [[0,1,2,-1], [0,1,-1,-1]]. q_seqlen should be set
to indicate the size of each segment.
kv_offsets: offset of each segment packed in key with a shape of [B,M+1]
where M is the maximum number of segments per batch. For batch that has
less segments than maximum segments, fill the padded entries with -1.
E.g, if 2 batches has 3 and 2 segments respectively, each segment has
size 1, kv_offsets = [[0,1,2,-1], [0,1,-1,-1]]. kv_seqlen should be set
to indicate the size of each segment.
scale: Scale for the query.
dropout_rate: Dropout rate.
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
@ -1679,7 +1789,7 @@ def dot_product_attention(
f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}"
)
check_fp8_params(fp8_params)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
output, amax_s, amax_o = _dot_product_attention_fp8(
query, key, value, fp8_params,
scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version
@ -1691,6 +1801,8 @@ def dot_product_attention(
if sliding_window_length is not None and sliding_window_length <= 0:
raise ValueError(
f"Require sliding_window_length > 0, got {sliding_window_length}")
if q_offsets is not None and (q_seqlen is None or kv_seqlen is None):
raise ValueError("Require q_seqlen and kv_seqlen to use packed layout")
if bias is not None:
# reshape bias to have 4D shape
@ -1712,7 +1824,7 @@ def dot_product_attention(
bias = bias + mask
# check if input shape and data type is compatiable
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout)
has_bias = bias is not None
has_dbias = has_bias and \
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
@ -1724,8 +1836,12 @@ def dot_product_attention(
q_seqlen = jnp.zeros(0, dtype=query.dtype)
if kv_seqlen is None:
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
if q_offsets is None:
q_offsets = jnp.zeros(0, dtype=query.dtype)
if kv_offsets is None:
kv_offsets = jnp.zeros(0, dtype=query.dtype)
output = _dot_product_attention(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length,
cudnn_version)
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout.value,
sliding_window_length, cudnn_version)
return output

@ -147,13 +147,13 @@ class custom_vmap:
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
debug = api_util.tracing_debug_info("custom_vmap", self.fun, args, {})
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.fun, debug_info=debug),
in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None

@ -127,12 +127,12 @@ class custom_dce:
"def_dce."
)
rule_name = util.fun_name(self.dce_rule)
debug = api_util.tracing_debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.tracing_debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_dce", self.fun,
args, {},
static_argnums=self.static_argnums)
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
args, {},
static_argnums=self.static_argnums)
args = api_util.resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)
@ -147,11 +147,11 @@ class custom_dce:
)
static_args = [args[i] for i in self.static_argnums]
dce_rule = api_util.prepend_static_args(
lu.wrap_init(self.dce_rule), static_args
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
)
else:
fun = lu.wrap_init(self.fun, debug_info=debug)
dce_rule = lu.wrap_init(self.dce_rule)
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
dyn_args = args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
@ -176,7 +176,7 @@ class custom_dce:
)
assert self.dce_rule is not None
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
flat_rule, in_avals, debug_rule
flat_rule, in_avals
)
# This second round of DCE is used to work out which inputs are actually
@ -191,7 +191,7 @@ class custom_dce:
return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
out_avals = closed_call.out_avals
out_flat = custom_dce_p.bind(
@ -366,7 +366,8 @@ def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
# that most users of this API would compose this with a custom_jvp or
# custom_vjp, which makes this less urgent.
out = core.call_p.bind(
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr)), *primals, *tangents
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
)
out_primals, out_tangents = util.split_list(out, [len(out_nz)])

@ -348,6 +348,9 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args
class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, args, params):
fun, jvp, tracers = args[0], args[1], args[2:]
return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params)
@ -866,6 +869,9 @@ def _temporary_shape_exception(a, a_) -> bool:
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, args, params):
fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:]
return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params)

@ -308,7 +308,7 @@ class custom_partitioning:
that describes the sharding rule, or a Callable that produces either of
these. We borrow the idea from the einops.rearrange string , to use a space
separator between factors and allow multiple letters factor names. See
[jax-shardy-guide](https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb)
`jax-shardy-guide <https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb>`_
for more details and examples on how to use this.
When config.use_shardy_partitioner.value is True, `sharding_rule` is used;
@ -468,9 +468,9 @@ class custom_partitioning:
def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs)
debug = api_util.tracing_debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
debug = api_util.debug_info("custom_partitioning", self.fun,
args, kwargs,
static_argnums=self.static_argnums)
if self.static_argnums:
static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
@ -485,13 +485,13 @@ class custom_partitioning:
_check_for_tracers(static_args)
else:
static_args = []
f_, dyn_args = lu.wrap_init(self.fun), args
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.get_aval(x) for x in args_flat]
mesh = mesh_lib.thread_resources.env.physical_mesh
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

@ -155,6 +155,9 @@ class CustomTransposePrimitive(core.Primitive):
map_primitive = False
multiple_results = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, call_args, params):
call, tracers = call_args[0], call_args[1:]
return trace.process_custom_transpose(self, call, tracers, **params)

@ -1042,7 +1042,7 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
*_CPU_FFI_KERNELS,
*_GPU_FFI_KERNELS,
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"cu_threefry2x32", "cu_threefry2x32_ffi",
"cu_threefry2x32_ffi",
# Triton IR does not guarantee stability.
# "__gpu$xla.gpu.triton",
# cholesky on CPU

File diff suppressed because one or more lines are too long

@ -98,7 +98,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment]
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, None)
residual_avals = map(get_aval, consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
@ -143,11 +143,12 @@ def linearize_jaxpr(
return _linearize_jaxpr(jaxpr, tuple(nonzeros))
@weakref_lru_cache
@source_info_util.reset_name_stack()
def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = lu.TracingDebugInfo.from_jaxpr(jaxpr)
dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
@ -166,16 +167,17 @@ def _linearize_jaxpr(
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
del lin_trace, ans, tracers, new_arg
debug_info = jaxpr.jaxpr.debug_info
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info)
tangent_trace.invalidate()
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
@ -192,7 +194,8 @@ def direct_linearize(traceable: lu.WrappedFun,
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with core.set_current_trace(linearize_trace, check_leaks=True):
with (core.set_current_trace(linearize_trace, check_leaks=True),
source_info_util.transform_name_stack('jvp')):
if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
aux_primals = [x.primal
@ -207,7 +210,7 @@ def direct_linearize(traceable: lu.WrappedFun,
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
pe.PartialVal.known(zeros_like_aval(t.aval))
@ -587,6 +590,10 @@ class LinearizeTrace(Trace):
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.tangent_trace = tangent_trace
self._name_stack_prefix_len = len(source_info_util.current_name_stack())
def _name_stack_suffix(self):
return source_info_util.current_name_stack()[self._name_stack_prefix_len:]
def to_primal_tangent_pair(self, val):
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
@ -605,7 +612,8 @@ class LinearizeTrace(Trace):
with core.set_current_trace(self.parent_trace):
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
with core.set_current_trace(self.tangent_trace):
with (core.set_current_trace(self.tangent_trace),
source_info_util.set_name_stack(self._name_stack_suffix())):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
@ -1019,12 +1027,14 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
debug_info = jaxpr.jaxpr.debug_info
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(
f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
@lu.transformation_with_aux2

@ -451,8 +451,6 @@ class AxisData:
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
if orig_sharding.mesh.empty:
return None
val = axis_data.explicit_mesh_axis
new_spec = P(*tuple_insert(orig_sharding.spec, axis, val))
return NamedSharding(orig_sharding.mesh, new_spec)
@ -760,7 +758,8 @@ def _batch_jaxpr2(
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
in_axes2, avals_in = unzip2([

@ -54,11 +54,14 @@ from jax._src.sharding_impls import (AUTO, NamedSharding,
SdyArraySharding, SdyArrayShardingList)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects, ir, passmanager
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
from jax._src.lib.mlir import register_jax_dialects
from jax._src.state.types import AbstractRef
# mypy: ignore-errors
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@ -469,11 +472,20 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
if loc is None:
frame = source_info_util.raw_frame_to_frame(code, lasti)
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
if xla_extension_version >= 309:
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
frame.end_line,
frame.end_column,
)
else:
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
ctx.traceback_caches.location_cache[code_lasti] = loc
frame_locs.append(loc)
@ -1121,16 +1133,20 @@ def lower_jaxpr_to_module(
"In multi-platform lowering either all or no lowering platforms "
f"should support donation. Lowering for {platforms} of which "
f"only {platforms_with_donation} support donation")
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
input_output_aliases, in_avals, out_avals, donated_args,
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts,
result_shardings if num_partitions > 1 else None)
if (num_partitions > 1 and
(result_shardings is None or
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
any(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
for s in result_shardings))):
xla_donated_args = donated_args
donated_args = [False] * len(donated_args)
if xla_donated_args is None:
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
input_output_aliases, in_avals, out_avals, donated_args,
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts)
if xla_donated_args is None:
xla_donated_args = [False] * len(donated_args)
for input_id in range(len(donated_args)):
if donated_args[input_id]:
xla_donated_args[input_id] = True
donated_args[input_id] = False
if any(donated_args):
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
@ -1225,14 +1241,15 @@ def lower_jaxpr_to_module(
def _set_up_aliases(input_output_aliases, avals_in, avals_out,
donated_args, arg_memory_kinds, result_memory_kinds,
in_layouts, out_layouts):
in_layouts, out_layouts, result_shardings):
if input_output_aliases is None:
input_output_aliases = [None] * len(avals_in)
else:
input_output_aliases = list(input_output_aliases)
# To match-up in-avals to out-avals we only care about the number of
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
strip_metadata = lambda a: a.strip_weak_type()
strip_metadata = lambda a: (a if a is core.abstract_token else
core.ShapedArray(a.shape, a.dtype))
avals_in = map(strip_metadata, avals_in)
avals_out = map(strip_metadata, avals_out)
@ -1283,7 +1300,10 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,
" for the input and output layout to be chosen by XLA and not the"
" layout of the input which might not be optimal.")
if (in_layouts is None or out_layouts is None or
in_layouts[input_id] == out_layouts[i]):
in_layouts[input_id] == out_layouts[i]) and (
result_shardings is None or not (
(s := result_shardings[i]) is None or
isinstance(s, AUTO) or contains_unconstrained(s))):
input_output_aliases[input_id] = i
else:
# Fallback to xla donation if layouts don't match.
@ -1393,7 +1413,6 @@ def lower_jaxpr_to_fun(
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)
# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars

@ -42,7 +42,6 @@ from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src import tree_util
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
tree_flatten, tree_structure)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@ -502,7 +501,7 @@ call_param_updaters[core.closed_call_p] = _closed_call_param_updater
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
@ -590,7 +589,7 @@ def trace_to_subjaxpr_nounits(
@lu.transformation2
def trace_to_subjaxpr_nounits2(
f,
f: Callable,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
@ -932,7 +931,7 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
debug_info=lu.TracingDebugInfo.from_jaxpr(jaxpr))
debug_info=jaxpr.jaxpr.debug_info)
cell = []
def fun(*known_vals_in):
@ -951,7 +950,9 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
return [*known_vals_out, *residuals]
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, debug_info=f.debug_info),
known_avals)
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
# check jaxpr_known and jaxpr_unknown in isolation
@ -1125,7 +1126,7 @@ def _partial_eval_jaxpr_custom_cached(
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects)
known_eqns, known_effects, jaxpr.debug_info)
config.enable_checks.value and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars)
@ -1334,10 +1335,10 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:
def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = jaxpr.replace(outvars=outvars, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
@ -1422,10 +1423,10 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)
dbg = jaxpr.debug_info and core.JaxprDebugInfo(
dbg = jaxpr.debug_info and core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
tuple(v for v, b in zip(jaxpr.debug_info.arg_names, used_inputs) if b),
tuple(v for v, b in zip(jaxpr.debug_info.result_paths, used_outputs) if b))
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
@ -1623,9 +1624,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: lu.TracingDebugInfo | None
debug_info: core.DebugInfo | None
def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_tracer = {}
@ -1642,7 +1643,9 @@ class JaxprStackFrame:
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer]
def to_jaxpr(self, trace: DynamicJaxprTrace,
out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None,
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
@ -1655,7 +1658,8 @@ class JaxprStackFrame:
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
@ -1809,7 +1813,7 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
def __init__(self, debug_info: lu.TracingDebugInfo | None):
def __init__(self, debug_info: core.DebugInfo | None):
self.frame = JaxprStackFrame(debug_info)
def invalidate(self):
@ -1948,7 +1952,7 @@ class DynamicJaxprTrace(core.Trace):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]):
jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic(
f, reduced_in_avals, f.debug_info)
f, reduced_in_avals)
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
if ordered_effects:
raise ValueError("Ordered effects not supported for "
@ -2072,8 +2076,9 @@ class DynamicJaxprTrace(core.Trace):
self.frame.add_eqn(eqn)
return out_tracers
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
return self.frame.to_jaxpr(self, out_tracers)
def to_jaxpr(self, out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None):
return self.frame.to_jaxpr(self, out_tracers, debug_info)
custom_staging_rules: dict[Primitive, Callable] = {}
@ -2114,14 +2119,12 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: lu.TracingDebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
trace = DynamicJaxprTrace(debug_info)
trace = DynamicJaxprTrace(fun.debug_info)
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
@ -2129,15 +2132,15 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_returned_refs(debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
_check_no_returned_refs(fun.debug_info, out_tracers)
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info)
del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
def _check_no_returned_refs(
dbg: lu.TracingDebugInfo | None,
dbg: core.DebugInfo | None,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
@ -2148,10 +2151,8 @@ def _check_no_returned_refs(
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
loc = (f' at output tree path {tree_util.keystr(ls[i])}' # type: ignore
if (dbg.result_paths_thunk and
(ls := dbg.result_paths_thunk()) and
ls[i]) else '')
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
loc = f' at output tree path {result_paths[i]}'
frame = t._trace.frame
v = frame.tracer_to_var.get(id(t))
eqn = next((e for e in frame.eqns if v in e.outvars), None)
@ -2160,7 +2161,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.arg_names[frame.invars.index(v)] # type: ignore
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
@ -2172,10 +2173,10 @@ def _check_no_returned_refs(
@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
fun: lu.WrappedFun,
) -> tuple[Jaxpr, OutputType, list[Any]]:
trace = DynamicJaxprTrace(debug_info)
trace = DynamicJaxprTrace(fun.debug_info)
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
in_avals, keep_inputs = unzip2(fun.in_type)
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)

@ -33,7 +33,6 @@ import numpy as np
import jax
from jax._src import api
from jax._src import api_util
from jax._src import compiler
from jax._src import config
from jax._src import core
@ -652,7 +651,6 @@ class ParallelCallableInfo:
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
debug_info: api_util.TracingDebugInfo | None
@cached_property
def local_devices(self):
@ -723,8 +721,7 @@ def stage_parallel_callable(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pci.debug_info)
jaxpr = api_util.add_jaxpr_debug_info(jaxpr, pci.debug_info)
fun, sharded_avals)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
@ -758,7 +755,7 @@ def get_pmap_jaxpr(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals, fun.debug_info)
in_axes, out_axes_thunk, avals)
with core.extend_axis_env_nd([(axis_name, axis_size)]):
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
@ -964,7 +961,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
jaxpr_debug_info: core.DebugInfo
def build_execute_fun(self):
input_indices = []
@ -992,7 +989,7 @@ class UnloadedPmapExecutable:
return PmapExecutable(
self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self.jaxpr_debug_info, self)
self.local_input_avals, self)
@staticmethod
def from_hlo(hlo: ir.Module,
@ -1004,7 +1001,7 @@ class UnloadedPmapExecutable:
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
jaxpr_debug_info: core.DebugInfo,
platforms: Sequence[str],
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None):
@ -1119,24 +1116,23 @@ class UnloadedPmapExecutable:
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
"_unloaded_executable"]
"fingerprint", "in_avals", "_unloaded_executable"]
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, jaxpr_debug_info, unloaded_executable):
in_avals,
unloaded_executable: UnloadedPmapExecutable):
self.xla_executable = xla_executable
self._unsafe_call = None
self.build_unsafe_call = build_unsafe_call
self.fingerprint = fingerprint
self.in_avals = in_avals
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
return self._unsafe_call # type: ignore
# -- stages.XlaExecutable overrides
@ -1147,7 +1143,8 @@ class PmapExecutable(stages.XlaExecutable):
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(core.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
check_arg_avals_for_call(self.in_avals, arg_avals,
self._unloaded_executable.jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
@ -2127,7 +2124,7 @@ MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
debug_info: core.JaxprDebugInfo | None
debug_info: core.DebugInfo | None
@lru_cache(maxsize=2048)
@ -2588,7 +2585,7 @@ def try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
recover_in_s, recover_in_aval = None, None
for in_s, in_aval in safe_zip(in_shardings, in_avals):
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
if isinstance(in_s, NamedSharding):
recover_in_s, recover_in_aval = in_s, in_aval
break
if recover_in_s is None:
@ -3199,14 +3196,14 @@ def cc_shard_arg(x, sharding, layout):
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
jaxpr_debug_info: core.DebugInfo | None = None):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
if jaxpr_debug_info is not None:
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
else:
num_args = len(ref_avals)
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
@ -3258,7 +3255,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce,
in_xla_shardings: Sequence[JSharding],
in_xla_layouts: Sequence[DeviceLocalLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None,
jaxpr_debug_info: core.DebugInfo | None,
kept_var_idx: set[int]) -> None:
from jax._src.array import ArrayImpl
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.

@ -53,19 +53,19 @@ def _typecheck_param(prim, param, name, msg_required, pred):
def _initial_style_open_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
wrapped_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info),
in_tree)
jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
wrapped_fun, in_avals, debug_info)
wrapped_fun, in_avals)
return jaxpr, consts, out_tree(), attrs_tracked
@weakref_lru_cache
def _initial_style_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
jaxpr, consts, out_tree, () = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -74,7 +74,7 @@ def _initial_style_jaxpr(fun: Callable,
def _initial_style_jaxpr_attrs(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: api_util.TracingDebugInfo):
debug_info: core.DebugInfo):
jaxpr, consts, out_tree, attrs_tracked = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -83,7 +83,7 @@ def _initial_style_jaxpr_attrs(fun: Callable,
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable],
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
debug_infos: Sequence[api_util.TracingDebugInfo]):
debug_infos: Sequence[core.DebugInfo]):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for

@ -134,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands)
dbgs = [api_util.tracing_debug_info("switch", branch, operands, {})
dbgs = [api_util.debug_info("switch", branch, operands, {})
for branch in branches]
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
@ -237,10 +237,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(core.get_aval, ops))
dbg_true_fun = api_util.tracing_debug_info("cond", true_fun, operands, {})
dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
if config.mutable_array_checks.value:
api_util._check_no_aliased_ref_args(dbg_true_fun, ops_avals, ops)
dbg_false_fun = api_util.tracing_debug_info("cond", false_fun, operands, {})
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals,
[dbg_true_fun, dbg_false_fun])
@ -561,7 +561,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
effects_known = _join_cond_effects(branches_known)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
effects_known, eqn.source_info)
effects_known, eqn.source_info, eqn.ctx)
# Build the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
@ -569,7 +569,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
effects_staged = _join_cond_effects(branches_staged)
eqn_staged = pe.new_jaxpr_eqn(
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
cond_p, params_staged, effects_staged, eqn.source_info)
cond_p, params_staged, effects_staged, eqn.source_info, eqn.ctx)
new_vars = [*new_inst, *res_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
@ -684,7 +684,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_effects, eqn.source_info)
eqn.primitive, new_params, new_effects, eqn.source_info, eqn.ctx)
assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
for jaxpr in new_params['branches'])

@ -195,7 +195,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
def _create_jaxpr(init):
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))
dbg = api_util.tracing_debug_info("scan", f, (init, xs), {})
dbg = api_util.debug_info("scan", f, (init, xs), {})
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, dbg)
@ -585,7 +585,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
call_jaxpr.effects, eqn.source_info)
call_jaxpr.effects, eqn.source_info, eqn.ctx)
jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals),
jaxpr_staged_resin_,
@ -609,7 +609,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
_, outvars = partition_list(out_inst, eqn.outvars)
eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars,
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
call_jaxpr.effects, eqn.source_info)
call_jaxpr.effects, eqn.source_info, eqn.ctx)
new_vars = [*new_inst, *resvars]
return eqn_known, eqn_staged, in_unknowns, out_inst, new_vars

@ -273,7 +273,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
return carry, stacked_y
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})
dbg_body = api_util.debug_info("scan", f, (init, xs), {})
if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
@ -1357,10 +1357,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_dbg = api_util.tracing_debug_info("while_cond", cond_fun, (init_val,), {})
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, cond_dbg)
body_dbg = api_util.tracing_debug_info("while_body", body_fun, (init_val,), {})
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, body_dbg)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
@ -1368,7 +1368,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
raise TypeError(msg.format(cond_tree))
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type() != ShapedArray((), np.bool_)):
or ShapedArray(pred_aval.shape, pred_aval.dtype) != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
@ -1855,18 +1855,26 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
f'Effects not supported in `while`: {disallowed_effects}')
return body_jaxpr.out_avals, joined_effects
def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
def _while_partial_discharge_rule(should_discharge, in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
cond_nconsts, body_nconsts):
# TODO(sharadmv): enable supporting state effects in the cond
if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects):
raise NotImplementedError
cond_consts_discharge, body_consts_discharge, carry_discharge = split_list(
should_discharge, [cond_nconsts, body_nconsts])
if any(cond_consts_discharge):
raise NotImplementedError
cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals,
[cond_nconsts,
body_nconsts])
# There shouldn't be any `Ref`s in the `cond` (because of our check above).
assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals)
is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals]
is_ref = [
isinstance(aval, state.AbstractRef) and should
for aval, should in zip(body_consts_avals, body_consts_discharge)
]
remaining_body_consts, refs = partition_list(is_ref, body_consts)
remaining_body_const_avals, ref_avals = partition_list(is_ref,
body_consts_avals)
@ -1886,7 +1894,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
# Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that
# they are part of the carry.
discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state(
body_jaxpr, ())
body_jaxpr, (), should_discharge=[*body_consts_discharge, *carry_discharge])
if discharged_consts: raise NotImplementedError
def new_body(*consts_refs_carry):
@ -1943,7 +1951,7 @@ batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)
state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule)
def _pred_bcast_select_hlo(ctx,

@ -93,16 +93,16 @@ def custom_root(f: Callable,
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(core.get_aval, guess_flat))
f_debug = api_util.tracing_debug_info("custom_root", f, (initial_guess,), {})
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals, f_debug)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree, False)
solve_debug = api_util.tracing_debug_info("custom_root solve", solve,
(f, initial_guess), {},
static_argnums=(0,))
solve_debug = api_util.debug_info("custom_root solve", solve,
(f, initial_guess), {},
static_argnums=(0,))
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug)
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
@ -111,10 +111,10 @@ def custom_root(f: Callable,
unchecked_zeros, f_jvp = api.linearize(f, x)
return tangent_solve(f_jvp, b)
tangent_solve_debug = api_util.tracing_debug_info("custom_root tangent_solve",
tangent_solve,
(f, initial_guess), {},
static_argnums=(0,))
tangent_solve_debug = api_util.debug_info("custom_root tangent_solve",
tangent_solve,
(f, initial_guess), {},
static_argnums=(0,))
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
tangent_solve_debug)
@ -265,17 +265,17 @@ def custom_linear_solve(
return f_aux if has_aux else f
matvec_debug = api_util.tracing_debug_info("custom_linear_solve",
matvec, (b,), {})
matvec_debug = api_util.debug_info("custom_linear_solve",
matvec, (b,), {})
# no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
matvec_debug)
_check_tree("matvec", "b", out_tree, tree, False)
solve_debug = api_util.tracing_debug_info("custom_linear_solve solve",
solve, (matvec, b), {},
static_argnums=(0,))
solve_debug = api_util.debug_info("custom_linear_solve solve",
solve, (matvec, b), {},
static_argnums=(0,))
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
solve_debug)
@ -285,7 +285,7 @@ def custom_linear_solve(
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
transpose_solve_debug = api_util.tracing_debug_info(
transpose_solve_debug = api_util.debug_info(
"custom_linear_solve transpose_solve", transpose_solve,
(matvec, b), {}, static_argnums=(0,))
if symmetric:
@ -325,7 +325,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return args_to_raise
return args_to_raise, jaxprs.solve.effects
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
@ -482,7 +482,7 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
linear_solve_p = core.Primitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
linear_solve_p.def_effectful_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
mlir.register_lowering(

@ -367,12 +367,46 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
"""
return nextafter_p.bind(x1, x2)
@export
def floor(x: ArrayLike) -> Array:
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
This function lowers directly to the `stablehlo.floor`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward negative infinity.
See also:
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
"""
return floor_p.bind(x)
@export
def ceil(x: ArrayLike) -> Array:
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
This function lowers directly to the `stablehlo.ceil`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward positive infinity.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
"""
return ceil_p.bind(x)
class RoundingMethod(enum.IntEnum):
@ -388,20 +422,38 @@ class RoundingMethod(enum.IntEnum):
as bankers rounding (e.g., 0.5 -> 0, 1.5 -> 2).
"""
@export
def round(x: ArrayLike,
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
) -> Array:
r"""Elementwise round.
Rounds values to the nearest integer.
Rounds values to the nearest integer. This function lowers directly to the
`stablehlo.round`_ operation.
Args:
x: an array or scalar value to round.
x: an array or scalar value to round. Must have floating-point type.
rounding_method: the method to use when rounding halfway values
(e.g., `0.5`). See :class:`jax.lax.RoundingMethod` for possible values.
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
Returns:
An array containing the elementwise rounding of x.
An array of the same shape and dtype as ``x``, containing the elementwise
rounding of ``x``.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
Examples:
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
"""
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)
@ -409,29 +461,126 @@ def round(x: ArrayLike,
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
:math:`\mathit{NaN}`.
This function lowers directly to the `stablehlo.is_finite`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of boolean dtype with the same shape as ``x``, containing ``False`` where
``x`` is :math:`\pm\infty` or :math:`\mathit{NaN}`, and ``True`` otherwise.
See also:
- :func:`jax.numpy.isinf`: return True where array is infinite.
- :func:`jax.numpy.isnan`: return True where array is NaN.
.. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite
"""
return is_finite_p.bind(x)
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
r"""Elementwise exponential: :math:`e^x`.
This function lowers directly to the `stablehlo.exponential`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
"""
return exp_p.bind(x)
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
r"""Elementwise base-2 exponential: :math:`2^x`.
This function is implemented in terms of the `stablehlo.exponential`_
and `stablehlo.multiply`_ operations.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return exp2_p.bind(x)
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
r"""Elementwise :math:`e^{x} - 1`.
This function lowers directly to the `stablehlo.exponential_minus_one`_
operation. Compared to the naive expression ``lax.exp(x) - 1``, it is
more accurate for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log1p`: elementwise :math:`\mathrm{log}(1 + x)`.
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
"""
return expm1_p.bind(x)
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
This function lowers directly to the `stablehlo.log`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
"""
return log_p.bind(x)
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
- :func:`jax.lax.log`: elementwise natural logarithm :math:`\mathrm{log}(x)`.
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
"""
return log1p_p.bind(x)
def tanh(x: ArrayLike) -> Array:
@ -745,9 +894,10 @@ def _trace_composite_to_jaxpr(fun: Callable,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
name: str,
debug_info: api_util.TracingDebugInfo):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug_info)
debug_info: core.DebugInfo):
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info), in_tree)
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError(
"Found a JAX Tracer as a constant in the decomposition for the "
@ -822,8 +972,8 @@ def composite(
"""
@functools.wraps(decomposition)
def _decorator(*args, **kwargs):
debug_info = api_util.tracing_debug_info("composite", decomposition,
args, kwargs)
debug_info = api_util.debug_info("composite", decomposition,
args, kwargs)
flat_args, in_tree = tree_util.tree_flatten(args)
in_avals = tuple(core.get_aval(x) for x in flat_args)
closed_jaxpr, out_tree = _trace_composite_to_jaxpr(
@ -3274,7 +3424,7 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
if isinstance(sharding, NamedSharding):
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
else:
return None
return core.get_cur_mesh_sharding()
return sharding
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
@ -6540,6 +6690,8 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
if (not dyn_shape and
not any(isinstance(d, core.DArray) and
type(core.get_aval(d).dtype) is core.bint for d in shape)):
if sharding is None:
sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape)))
return ShapedArray(shape, dtype, sharding=sharding)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)

@ -733,7 +733,6 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
if config.sharding_in_types.value:
args = core.cast_from_auto_to_manual(args)
core.check_avals_context_mesh(args, 'all_reduce')
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,

@ -1884,12 +1884,12 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
mode, fill_value):
# TODO(yashkatariya): Write a proper gather sharding rule.
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
return None
if (cur_mesh._are_all_axes_explicit and # type: ignore
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
return core.get_cur_mesh_sharding()
if (cur_mesh._are_all_axes_explicit and
all(s is None for s in operand.sharding.spec) and
all(s is None for s in indices.sharding.spec)):
return None
return core.get_cur_mesh_sharding()
raise GatherShardingError(
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
" the gather indexing.")

@ -24,6 +24,7 @@ from jax._src import config
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip
from jax._src.partition_spec import PartitionSpec as P
zip, unsafe_zip = safe_zip, zip
@ -49,9 +50,14 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
from jax._src.pjit import _get_abstract_mesh_from_avals, NamedSharding
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
return None if num_out is None else [None] * num_out
aval_mesh = _get_abstract_mesh_from_avals(avals)
# TODO(yashkatariya): `aval_mesh.empty` should be `aval_mesh.unset`
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
s = NamedSharding(aval_mesh, P())
return s if num_out is None else [s] * num_out
if rule is None:
raise ValueError(
f'sharding rule for {prim.name} is not implemented. Please file a'
@ -68,7 +74,6 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
weak_type = weak_type_rule(*avals, **kwargs)
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ShapedArray:
avals = core.cast_from_auto_to_manual(avals)
core.check_avals_context_mesh(avals, prim.name)
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
@ -94,7 +99,6 @@ def standard_multi_result_abstract_eval(
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
weak_types = weak_type_rule(*avals, **kwargs)
if least_specialized is core.ShapedArray:
avals = core.cast_from_auto_to_manual(avals)
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)

@ -12,4 +12,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import Protocol
from jaxlib.triton import dialect # noqa: F401 # pytype: disable=import-error
class CompilationResult(Protocol):
asm: str
smem_bytes: int
cluster_dim_x: int
cluster_dim_y: int
cluster_dim_z: int
class CompilationHandler(Protocol):
def __call__(
self,
module: bytes,
arch_name: str,
num_warps: int,
num_ctas: int,
num_stages: int,
) -> CompilationResult:
...
_compilation_handlers: dict[str, CompilationHandler] = {}
_compilation_handlers_lock = threading.Lock()
def register_compilation_handler(
platform: str, handler: CompilationHandler
) -> None:
platform = platform.upper()
with _compilation_handlers_lock:
if existing_handler := _compilation_handlers.get(platform):
raise RuntimeError(
f'Platform {platform} already has a Triton compilation handler:'
f' {existing_handler}'
)
_compilation_handlers[platform] = handler
def has_compilation_handler(platform: str) -> bool:
platform = platform.upper()
with _compilation_handlers_lock:
return platform in _compilation_handlers
def compile(
platform: str,
module: bytes,
arch_name: str,
*,
num_warps: int,
num_ctas: int,
num_stages: int,
) -> CompilationResult:
platform = platform.upper()
with _compilation_handlers_lock:
handler = _compilation_handlers.get(platform)
if handler is None:
raise RuntimeError(
f'Platform {platform} does not have a Triton compilation handler'
)
return handler(module, arch_name, num_warps, num_ctas, num_stages)

@ -63,7 +63,7 @@ data must be immutable, because it will be stored in function memoization tables
"""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, NamedTuple
import weakref
@ -71,6 +71,7 @@ import weakref
from jax._src import config
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import keystr, generate_key_paths
from jax._src.util import curry, cache_clearing_funs, HashableFunction
@ -156,7 +157,7 @@ class WrappedFun:
f_transformed: Callable,
transforms,
stores: tuple[Store | EqualStore | None, ...], params, in_type,
debug_info: TracingDebugInfo | None):
debug_info: DebugInfo | None):
self.f = f
self.f_transformed = f_transformed
self.transforms = transforms
@ -253,12 +254,10 @@ def fun_name(f):
except:
return str(f)
class TracingDebugInfo(NamedTuple):
"""Tracing-time debugging info about a func and its arguments.
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
class DebugInfo(NamedTuple):
"""Debugging info about a func, its arguments, and results."""
traced_for: str # e.g. 'jit', 'scan', etc
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
@ -270,23 +269,18 @@ class TracingDebugInfo(NamedTuple):
# e.g., tangent args in jax.jvp.
arg_names: tuple[str | None, ...]
# The result paths are not available while we are tracing the function,
# instead we keep a thunk. Once we are done tracing, we use
# `self.resolve_result_paths()` to execute the thunk and replace the
# actual result paths.
# e.g. ('[0]', '[1]', ...)
result_paths_thunk: Callable[[], tuple[str, ...]] | None
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
@classmethod
def from_jaxpr(cls, jaxpr: core.ClosedJaxpr) -> TracingDebugInfo | None:
jaxpr_dbg = jaxpr.jaxpr._debug_info
if jaxpr_dbg is None: return None
return TracingDebugInfo(jaxpr_dbg.traced_for,
jaxpr_dbg.func_src_info,
jaxpr_dbg.arg_names,
lambda: jaxpr_dbg.result_paths)
def add_result_paths(self, result_paths_thunk: Callable[[], tuple[str, ...]]
) -> TracingDebugInfo:
assert self.result_paths_thunk is None
return self._replace(result_paths_thunk=HashableFunction(result_paths_thunk,
closure=()))
def resolve_result_paths(self) -> DebugInfo:
"""Return a debug info with resolved result paths."""
if callable(self.result_paths):
return self._replace(result_paths=tuple(self.result_paths()))
return self
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
"""Get the arg_names with a safety check."""
@ -296,15 +290,47 @@ class TracingDebugInfo(NamedTuple):
# TODO(necula): this should not happen
return (None,) * expected
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
"""Keep only the arg_names for which `keep` is True."""
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result paths with a safety check."""
assert not callable(self.result_paths), self
if self.result_paths is not None and len(self.result_paths) == expected:
return self.result_paths
else:
# TODO(necula): this should not happen
return ("",) * expected
def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
"""Keep only the result_paths for which `keep` is True."""
assert not callable(self.result_paths), self
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
def wrap_init(f: Callable, params=None, *,
debug_info: TracingDebugInfo | None = None) -> WrappedFun:
debug_info: DebugInfo | None = None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
if debug_info:
if debug_info.result_paths is None:
fun, result_paths_thunk = _get_result_paths_thunk(fun)
debug_info = debug_info._replace(
result_paths=HashableFunction(result_paths_thunk, closure=()))
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
fun.params, fun.in_type, debug_info)
return fun
@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
assert f.in_type is None
if in_type is None:
@ -341,16 +367,9 @@ def _check_input_type(in_type: core.InputType) -> None:
provided[d.val] = True
assert all(provided)
def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None
if debug_info is None:
return f
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
def cache(call: Callable, *, explain: Callable | None = None):
def cache(call: Callable, *,
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
"""Memoization decorator for functions taking a WrappedFun as first argument.
Args:
@ -358,6 +377,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
underlying transforms and params on the WrappedFun are used as part of the
memoization cache key.
explain: a function that is invoked upon cache misses to log an explanation
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
Returns:
A memoized version of ``call``.
"""
@ -373,7 +395,7 @@ def cache(call: Callable, *, explain: Callable | None = None):
else:
ans = call(fun, *args)
if explain and config.explain_cache_misses.value:
explain(fun.f, cache is new_cache, cache, key)
explain(fun, cache is new_cache, cache, key)
cache[key] = (ans, fun.stores)
return ans

@ -530,8 +530,8 @@ class AbstractMesh:
@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.abstract_mesh_context_manager.set_local(mesh)
return
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
return prev
# Create this indirection because pytype fails to recognize a property if a

@ -9744,11 +9744,12 @@ def einsum(
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
preferred_element_type, _dot_general, out_sharding)
jit_einsum = jax.named_call(jit_einsum, name=spec)
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
return jit_einsum(operand_arrays, contractions, precision,
preferred_element_type, _dot_general, out_sharding)
# Enable other modules to override einsum_contact_path.
@ -9843,7 +9844,7 @@ def _removechars(s, chars):
def _einsum(
operands: Sequence,
operands: list[jax.Array],
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
precision,
preferred_element_type,
@ -9859,7 +9860,6 @@ def _einsum(
"`out_sharding` argument of `einsum` only supports NamedSharding"
" instances. Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
else:
@ -11649,7 +11649,8 @@ def take_along_axis(
j = 0
for i in range(rank):
if i == axis_int:
indices = _normalize_index(indices, axis_size)
if mode != 'promise_in_bounds':
indices = _normalize_index(indices, axis_size)
gather_indices.append(lax.reshape(indices, gather_index_shape))
slice_sizes.append(1)
start_index_map.append(i)

@ -222,10 +222,6 @@ class AbstractMemoryRef(state.AbstractRef):
def __repr__(self) -> str:
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
@property
def sharding(self):
return self.inner_aval.sharding
def update_weak_type(self, weak_type):
return AbstractMemoryRef(
self.inner_aval.update_weak_type(weak_type), self.memory_space)
@ -413,9 +409,9 @@ class BlockSpec:
fake_index_map_args, fake_index_map_kwargs = \
index_map_tree.unflatten([False] * index_map_tree.num_leaves)
debug = api_util.tracing_debug_info("pallas_call index_map",
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
debug = api_util.debug_info("pallas_call index_map",
index_map_func, fake_index_map_args,
fake_index_map_kwargs)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
@ -423,7 +419,7 @@ class BlockSpec:
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
flat_index_map_fun, index_map_avals, debug_info=debug
flat_index_map_fun, index_map_avals
)
mapped_block_shape = tuple(mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
@ -890,7 +886,8 @@ def get_grid_mapping(
)
# The inputs for the index maps
index_map_avals = (
(index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid))
index_map_grid_aval.update(sharding=jax_core.get_cur_mesh_sharding()),
) * len(grid_spec.grid)
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)

@ -49,3 +49,13 @@ class ArrayLike(Protocol):
def empty_like(x: ArrayLike, *, memory_space: Any = None):
return empty(x.shape, x.dtype, memory_space=memory_space)
def when(condition):
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
jax.lax.cond(condition, f, lambda: None)
return _wrapped

@ -30,8 +30,9 @@ LOCATION_PATTERN = re.compile(
r'(?P<location>loc\((?P<eqn_str>\".*?\")(?P<frames>.*)\))'
)
FRAME_PATTERN = re.compile(
r'(?P<fun_name>\".*?\")\((?P<filename>\".*?\"):'
r'(?P<lineno>[0-9]+):(?P<colno>[0-9]+)\)'
r'(?P<fun_name>\".*?\")\((?P<filename>\"[^"]*?\"):'
r'(?P<lineno>[0-9]+)?:(?P<colno>[0-9]+)'
r'( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)'
)
MLIR_ERR_PREFIX = (
'Pallas encountered an internal verification error.'

@ -16,6 +16,7 @@
import functools
import jax
from jax._src.pallas import helpers as pl_helpers
from jax._src.pallas import primitives as pl_primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as plm_primitives
@ -55,3 +56,40 @@ def sync_copy(src_ref, dst_ref):
src_ref,
dst_ref,
)
def run_on_first_core(core_axis_name: str):
"""Runs a function on the first core in a given axis."""
num_cores = jax.lax.psum(1, core_axis_name)
if num_cores == 1:
return lambda f: f()
def wrapped(f):
core_id = jax.lax.axis_index(core_axis_name)
@pl_helpers.when(core_id == 0)
@functools.wraps(f)
def _():
return f()
return wrapped
def core_barrier(sem, *, core_axis_name: str):
"""Synchronizes all cores in a given axis."""
num_cores = jax.lax.psum(1, core_axis_name)
core_id = jax.lax.axis_index(core_axis_name)
@pl_helpers.when(num_cores > 1)
def _():
with jax.named_scope("sync_cores"):
def signal_core(i):
# Don't signal ourself
@pl_helpers.when(core_id != i)
def _():
plm_primitives.semaphore_signal(sem, 1, core_index=i)
for i in range(num_cores):
signal_core(i)
plm_primitives.semaphore_wait(sem, num_cores - 1)

@ -1524,7 +1524,9 @@ def _masked_swap_lowering_rule(
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
for b in ref_block_shape
]
mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None)
mem_aval = aval_out.update(
shape=tuple(mem_slice_shape), sharding=jax_core.get_cur_mesh_sharding()
)
mem_aval_shape = ctx.lowering_context.dynamic_shape_replacement_fn(
mem_aval.shape
)
@ -2127,7 +2129,11 @@ def _gather_lowering_rule(
slice_sizes == (1, 1)
and not unique_indices
and not indices_are_sorted
and mode == lax.GatherScatterMode.FILL_OR_DROP
and mode
in (
lax.GatherScatterMode.FILL_OR_DROP,
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
)
):
if dimension_numbers == lax.GatherDimensionNumbers(
offset_dims=(),
@ -3011,6 +3017,11 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
lowering_rules[pjit.pjit_p] = _pjit_lowering_rule
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
return x
lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule
def _custom_jvp_call_lowering_rule(
ctx: LoweringRuleContext,
*args,

@ -921,13 +921,18 @@ ARBITRARY = GridDimensionSemantics()
def _partition_grid(
grid: tuple[int | jax.Array, ...],
core_axis: int | None,
core_axis: int | str | None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None,
) -> tuple[tuple[int | jax.Array, ...], tuple[int | jax.Array, ...]]:
if core_axis is None:
# We aren't partitioning the grid
return grid, (0,) * len(grid)
num_cores = pl.num_programs(core_axis)
if isinstance(core_axis, int):
num_cores = pl.num_programs(core_axis)
core_id = pl.program_id(core_axis)
else:
num_cores = jax.lax.psum(1, core_axis)
core_id = jax.lax.axis_index(core_axis)
# Check that num_cores is statically known
if not isinstance(num_cores, int):
raise NotImplementedError(
@ -966,7 +971,7 @@ def _partition_grid(
i for i in range(len(dimension_semantics)) if i in divisible_dimensions
)
partitioned_dim_size = grid[first_divisible_dimension] // num_cores
partitioned_dim_offset = pl.program_id(core_axis) * partitioned_dim_size
partitioned_dim_offset = core_id * partitioned_dim_size
new_grid = jax_util.tuple_update(
grid, first_divisible_dimension, partitioned_dim_size
)
@ -990,8 +995,7 @@ def _partition_grid(
# We have some remainder iterations that we need to assign somewhere. We
# know that rem < num_cores, so we can assign one extra iteration to each
# core except for the last (num_cores - rem).
core_index = pl.program_id(core_axis)
num_iters = jnp.where(core_index < rem, base_num_iters + 1,
num_iters = jnp.where(core_id < rem, base_num_iters + 1,
base_num_iters)
new_grid = jax_util.tuple_update(grid, partition_dimension, num_iters)
# Ordinarily, we would compute the offset as:
@ -999,9 +1003,9 @@ def _partition_grid(
# However, since we have some cores that don't have an extra iteration, we
# need to adjust the offset by `rem`.
grid_offset = jnp.where(
core_index < rem,
core_index * num_iters,
core_index * base_num_iters + rem,
core_id < rem,
core_id * num_iters,
core_id * base_num_iters + rem,
)
offsets = jax_util.tuple_update(
(0,) * len(grid), partition_dimension, grid_offset
@ -1015,8 +1019,9 @@ def emit_pipeline(
grid: tuple[int | jax.Array, ...],
in_specs=None,
out_specs=None,
should_accumulate_out=False,
should_accumulate_out: bool = False,
core_axis: int | None = None,
core_axis_name: str | None = None,
dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None,
trace_scopes: bool = True,
):
@ -1039,6 +1044,8 @@ def emit_pipeline(
as accumulators.
core_axis: optional int, indicates whether or not to partition the grid
along the core axis.
core_axis_name: optional str, indicates whether or not to partition the grid
along the core axis.
dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL
or ARBITRARY).
trace_scopes: optional bool, indicates whether to annotate each region in
@ -1049,7 +1056,10 @@ def emit_pipeline(
raise ValueError(
f"Grid must consist of Python integers and JAX Arrays: {grid_types}"
)
grid, grid_offsets = _partition_grid(grid, core_axis, dimension_semantics)
if not (core_axis is None or core_axis_name is None):
raise ValueError("core_axis and core_axis_name cannot both be provided.")
core_axis_ = core_axis_name if core_axis is None else core_axis
grid, grid_offsets = _partition_grid(grid, core_axis_, dimension_semantics)
num_steps = _grid_size(grid)
if not isinstance(in_specs, (list, tuple)):

@ -550,8 +550,8 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn
def dma_start_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals,
*args, tree, device_id_type):
(
src_ref,
src_transforms,
@ -575,7 +575,22 @@ def dma_start_discharge_rule(in_avals, out_avals,
_,
) = tree_util.tree_unflatten(tree, in_avals)
del out_avals
(
_,
_,
dst_discharge,
_,
dst_sem_discharge,
_,
*maybe_src_sem_discharge,
) = tree_util.tree_unflatten(tree, should_discharge)
is_remote = device_id is not None
src_sem_discharge = None
if is_remote:
src_sem_discharge = maybe_src_sem_discharge[0]
if not is_remote:
# Local async copies only use one semaphore.
assert src_sem is None
@ -586,7 +601,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals))
num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals))
updates = state_discharge.transform_array(src_ref, src_transforms)
updates = state_discharge.transform_array(src_ref[...], src_transforms)
local_src = updates
if is_remote:
@ -641,47 +656,61 @@ def dma_start_discharge_rule(in_avals, out_avals,
global_dst_transforms,
)
_, new_dst = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
def do_discharge_dst(dst_ref=dst_ref):
_, ret = state_discharge.transform_swap_array(
dst_ref, dst_transforms, updates
)
return ret
# Update semaphore values.
# TODO(justinfu): Potentially handle asymmetric copy sizes.
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, new_dst_sem = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value + recv_size
)
if is_remote:
def do_discharge_dst_sem(dst_sem=dst_sem):
recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
dst_sem_value = _transform_semaphore(
dst_sem, dst_sem_transforms, dst_sem_aval
)
_, ret = state_discharge.transform_swap_array(
dst_sem, dst_sem_transforms, dst_sem_value[...] + recv_size
)
return ret
def do_discharge_src_sem(src_sem=src_sem):
send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE)
send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
src_sem_value = _transform_semaphore(
src_sem, src_sem_transforms, src_sem_aval
)
_, new_src_sem = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value + send_size
_, ret = state_discharge.transform_swap_array(
src_sem, src_sem_transforms, src_sem_value[...] + send_size
)
else:
new_src_sem = None
return ret
new_vals = (None,) # src_val
new_vals += (None,) * num_src_transform_vals
new_vals += (new_dst,) # dst_val
new_vals += (do_discharge_dst() if dst_discharge else None,) # dst_val
new_vals += (None,) * num_dst_transform_vals
new_vals += (new_dst_sem,) # dst_sem
new_vals += (do_discharge_dst_sem() if dst_sem_discharge else None,) # dst_sem
new_vals += (None,) * num_dst_sem_transforms
if is_remote:
new_vals += (new_src_sem,) # src_sem
new_vals += (do_discharge_src_sem() if src_sem_discharge else None,) # src_sem
new_vals += (None,) * num_src_sem_transforms
new_vals += (None,) # device_id
assert (len(new_vals) ==
len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}"
# If we didn't discharge everything we could we should keep writes
# to the references that are left over.
if not dst_discharge:
sp.ref_set(dst_ref, None, do_discharge_dst(dst_ref=dst_ref[...]))
if not dst_sem_discharge:
sp.ref_set(dst_sem, None, do_discharge_dst_sem(dst_sem=dst_sem[...]))
if is_remote and not src_sem_discharge:
sp.ref_set(src_sem, None, do_discharge_src_sem(src_sem=src_sem[...]))
return new_vals, []
state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule)
dma_wait_p = jax_core.Primitive('dma_wait')
@ -719,8 +748,9 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn
def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
def dma_wait_partial_discharge_rule(should_discharge,
in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
@ -735,6 +765,14 @@ def dma_wait_discharge_rule(in_avals, out_avals,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)
# The only one we can discharge is the dst semaphore. The provided
# buffers are only specified for their types and not their value so
# it's completely irrelevant for us here if they are discharged.
should_discharge_unflattened = tree_util.tree_unflatten(tree, should_discharge)
if not should_discharge_unflattened[4]:
return (None,) * len(in_avals), []
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
@ -754,7 +792,7 @@ def dma_wait_discharge_rule(in_avals, out_avals,
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)
state_discharge.register_partial_discharge_rule(dma_wait_p)(dma_wait_partial_discharge_rule)
def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):

@ -1121,6 +1121,9 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args
)
@register_lowering_rule(pjit.mesh_cast_p)
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
return x
@register_lowering_rule(lax.slice_p)
def _slice_lowering_rule(
@ -1430,15 +1433,6 @@ def _run_scoped_lowering_rule(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)
for o in outs:
# This is definitely one of the accumulators we produced. Each
# run_scoped call is responsible for dereferencing its own
# accumulators.
if isinstance(o, mgpu.WGMMAAccumulator) or (
isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type)
):
raise ValueError(f"No references are allowed to escape a scope. (got {o})")
assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
return outs
@ -1508,7 +1502,14 @@ def _lower_jaxpr_to_for_loop(
if arg_avals:
out_avals = ctx.avals_out[-len(arg_avals):]
@mgpu.fori(length, [*map(_ensure_fa, args, arg_avals)])
is_acc = [isinstance(v, mgpu.WGMMAAccumulator) for v in args]
def as_fas(vals, avals):
if is_acc != [isinstance(v, mgpu.WGMMAAccumulator) for v in vals]:
raise ValueError("Unexpected loop carry w.r.t. accumulators.")
return [v if a else _ensure_fa(v, av) for a, v, av in zip(is_acc, vals, avals)]
@mgpu.fori(length, as_fas(args, arg_avals))
def loop(loop_index, body_args):
if has_loop_index:
loop_index = arith_dialect.addi(loop_index, start)
@ -1518,7 +1519,7 @@ def _lower_jaxpr_to_for_loop(
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args
)
return map(_ensure_fa, outs, out_avals)
return as_fas(outs, out_avals)
return loop.results
@ -1640,7 +1641,10 @@ def _while_lowering_rule(
_cond_avals, body_avals, carry_avals = util.split_list(
ctx.avals_in, [cond_nconsts, body_nconsts]
)
carry = map(_ensure_fa, carry, carry_avals)
carry = [
v if isinstance(v, mgpu.WGMMAAccumulator) else _ensure_fa(v, av)
for v, av in zip(carry, carry_avals)
]
# Flatten the carry to get a concatenated list of registers from each FA.
# Note that the treedef is also used below to unflatten the body results.
flat_carry, carry_treedef = jax.tree.flatten(carry)
@ -1663,9 +1667,19 @@ def _while_lowering_rule(
loop_out = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args
)
loop_out = map(_ensure_fa, loop_out, carry_avals)
loop_out = [
v if isinstance(v, mgpu.WGMMAAccumulator) else _ensure_fa(v, av)
for v, av in zip(loop_out, carry_avals)
]
for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)):
if carry_fa.layout != out_fa.layout:
_is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator)
if _is_acc(carry_fa) != _is_acc(out_fa):
raise ValueError(
f"The loop body output has unexpected accumulator type: output[{idx}]"
f" is {out_fa}, when it should be {carry_fa}."
)
if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout:
raise ValueError(
f"The loop body output has unexpected layout: output[{idx}] has"
f" layout {out_fa.layout}, when it should be {carry_fa.layout}."
@ -1865,6 +1879,19 @@ def merge_indexers(
if indexer.int_indexer_shape:
raise NotImplementedError()
def _ensure_idx_fa(x):
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
if isinstance(x, ir.Value):
return mgpu.FragmentedArray.splat(
x, (), is_signed=mgpu.utils.is_signed(x.type)
).astype(i32, is_signed=False)
if isinstance(x, mgpu.FragmentedArray):
return x.astype(i32, is_signed=False)
if isinstance(x, int):
return mgpu.FragmentedArray.splat(mgpu.c(x, i32), (), is_signed=False)
raise NotImplementedError(x)
num_skipped = 0
for i in range(len(current_indices)):
# Integer indexers remove dimensions which should be
@ -1876,18 +1903,17 @@ def merge_indexers(
current_index = current_indices[i]
assert isinstance(current_index, indexing.Slice)
current_start_index = _ensure_fa(current_index.start, jnp.int32)
current_start_index = _ensure_idx_fa(current_index.start)
if isinstance(dim_indexer, indexing.Slice):
if dim_indexer.stride != 1:
raise NotImplementedError("Non-unit strides not implemented.")
current_indices[i] = indexing.Slice(
current_start_index + _ensure_fa(dim_indexer.start, jnp.int32),
current_start_index + _ensure_idx_fa(dim_indexer.start),
dim_indexer.size,
1,
)
else:
current_indices[i] = current_start_index + _ensure_fa(
dim_indexer, dtype=jnp.int32)
current_indices[i] = current_start_index + _ensure_idx_fa(dim_indexer)
removed_dimensions.add(i)
return indexing.NDIndexer(
indices=tuple(current_indices),

@ -558,15 +558,9 @@ def _wgmma_lowering(
if rhs_tiling != (swizzle_elems, swizzle_elems):
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")
new_acc = mgpu.wgmma(
acc,
a,
b,
swizzle=rhs_swizzle,
b_order=mgpu.WGMMALayout.COL_MAJOR
if rhs_transpose
else mgpu.WGMMALayout.ROW_MAJOR,
)
if rhs_transpose:
b = mgpu.memref_transpose(b, (0, 1, 3, 2))
new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle)
nvvm_dialect.wgmma_commit_group_sync_aligned()
return new_acc

@ -37,8 +37,8 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import helpers as pallas_helpers
from jax._src.pallas import hlo_interpreter
from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.util import (
@ -101,12 +101,12 @@ def _pallas_call_jvp_rule(
primals,
tangents,
*,
jaxpr,
jaxpr: jax_core.Jaxpr,
name_and_src_info,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping,
debug,
interpret,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -739,7 +739,7 @@ def _pallas_call_batching_rule(
# b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
# checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
@pallas_utils.when(run_kernel)
@pallas_helpers.when(run_kernel)
def f():
# Important! This allows us to trace the inner kernel with the correct
# grid to preserve user program_id semantics. Ex: program_id(0) will
@ -1098,13 +1098,14 @@ def pallas_call_checkify_rule(error: checkify.Error,
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
*error_memref_aval, *output_aval, *scratch_aval]
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
debug = api_util.debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = api_util.tracing_debug_info("checkify_pallas", checked_kernel_fn,
retrace_in_avals, {})
lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
wrapped_kernel_with_err, jaxpr_flat_avals)
# Prepare pallas_call inputs. We need to create new block specs
# for the new error inputs and outputs.
@ -1161,16 +1162,16 @@ def _trace_kernel_to_jaxpr(
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
indexer: bool = False,
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.debug_info("pallas_call", fun, fake_kernel_args, {})
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), kernel_in_tree)
lu.wrap_init(fun, debug_info=debug), kernel_in_tree)
wrapped_kernel_fun = primitives.wrap_with_transforms(
wrapped_kernel_fun, kernel_in_transforms
)
fake_kernel_args = kernel_in_tree.unflatten(kernel_avals)
debug = api_util.tracing_debug_info("pallas_call", fun, fake_kernel_args, {})
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
kernel_avals)
if consts:
consts_avals = [jax_core.get_aval(c) for c in consts]
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
@ -1568,7 +1569,7 @@ def pallas_call(
kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None
if kernel_fun_sig:
kernel_debug_info = api_util.tracing_debug_info(
kernel_debug_info = api_util.debug_info(
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})

@ -896,22 +896,25 @@ def _run_scoped_discharge_rule(
**_):
del out_avals
num_consts = len(args_flat)
# discharge_state only discharges invars, not consts, so in order to
# discharge the requested refs we need to move them to the invar set.
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
should_discharge = should_discharge + [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
]
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=should_discharge)
jaxpr_noconst,
[],
should_discharge=should_discharge + [False] * len(jaxpr.invars),
)
if new_consts:
raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
# Create inputs filled with uninitialized values to the body.
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
init_vals_with_consts = args_flat + tuple(init_vals)
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)
# Lowering expects that the jaxpr.consts to be the eqn.invals.
discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts)
# Run_scoped discharged the external variables but the scoped ones
# are not discharged.
out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body)
# Order of outputs:
# (1) return values, (2) closed refs, (3) scoped refs.
return_values = out[:num_return_values]
@ -919,8 +922,8 @@ def _run_scoped_discharge_rule(
# We update all ref values with their updated values from the discharged
# body. For other values we leave them in place.
updates = [
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
else None for aval in in_avals]
ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef)
else None for should, aval in zip(should_discharge, in_avals)]
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
return updates, return_values
@ -931,17 +934,20 @@ state_discharge.register_partial_discharge_rule(run_scoped_p)(
@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=True)
if new_consts: raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
# Create inputs filled with uninitialized values to the body.
num_consts = len(lower_fun_args)
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
return out[:num_return_values]
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)

@ -593,12 +593,19 @@ class _Extern:
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, self.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
result_type = _dtype_to_ir_type(jnp.dtype(self.result_type))
if out_aval.shape:
result_type = ir.RankedTensorType.get(out_aval.shape, result_type)
return tt_dialect.extern_elementwise(
result_type,
args,
bcast_args,
libname="",
libpath="",
symbol=self.symbol,
@ -608,10 +615,23 @@ class _Extern:
@dataclasses.dataclass(frozen=True)
class _Fallback:
arg_types: Sequence[jax.typing.DTypeLike]
lower: Callable[..., ir.Value]
arg_classes: Sequence[jax.typing.DTypeLike]
op: Callable[..., ir.Value]
matches = _Extern.matches
def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool:
if len(avals) != len(self.arg_classes):
return False
return all(
jnp.issubdtype(aval.dtype, arg_class)
for aval, arg_class in zip(avals, self.arg_classes)
)
def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]):
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg in zip(ctx.avals_in, args):
bcast_args.append(_bcast_to(_ensure_ir_value(arg, aval), out_aval.shape))
return self.op(*args)
def _make_dispatch_table(
@ -626,390 +646,452 @@ def _make_dispatch_table(
raise NotImplementedError(
f"unsupported types for {name}: {arg_aval_dtypes}"
)
[out_aval] = ctx.avals_out
bcast_args = []
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
if aval.weak_type and aval.dtype != jnp.dtype(arg_type):
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
bcast_args.append(bcast_arg)
return h.lower(ctx, *bcast_args)
return h.lower(ctx, *args)
return inner
_abs_dispatch_table = _make_dispatch_table(
abs_dispatch_table = _make_dispatch_table(
"abs",
cuda=[
_Extern([jnp.int32], "__nv_abs", jnp.int32),
_Extern([jnp.int64], "__nv_llabs", jnp.int64),
_Extern([jnp.float32], "__nv_fabsf", jnp.float32),
_Extern([jnp.float64], "__nv_fabs", jnp.float64),
_Fallback([jnp.integer], math_dialect.absi),
_Fallback([jnp.floating], math_dialect.absf),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)),
_Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)),
_Fallback([jnp.integer], math_dialect.absi),
_Fallback([jnp.floating], math_dialect.absf),
],
)
ceil_dispatch_table = _make_dispatch_table(
"ceil",
cuda=[
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
_Fallback([jnp.floating], math_dialect.ceil),
],
rocm=[
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.ceil),
],
)
@register_lowering(lax.abs_p)
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
try:
return _abs_dispatch_table(ctx, x)
except NotImplementedError as e:
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval, jnp.integer):
return math_dialect.absi(x)
elif jnp.issubdtype(x_aval, jnp.floating):
return math_dialect.absf(x)
else:
raise e from None
floor_dispatch_table = _make_dispatch_table(
"floor",
cuda=[
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
_Extern([jnp.float64], "__nv_floor", jnp.float64),
_Fallback([jnp.floating], math_dialect.floor),
],
rocm=[
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.floor),
],
)
exp_dispatch_table = _make_dispatch_table(
"exp",
cuda=[
_Extern([jnp.float32], "__nv_expf", jnp.float32),
_Extern([jnp.float64], "__nv_exp", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp),
],
rocm=[
_Fallback([jnp.float32], math_dialect.exp),
_Fallback([jnp.float64], math_dialect.exp),
_Fallback([jnp.floating], math_dialect.exp),
],
)
exp2_dispatch_table = _make_dispatch_table(
"exp2",
cuda=[
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp2),
],
rocm=[
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.exp2),
],
)
expm1_dispatch_table = _make_dispatch_table(
"expm1",
cuda=[
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
_Fallback([jnp.floating], math_dialect.expm1),
],
rocm=[
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.expm1),
],
)
log_dispatch_table = _make_dispatch_table(
"log",
cuda=[
_Extern([jnp.float32], "__nv_logf", jnp.float32),
_Extern([jnp.float64], "__nv_log", jnp.float64),
_Fallback([jnp.floating], math_dialect.log),
],
rocm=[
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.log),
],
)
log1p_dispatch_table = _make_dispatch_table(
"log1p",
cuda=[
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
_Fallback([jnp.floating], math_dialect.log1p),
],
rocm=[
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.log1p),
],
)
sqrt_dispatch_table = _make_dispatch_table(
"sqrt",
cuda=[
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.sqrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sqrt),
],
)
pow_dispatch_table = _make_dispatch_table(
"pow",
cuda=[
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
math_dialect.fpowi
),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
math_dialect.powf
),
],
rocm=[
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
math_dialect.fpowi
),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
math_dialect.powf
),
],
)
cbrt_dispatch_table = _make_dispatch_table(
"cbrt",
cuda=[
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.cbrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cbrt),
],
)
rsqrt_dispatch_table = _make_dispatch_table(
"rsqrt",
cuda=[
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
_Fallback([jnp.floating], math_dialect.rsqrt),
],
rocm=[
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.rsqrt),
],
)
sin_dispatch_table = _make_dispatch_table(
"sin",
cuda=[
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
_Extern([jnp.float64], "__nv_sin", jnp.float64),
_Fallback([jnp.floating], math_dialect.sin),
],
rocm=[
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sin),
],
)
cos_dispatch_table = _make_dispatch_table(
"cos",
cuda=[
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
_Extern([jnp.float64], "__nv_cos", jnp.float64),
_Fallback([jnp.floating], math_dialect.cos),
],
rocm=[
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cos),
],
)
tan_dispatch_table = _make_dispatch_table(
"tan",
cuda=[
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
_Extern([jnp.float64], "__nv_tan", jnp.float64),
_Fallback([jnp.floating], math_dialect.tan),
],
rocm=[
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.tan),
],
)
asin_dispatch_table = _make_dispatch_table(
"asin",
cuda=[
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
_Extern([jnp.float64], "__nv_asin", jnp.float64),
_Fallback([jnp.floating], math_dialect.asin),
],
rocm=[
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.asin),
],
)
acos_dispatch_table = _make_dispatch_table(
"acos",
cuda=[
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
_Extern([jnp.float64], "__nv_acos", jnp.float64),
_Fallback([jnp.floating], math_dialect.acos),
],
rocm=[
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.acos),
],
)
atan_dispatch_table = _make_dispatch_table(
"atan",
cuda=[
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
_Extern([jnp.float64], "__nv_atan", jnp.float64),
_Fallback([jnp.floating], math_dialect.atan),
],
rocm=[
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.atan),
],
)
atan2_dispatch_table = _make_dispatch_table(
"atan2",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
],
rocm=[
_Extern([jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64),
_Fallback([jnp.floating, jnp.floating], math_dialect.atan2),
],
)
sinh_dispatch_table = _make_dispatch_table(
"sinh",
cuda=[
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
_Fallback([jnp.floating], math_dialect.sinh),
],
rocm=[
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.sinh),
],
)
cosh_dispatch_table = _make_dispatch_table(
"cosh",
cuda=[
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
_Fallback([jnp.floating], math_dialect.cosh),
],
rocm=[
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.cosh),
],
)
tanh_dispatch_table = _make_dispatch_table(
"tanh",
cuda=[
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
_Fallback([jnp.floating], math_dialect.tanh),
],
rocm=[
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.tanh),
],
)
asinh_dispatch_table = _make_dispatch_table(
"asinh",
cuda=[
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
_Fallback([jnp.floating], math_dialect.asinh),
],
rocm=[
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.asinh),
],
)
acosh_dispatch_table = _make_dispatch_table(
"acosh",
cuda=[
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
_Fallback([jnp.floating], math_dialect.acosh),
],
rocm=[
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.acosh),
],
)
atanh_dispatch_table = _make_dispatch_table(
"atanh",
cuda=[
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
_Fallback([jnp.floating], math_dialect.atanh),
],
rocm=[
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
_Fallback([jnp.floating], math_dialect.atanh),
],
)
population_count_dispatch_table = _make_dispatch_table(
"population_count",
cuda=[
_Extern([jnp.int32], "__nv_popc", jnp.int32),
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
_Fallback([jnp.integer], math_dialect.ctpop),
],
rocm=[
_Fallback([jnp.integer], math_dialect.ctpop),
],
)
clz_dispatch_table = _make_dispatch_table(
"clz",
cuda=[
_Extern([jnp.int32], "__nv_clz", jnp.int32),
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
_Fallback([jnp.integer], math_dialect.ctlz),
],
rocm=[
_Fallback([jnp.integer], math_dialect.ctlz),
],
)
nextafter_dispatch_table = _make_dispatch_table(
"nextafter",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
),
],
)
triton_lowering_rules.update({
lax.abs_p: abs_dispatch_table,
lax.neg_p: lambda ctx, x: _minus(x),
lax.ceil_p: _make_dispatch_table(
"ceil",
cuda=[
_Extern([jnp.float32], "__nv_ceilf", jnp.float32),
_Extern([jnp.float64], "__nv_ceil", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64),
],
),
lax.floor_p: _make_dispatch_table(
"floor",
cuda=[
_Extern([jnp.float32], "__nv_floorf", jnp.float32),
_Extern([jnp.float64], "__nv_floor", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_floor_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_floor_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)),
],
),
lax.exp_p: _make_dispatch_table(
"exp",
cuda=[
_Extern([jnp.float32], "__nv_expf", jnp.float32),
_Extern([jnp.float64], "__nv_exp", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
rocm=[
_Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)),
],
),
lax.exp2_p: _make_dispatch_table(
"exp2",
cuda=[
_Extern([jnp.float32], "__nv_exp2f", jnp.float32),
_Extern([jnp.float64], "__nv_exp2", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)),
],
),
lax.expm1_p: _make_dispatch_table(
"expm1",
cuda=[
_Extern([jnp.float32], "__nv_expm1f", jnp.float32),
_Extern([jnp.float64], "__nv_expm1", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64),
],
),
lax.log_p: _make_dispatch_table(
"log",
cuda=[
_Extern([jnp.float32], "__nv_logf", jnp.float32),
_Extern([jnp.float64], "__nv_log", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_log_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)),
],
),
lax.log1p_p: _make_dispatch_table(
"log1p",
cuda=[
_Extern([jnp.float32], "__nv_log1pf", jnp.float32),
_Extern([jnp.float64], "__nv_log1p", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64),
],
),
lax.sqrt_p: _make_dispatch_table(
"sqrt",
cuda=[
_Extern([jnp.float32], "__nv_sqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_sqrt", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
],
),
lax.ceil_p: ceil_dispatch_table,
lax.floor_p: floor_dispatch_table,
lax.exp_p: exp_dispatch_table,
lax.exp2_p: exp2_dispatch_table,
lax.expm1_p: expm1_dispatch_table,
lax.log_p: log_dispatch_table,
lax.log1p_p: log1p_dispatch_table,
lax.sqrt_p: sqrt_dispatch_table,
lax.square_p: lambda ctx, x: _mul(x, x),
lax.pow_p: _make_dispatch_table(
"pow",
cuda=[
_Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
],
rocm=[
_Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32),
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
],
),
lax.cbrt_p: _make_dispatch_table(
"cbrt",
cuda=[
_Extern([jnp.float32], "__nv_cbrtf", jnp.float32),
_Extern([jnp.float64], "__nv_cbrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64),
],
),
lax.rsqrt_p: _make_dispatch_table(
"rsqrt",
cuda=[
_Extern([jnp.float32], "__nv_rsqrtf", jnp.float32),
_Extern([jnp.float64], "__nv_rsqrt", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64),
],
),
lax.sin_p: _make_dispatch_table(
"sin",
cuda=[
_Extern([jnp.float32], "__nv_sinf", jnp.float32),
_Extern([jnp.float64], "__nv_sin", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_sin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sin_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)),
],
),
lax.cos_p: _make_dispatch_table(
"cos",
cuda=[
_Extern([jnp.float32], "__nv_cosf", jnp.float32),
_Extern([jnp.float64], "__nv_cos", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
rocm=[
_Extern([jnp.float32], "__ocml_cos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cos_f64", jnp.float64),
_Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)),
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)),
],
),
lax.tan_p: _make_dispatch_table(
"tan",
cuda=[
_Extern([jnp.float32], "__nv_tanf", jnp.float32),
_Extern([jnp.float64], "__nv_tan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tan_f64", jnp.float64),
],
),
lax.asin_p: _make_dispatch_table(
"asin",
cuda=[
_Extern([jnp.float32], "__nv_asinf", jnp.float32),
_Extern([jnp.float64], "__nv_asin", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asin_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asin_f64", jnp.float64),
],
),
lax.acos_p: _make_dispatch_table(
"acos",
cuda=[
_Extern([jnp.float32], "__nv_acosf", jnp.float32),
_Extern([jnp.float64], "__nv_acos", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acos_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acos_f64", jnp.float64),
],
),
lax.atan_p: _make_dispatch_table(
"atan",
cuda=[
_Extern([jnp.float32], "__nv_atanf", jnp.float32),
_Extern([jnp.float64], "__nv_atan", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atan_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atan_f64", jnp.float64),
],
),
lax.atan2_p: _make_dispatch_table(
"atan2",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64
),
],
),
lax.sinh_p: _make_dispatch_table(
"sinh",
cuda=[
_Extern([jnp.float32], "__nv_sinhf", jnp.float32),
_Extern([jnp.float64], "__nv_sinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64),
],
),
lax.cosh_p: _make_dispatch_table(
"cosh",
cuda=[
_Extern([jnp.float32], "__nv_coshf", jnp.float32),
_Extern([jnp.float64], "__nv_cosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64),
],
),
lax.tanh_p: _make_dispatch_table(
"tanh",
cuda=[
_Extern([jnp.float32], "__nv_tanhf", jnp.float32),
_Extern([jnp.float64], "__nv_tanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64),
],
),
lax.asinh_p: _make_dispatch_table(
"asinh",
cuda=[
_Extern([jnp.float32], "__nv_asinhf", jnp.float32),
_Extern([jnp.float64], "__nv_asinh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64),
],
),
lax.acosh_p: _make_dispatch_table(
"acosh",
cuda=[
_Extern([jnp.float32], "__nv_acoshf", jnp.float32),
_Extern([jnp.float64], "__nv_acosh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64),
],
),
lax.atanh_p: _make_dispatch_table(
"atanh",
cuda=[
_Extern([jnp.float32], "__nv_atanhf", jnp.float32),
_Extern([jnp.float64], "__nv_atanh", jnp.float64),
],
rocm=[
_Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32),
_Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64),
],
),
lax.population_count_p: _make_dispatch_table(
"population_count",
cuda=[
_Extern([jnp.int32], "__nv_popc", jnp.int32),
_Extern([jnp.int64], "__nv_popcll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)),
],
),
lax.clz_p: _make_dispatch_table(
"clz",
cuda=[
_Extern([jnp.int32], "__nv_clz", jnp.int32),
_Extern([jnp.int64], "__nv_clzll", jnp.int32),
],
rocm=[
_Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)),
_Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)),
],
),
lax.nextafter_p: _make_dispatch_table(
"nextafter",
cuda=[
_Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64),
],
rocm=[
_Extern(
[jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32
),
_Extern(
[jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64
),
],
),
lax.pow_p: pow_dispatch_table,
lax.cbrt_p: cbrt_dispatch_table,
lax.rsqrt_p: rsqrt_dispatch_table,
lax.sin_p: sin_dispatch_table,
lax.cos_p: cos_dispatch_table,
lax.tan_p: tan_dispatch_table,
lax.asin_p: asin_dispatch_table,
lax.acos_p: acos_dispatch_table,
lax.atan_p: atan_dispatch_table,
lax.atan2_p: atan2_dispatch_table,
lax.sinh_p: sinh_dispatch_table,
lax.cosh_p: cosh_dispatch_table,
lax.tanh_p: tanh_dispatch_table,
lax.asinh_p: asinh_dispatch_table,
lax.acosh_p: acosh_dispatch_table,
lax.atanh_p: atanh_dispatch_table,
lax.population_count_p: population_count_dispatch_table,
lax.clz_p: clz_dispatch_table,
lax.nextafter_p: nextafter_dispatch_table,
})
@ -2211,6 +2293,10 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
ctx.context, jaxpr.jaxpr, ctx.block_infos, *args
)
@register_lowering(pjit.mesh_cast_p)
def _mesh_cast_lowering_rule(ctx, x, dst_sharding):
return x
@register_lowering(jax_core.closed_call_p)
@register_lowering(custom_derivatives.custom_jvp_call_p)

@ -17,10 +17,16 @@
from __future__ import annotations
import io
import re
from typing import Any
import zlib
import jax
import jax._src.core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import triton
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.triton import lowering
@ -51,7 +57,7 @@ def pallas_call_lowering(
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del interpret, out_avals
del interpret, cost_estimate, out_avals
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
@ -77,6 +83,11 @@ def pallas_call_lowering(
print("The grid mapping for pallas_call {name_and_src_info}:")
print(grid_mapping)
# Sanitize the name to conform to NVPTX requirements. We do this here
# to avoid the need to fetch the new name from PTX post compilation.
name_and_src_info = name_and_src_info.replace(
name=re.sub(r"[^a-zA-Z0-9_$]", "_", name_and_src_info.name)
)
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, grid_mapping, name_and_src_info, lowering_platform
)
@ -86,35 +97,93 @@ def pallas_call_lowering(
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
out_types = [
buf = io.BytesIO()
module_op.write_bytecode(buf)
if jaxlib_version < (0, 5, 1):
# AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype))
for bm in grid_mapping.block_mappings_output
]
buf = io.BytesIO()
module_op.write_bytecode(buf)
backend_config = dict(
name=ir.StringAttr.get(name_and_src_info.name),
ir=ir.StringAttr.get(buf.getvalue()),
num_stages=mlir.i32_attr(num_stages),
num_warps=mlir.i32_attr(num_warps),
grid_x=mlir.i32_attr(grid_x),
grid_y=mlir.i32_attr(grid_y),
grid_z=mlir.i32_attr(grid_z),
debug=ir.BoolAttr.get(debug),
]
backend_config = dict(
name=ir.StringAttr.get(name_and_src_info.name),
ir=ir.StringAttr.get(buf.getvalue()),
num_stages=mlir.i32_attr(num_stages),
num_warps=mlir.i32_attr(num_warps),
grid_x=mlir.i32_attr(grid_x),
grid_y=mlir.i32_attr(grid_y),
grid_z=mlir.i32_attr(grid_z),
debug=ir.BoolAttr.get(debug),
)
if "serialized_metadata" in (triton_params or {}):
# This field is unstable and may be removed in the future.
if triton_params["serialized_metadata"] is not None:
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
return mlir.custom_call(
call_target_name="__gpu$xla.gpu.triton",
result_types=out_types,
operands=in_nodes,
backend_config=backend_config,
api_version=4,
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
).results
# TODO(slebedev): Make this work for ROCm.
try:
gpu_device, *_ = jax.local_devices(backend="gpu")
except RuntimeError:
# GPU device is not available. Fall back to the minimum CC supported by Triton.
# TODO(slebedev): Make the fallback CC configurable.
arch_name = "8.0"
cc = 80
else:
arch_name = str(gpu_device.compute_capability)
cc = int(arch_name.replace(".", ""))
compilation_result = triton.compile(
lowering_platform,
buf.getvalue(),
arch_name,
num_warps=num_warps,
num_ctas=1,
num_stages=num_stages,
)
if "serialized_metadata" in (triton_params or {}):
# This field is unstable and may be removed in the future.
if triton_params["serialized_metadata"] is not None:
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
kernel = triton_kernel_call_lib.TritonKernel(
name_and_src_info.name,
num_warps,
compilation_result.smem_bytes,
compilation_result.asm,
module_op.get_asm(enable_debug_info=True, pretty_debug_info=True),
cc,
compilation_result.cluster_dim_x,
compilation_result.cluster_dim_y,
compilation_result.cluster_dim_z,
)
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel,
grid_x,
grid_y,
grid_z,
[triton_kernel_call_lib.create_array_parameter(0, 16)]
* (len(ctx.avals_in) + len(ctx.avals_out)),
)
# TODO(b/392558289): Migrate to ``jax.ffi``.
return mlir.custom_call(
call_target_name="__gpu$xla.gpu.triton",
result_types=out_types,
call_target_name="triton_kernel_call",
result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item]
operands=in_nodes,
backend_config=backend_config,
api_version=4,
backend_config=zlib.compress(
kernel_call.to_proto(
name_and_src_info.name,
triton_params.get("serialized_metadata") or b"",
)
),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),

@ -25,15 +25,6 @@ import jax.numpy as jnp
import numpy as np
def when(condition):
def _wrapped(f):
if isinstance(condition, bool):
if condition:
f()
else:
lax.cond(condition, f, lambda: None)
return _wrapped
@overload
def cdiv(a: int, b: int) -> int:
...

@ -49,7 +49,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
@ -548,7 +548,7 @@ def _infer_params_impl(
ji: PjitInfo,
pjit_mesh: mesh_lib.Mesh | None,
resource_env: mesh_lib.ResourceEnv | None,
dbg: lu.TracingDebugInfo,
dbg: core.DebugInfo,
args: tuple[Any, ...],
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
@ -567,9 +567,7 @@ def _infer_params_impl(
axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
dbg = dbg and dbg.add_result_paths(result_paths_thunk=res_paths)
f = lu.wrap_init(fun, debug_info=dbg)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
@ -618,7 +616,7 @@ def _infer_params_impl(
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
ji.in_layouts_treedef, ji.in_layouts_leaves,
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs)
attr_token = _attr_token(flat_fun, in_type)
@ -627,8 +625,7 @@ def _infer_params_impl(
if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh())
with mesh_lib.set_abstract_mesh(abstract_mesh):
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
flat_fun, in_type, attr_token,
IgnoreKey(ji.inline))
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
@ -733,7 +730,7 @@ def _infer_params(
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
' together is not allowed.')
dbg = tracing_debug_info(
dbg = debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
@ -756,7 +753,7 @@ def _infer_params(
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs
def _infer_input_type(fun: Callable, dbg: lu.TracingDebugInfo | None,
def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None,
explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
try:
@ -1171,17 +1168,18 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
callsites: set[str] = set()
def explain_tracing_cache_miss(
f: Callable, unseen_f: bool, cache: dict, key: tuple):
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
if config.check_tracer_leaks.value: return
def unpack(key):
transforms, (), _, (in_type, _, debug_info, _, inline), *_, ctx = key
transforms, (), _, (in_type, _, inline), *_, ctx = key
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
_, (_, (in_tree,)), *_ = transforms
return in_tree, in_type, debug_info, inline.val, ctx
in_tree, in_type, debug_info, inline, ctx = unpack(key)
return in_tree, in_type, inline.val, ctx
in_tree, in_type, inline, ctx = unpack(key)
if inline: return
debug_info = fun.debug_info
msg: list[str] = []
p = msg.append
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
@ -1190,7 +1188,7 @@ def explain_tracing_cache_miss(
p(f"TRACING CACHE MISS at {callsite} because:")
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
fun_name = getattr(fun.f, '__qualname__', fun.f)
if debug_info is not None and debug_info.func_src_info:
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
@ -1198,7 +1196,7 @@ def explain_tracing_cache_miss(
else:
src_info = ''
if unseen_f:
p(f" never seen function:\n {fun_name} id={id(f)}{src_info}")
p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}")
if callsite in callsites:
p(" but seen another function defined on the same line; maybe the function is\n"
" being re-defined repeatedly, preventing caching?")
@ -1263,7 +1261,7 @@ def explain_tracing_cache_miss(
# have we never seen these input types (eg shapes, dtypes) before?
types_match = [k for k in trees_match if k[1] == in_type]
if not types_match:
if len(in_type) < 5:
if len(in_type) < 5 and debug_info is not None:
in_type_str = ':\n {}'.format(', '.join(
f'{n}: {ty.str_short(short_dtypes=True)}'
for n, ty in zip(debug_info.arg_names, in_type)))
@ -1275,7 +1273,12 @@ def explain_tracing_cache_miss(
num_mismatch = sum(map(op.ne, closest_ty, in_type))
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
add_weak_type_hint = False
for name, ty1, ty2 in zip(debug_info.arg_names, closest_ty, in_type):
if debug_info:
arg_names = debug_info.safe_arg_names(len(in_type))
else:
arg_names = (None,) * len(in_type)
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
if ty1 != ty2:
if type(ty1) == type(ty2) == core.ShapedArray:
s1, s2 = ty1.str_short(True), ty2.str_short(True)
@ -1302,8 +1305,6 @@ def _create_pjit_jaxpr(
fun: lu.WrappedFun,
in_type: core.InputType | Sequence[core.AbstractValue],
attr_data: int,
debug_info: lu.TracingDebugInfo,
result_paths: Callable,
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
@ -1317,17 +1318,13 @@ def _create_pjit_jaxpr(
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, cast(core.InputType, in_type)), debug_info=debug_info)
lu.annotate(fun, cast(core.InputType, in_type)))
attrs_tracked = []
else:
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=debug_info)
fun, in_type)
# assert attr_data is sentinel or attr_data matches attrs_tracked
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = add_jaxpr_debug_info(jaxpr, debug_info, result_paths())
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr
@ -1346,7 +1343,7 @@ def _create_pjit_jaxpr(
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals,
debug_info: core.JaxprDebugInfo | None,
debug_info: core.DebugInfo | None,
device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
@ -1479,7 +1476,6 @@ def check_aval_layout_compatibility(
pjit_p = core.Primitive("pjit")
pjit_p.multiple_results = True
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
@ -1928,7 +1924,9 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext,
name: str, jaxpr: core.ClosedJaxpr,
effects, in_shardings,
out_shardings, in_layouts, out_layouts,
api_name):
mod_ctx = ctx.module_context
@ -1959,7 +1957,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
return func
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
jaxpr: core.ClosedJaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
@ -1987,8 +1986,10 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
mlir.register_lowering(pjit_p, _pjit_lowering)
def _pjit_batcher(axis_data, vals_in, dims_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
def _pjit_batcher(axis_data, vals_in,
dims_in: tuple[int, ...],
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
@ -2037,7 +2038,8 @@ batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
def _pjit_batcher_for_sharding(
s: Sharding | UnspecifiedValue,
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, mesh,
ndim: int):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim)
@ -2049,7 +2051,7 @@ def _pjit_batcher_for_sharding(
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
new_op = hlo_s.to_proto().clone()
tad = list(new_op.tile_assignment_dimensions)
tad.insert(dim, 1)
tad.insert(dim, 1) # type: ignore
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(
s._device_assignment, new_op,
@ -2171,8 +2173,9 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
ad.primitive_linearizations[pjit_p] = _pjit_linearization
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
def _pjit_partial_eval(trace: pe.JaxprTrace,
*in_tracers,
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
@ -2191,7 +2194,7 @@ def _pjit_partial_eval(trace, *in_tracers,
else:
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
res_shardings = (UNSPECIFIED,) * num_residuals
@ -2282,7 +2285,7 @@ def _pjit_partial_eval(trace, *in_tracers,
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_out_avals = unknown_jaxpr.out_avals
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore
for aval in unknown_out_avals
]
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
@ -2707,14 +2710,20 @@ def mesh_cast(xs, out_shardings):
return tree_unflatten(treedef, out_flat)
mesh_cast_p = core.Primitive('mesh_cast')
mesh_cast_p.skip_canonicalization = True
def _mesh_cast_abstract_eval(aval, dst_sharding):
src_sharding = aval.sharding
if src_sharding == dst_sharding:
return aval
if src_sharding.mesh.empty or dst_sharding.mesh.empty:
return aval.update(sharding=dst_sharding)
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
raise ValueError(
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
' match the mesh shape of the target sharding'
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
if src_sharding.mesh.axis_types == dst_sharding.mesh.axis_types:
if (src_sharding.mesh.axis_types == dst_sharding.mesh.axis_types and
src_sharding.spec != dst_sharding.spec):
raise ValueError(
'mesh_cast should only be used when AxisTypes changes between the'
' input mesh and the target mesh. Got src'
@ -2746,7 +2755,9 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval)
def _mesh_cast_impl(x, dst_sharding):
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
x_aval = core.shaped_abstractify(x)
with mesh_lib.set_abstract_mesh(x_aval.sharding.mesh):
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
mesh_cast_p.def_impl(_mesh_cast_impl)
def _mesh_cast_transpose_rule(ct, x, dst_sharding):
@ -2763,7 +2774,6 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, dst_sharding):
mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering)
def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding):
assert axis_data.spmd_name is None
x, = vals_in
d, = dims_in
vmapped_dst_sharding = batching.get_sharding_for_vmap(

@ -2070,8 +2070,8 @@ def orthogonal(
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
z = normal(key, (*shape, n, n), dtype)
q, r = jnp.linalg.qr(z)
d = jnp.diagonal(r, 0, -2, -1)
return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2]))
d = jnp.linalg.diagonal(r)
return q * jnp.expand_dims(jnp.sign(d), -2)
def generalized_normal(
key: ArrayLike,

@ -989,7 +989,7 @@ def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue],
def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue],
dbg: api_util.TracingDebugInfo,
dbg: core.DebugInfo,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals), dbg)
@ -997,17 +997,18 @@ def initial_style_jaxpr(
def _initial_style_jaxpr(fun: Callable,
in_tree: api_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug: api_util.TracingDebugInfo):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
debug: core.DebugInfo):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug),
tree_util.treedef_tuple((in_tree,)))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals)
return jaxpr, consts, out_tree_thunk()
T = TypeVar('T')
def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
# There may be some uninitialized values here in ref_args.
@ -1027,7 +1028,7 @@ def run_state(f: Callable[..., None]) -> Callable[[T], T]:
def run_state_reference(f: Callable[..., None]):
def wrapped(args):
dbg = api_util.tracing_debug_info("run_state", f, (args,), {})
dbg = api_util.debug_info("run_state", f, (args,), {})
flat_args, in_tree = tree_util.tree_flatten(args)
ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args))
jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals, dbg)

@ -217,7 +217,7 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
# TODO(yashkatariya): Transform the sharding too instead of setting it to
# None.
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
sharding=None)
sharding=core.get_cur_mesh_sharding())
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")

@ -266,6 +266,14 @@ class TransformedRef:
(*self.transforms, RefReshaper.from_ref_new_shape(self, *shape)),
)
def set(self, value, idx=()):
from jax._src.state.primitives import ref_set # pytype: disable=import-error
return ref_set(self, idx, value)
def get(self, idx=()):
from jax._src.state.primitives import ref_get # pytype: disable=import-error
return ref_get(self, idx)
def __getattr__(self, name):
return getattr(self.ref, name)

@ -537,7 +537,7 @@ def request_cpu_devices(nr_devices: int):
invoked. Test cases that require a specific number of devices should skip
themselves if that number is not met.
"""
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
if config.num_cpu_devices.value < nr_devices:
xla_bridge.get_backend.cache_clear()
config.update("jax_num_cpu_devices", nr_devices)

@ -104,17 +104,6 @@ _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag(
help="Deprecated, please use jax_cpu_collectives_implementation instead.",
)
CPU_COLLECTIVES_IMPLEMENTATIONS = ["none", "gloo", "mpi"]
CPU_COLLECTIVES_IMPLEMENTATION = config.enum_flag(
name="jax_cpu_collectives_implementation",
default="none",
enum_values=CPU_COLLECTIVES_IMPLEMENTATIONS,
help=(
"Cross-process collective implementation used on CPU. Must be one of"
f" {CPU_COLLECTIVES_IMPLEMENTATIONS}"
),
)
_CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
name="jax_cpu_enable_async_dispatch",
default=True,
@ -122,14 +111,6 @@ _CPU_ENABLE_ASYNC_DISPATCH = config.bool_flag(
"inline without async dispatch.",
)
NUM_CPU_DEVICES = config.int_flag(
name="jax_num_cpu_devices",
default=-1,
help="Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized.",
)
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
@ -255,7 +236,7 @@ def make_cpu_client(
The created CPU client.
"""
if collectives is None:
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
collectives_impl = config.cpu_collectives_implementation.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
@ -271,12 +252,13 @@ def make_cpu_client(
collectives = xla_client._xla.make_mpi_collectives()
collectives.Init()
atexit.register(collectives.Finalize)
elif collectives_impl != 'none':
raise RuntimeError(f"Unknown collectives implementation "
f"{collectives_impl}. Available implementations are "
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
elif collectives_impl == 'megascale':
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
else:
# Already validated by config module
assert collectives_impl is None
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
num_devices = config.num_cpu_devices.value if config.num_cpu_devices.value >= 0 else None
return xla_client.make_cpu_client(
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,

@ -20,13 +20,13 @@ from jax._src.core import (
AbstractValue as AbstractValue,
Atom as Atom,
CallPrimitive as CallPrimitive,
DebugInfo as DebugInfo,
DShapedArray as DShapedArray,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
get_opaque_trace_state as get_opaque_trace_state,
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
JaxprDebugInfo as JaxprDebugInfo,
JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError,

@ -85,8 +85,9 @@ from .utils import (
warpgroup_idx as warpgroup_idx,
when as when,
)
# The import below shadows the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401
from .wgmma import (
WGMMAAccumulator as WGMMAAccumulator,
WGMMALayout as WGMMALayout,
wgmma as wgmma,
)

@ -220,7 +220,7 @@ def build_kernel(
# TODO(apaszke): Support WGMMA without an initial accumulator.
qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv)
q, k = qo_smem, memref_slice(k_smem, slot)
qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR)
qk_acc = wgmma(qk_acc, q, memref_transpose(k, (0, 1, 3, 2)))
nvvm.wgmma_commit_group_sync_aligned()
perform_schedule_barrier()
@ -441,7 +441,7 @@ def build_kernel(
# TODO(apaszke): Support WGMMA without an initial accumulator.
qk_acc = WGMMAAccumulator.zero(blocks.q, blocks.kv)
q, k = qo_smem, memref_slice(k_smem, slot)
qk_acc = wgmma(qk_acc, q, k, b_order=WGMMALayout.COL_MAJOR)
qk_acc = wgmma(qk_acc, q, memref_transpose(k, (0, 1, 3, 2)))
nvvm.wgmma_commit_group_sync_aligned()
# We hide the TMA overhead by overlapping it with the QK matmul.

@ -68,7 +68,7 @@ class WGMMADefaultImpl:
block_tiling: Tiling,
tma_tiling: Tiling,
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
rhs_transpose: WGMMALayout,
rhs_transpose: bool,
) -> dict[str, jax.ShapeDtypeStruct]:
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused.
return ()
@ -81,7 +81,6 @@ class WGMMADefaultImpl:
def wgmma(
smem_scratch: Any, # pylint: disable=unused-argument
acc: WGMMAAccumulator,
b_order: WGMMALayout,
a_slice: SmemRef,
b_slice: SmemRef,
swizzle: int,
@ -91,7 +90,7 @@ class WGMMADefaultImpl:
This function must guarantee that all WGMMA operations queued before it was
called have completed before returning.
"""
acc = wgmma(acc, a_slice, b_slice, b_order=b_order, swizzle=swizzle)
acc = wgmma(acc, a_slice, b_slice, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(1)
return acc
@ -250,11 +249,10 @@ def build_kernel(
with ctx.named_region("WGMMA"):
a_slice = memref_slice(lhs_smem, si)
b_slice = memref_slice(rhs_smem, si)
rhs_smem_order = (
WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR
)
if rhs_transpose:
b_slice = memref_transpose(b_slice, (0, 1, 3, 2))
accs = wgmma_impl.wgmma(
impl_smem, accs, rhs_smem_order, a_slice, b_slice, swizzle=swizzle
impl_smem, accs, a_slice, b_slice, swizzle=swizzle
)
with ctx.named_region("TMA start"):

@ -0,0 +1,197 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Matmul kernel for Blackwell."""
import jax
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import nvvm
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, ds, utils
from jax.experimental.mosaic.gpu import tcgen05
import jax.numpy as jnp
import jax.random as jr
import numpy as np
BLACKWELL_MMA_FP16_K = 16
TMA_WARP = 1
MMA_WARP = 0
def bytecount(shape, dtype):
return int(np.prod(shape) * dtype.dtype.itemsize)
def build_kernel(
m, n, k,
tile_m: int = 128,
tile_n: int = 128,
):
i1 = ir.IntegerType.get_signless(1)
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
ptr6 = ir.Type.parse("!llvm.ptr<6>") # TMEM
swizzle = 128
tile_k = 64 # TODO(apaszke): I think we need to tile TMA to change this.
in_dtype = jnp.float16
k_loop_iter = k // tile_k
tma_tile_m = 128
tma_tile_kn = 64
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
def kernel(ctx, a, b, d, smem):
# TODO(apaszke): Use more SMEM slots to avoid oversynchronizing warps.
a_smem, b_smem, d_smem, barriers, tmem_addr = smem
(ab_full_barrier, ab_empty_barrier, mma_done_barrier) = barriers
warp_idx = mgpu.warp_idx(sync=True)
warp_leader = nvvm.elect_sync(i1)
is_warp = lambda i: arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32))
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m,index))
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n,index))
with mgpu.when(arith.andi(is_warp(TMA_WARP), warp_leader)):
@mgpu.fori(c(k_loop_iter, index), None)
def _tma_body(ki, _):
# TODO(apaszke): Use a predicate instead of a conditional.
with mgpu.when(arith.cmpi(arith.CmpIPredicate.ugt, ki, c(0, index))):
ab_empty_barrier.wait()
ab_full_barrier.arrive_expect_tx(
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
)
k_start = arith.muli(ki, c(tile_k, index))
common_args = dict(
swizzle=swizzle, barrier=ab_full_barrier, arrive=False, uniform=False,
)
ctx.async_copy(
src_ref=a,
dst_ref=a_smem,
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
gmem_transform=mgpu.TileTransform((tma_tile_m, tma_tile_kn)),
**common_args,
)
ctx.async_copy(
src_ref=b,
dst_ref=b_smem,
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
gmem_transform=(
mgpu.TileTransform((tma_tile_kn, tma_tile_kn)),
mgpu.TransposeTransform((1, 0, 2, 3)),
),
**common_args,
)
with mgpu.when(is_warp(MMA_WARP)):
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
tcgen05.tmem_relinquish_alloc_permit()
with mgpu.when(warp_leader):
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
ab_full_barrier.wait()
tcgen05.mma(
tmem_addr_value,
a_smem,
mgpu.memref_transpose(b_smem, (0, 1, 3, 2)),
a_swizzle=swizzle,
b_swizzle=swizzle,
accumulate=accumulate,
)
accumulate = arith.constant(i1, 1)
is_last_iter = arith.cmpi(
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
)
barrier_ptr = arith.select(
is_last_iter, mma_done_barrier.get_ptr(), ab_empty_barrier.get_ptr()
)
tcgen05.commit_arrive(barrier_ptr)
return accumulate
gpu.barrier()
mma_done_barrier.wait()
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
mgpu.commit_shared()
ctx.async_copy(
src_ref=d_smem,
dst_ref=d,
gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)),
gmem_transform=mgpu.TileTransform((128, 64)),
swizzle=swizzle,
)
ctx.await_async_copy(0)
smem = (
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_k), (tma_tile_m, tma_tile_kn)), jnp.float16),
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_k, tile_n), (tma_tile_kn, tma_tile_kn)), jnp.float16),
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
[mgpu.Barrier(arrival_count=1)] * 3,
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
)
return mgpu.as_gpu_kernel(
kernel,
(n // tile_n, m // tile_m, 1),
(128, 1, 1),
(
jax.ShapeDtypeStruct((m, k), jnp.float16),
jax.ShapeDtypeStruct((n, k), jnp.float16),
),
jax.ShapeDtypeStruct((m, n), jnp.float16),
smem,
)
def main(unused_argv):
m_tile = 128
n_tile = 128
k_tile = 64
m = 16*m_tile
n = 16*n_tile
k = 16*k_tile
ka, kb = jr.split(jr.key(0), 2)
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
with mlir.make_ir_context(), ir.Location.unknown():
f = build_kernel(m, n, k, tile_m=m_tile, tile_n=n_tile)
y = f(a, b).block_until_ready()
ref = np.asarray(a) @ np.asarray(b).T
np.testing.assert_allclose(y, ref, atol=1e-3, rtol=1e-3)
print("OK!")
if __name__ == "__main__":
from absl import app
import jax
jax.config.config_with_absl()
app.run(main)

@ -0,0 +1,338 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import enum
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import memref
import numpy as np
from . import utils
from . import fragmented_array as fa
from . import _wgmma
# MyPy does a terrible job with the MLIR API.
# mypy: ignore-errors
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
def create_smem_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
):
return _wgmma.create_descriptor(
memref_arg,
leading_byte_offset,
stride_byte_offset,
swizzle,
memory_space=3,
const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)
def create_instr_descriptor(
m: int,
n: int,
acc_dtype,
input_dtype,
transpose_a: bool = False,
transpose_b: bool = False,
):
f32 = ir.F32Type.get()
bf16 = ir.BF16Type.get()
f16 = ir.F16Type.get()
if input_dtype not in {f16, bf16}:
raise NotImplementedError("Only float16 and bfloat16 inputs supported")
if acc_dtype not in {f32, f16}:
raise NotImplementedError("Only float32 and float16 accumulators supported")
desc = 0
# We ignore sparsity in bits 0-3
desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5
# Bit 6 is reserved
desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9
desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12
# We ignore negate bits 13-14
desc |= transpose_a << 15 # Transpose A
desc |= transpose_b << 16 # Transpose B
if n % 8 or n > 256:
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
desc |= (n >> 3) << 17 # N, bits 17-22
# Bit 23 is reserved
if m % 16 or m > 256:
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
# Bit 29 is reserved
# We ignore max shift under .ws, bits 30-31
return arith.constant(ir.IntegerType.get_signless(32), desc)
def mma(
d: ir.Value,
a: ir.Value,
b: ir.Value,
*,
a_swizzle: int = 128,
b_swizzle: int = 128,
num_cta: int = 1,
accumulate: ir.Value | bool = True,
):
if not ir.MemRefType.isinstance(a.type):
raise ValueError(f"A must be a memref, got {a.type}")
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
if a_swizzle != 128 or b_swizzle != 128:
raise NotImplementedError("Only swizzle=128 has been tested")
if num_cta != 1:
raise NotImplementedError("Only num_cta=1 supported")
if isinstance(accumulate, bool):
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, kn_tiling),
element_type,
mma_params,
a_k_byte_stride,
b_k_byte_stride,
) = _wgmma._validate_mma(
a,
b,
a_swizzle,
_wgmma.WGMMALayout.ROW_MAJOR,
_wgmma.WGMMALayout.COL_MAJOR,
descriptor_const_init=TCGEN05_SMEM_DESCRIPTOR_BIT,
)
if m_tiling != 128:
raise ValueError(f"A must have rows tiled by 128, got: {m_tiling}")
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * utils.bytewidth(element_type)
groups_k = k // kn_tiling
groups_m = m // m_tiling
# TODO(apaszke): Verify ACC shape.
i64 = ir.IntegerType.get_signless(64)
for mi in range(groups_m):
for ki in range(groups_k):
a_mk = arith.addi(
a_desc_base,
utils.c(_wgmma.wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
)
b_k = arith.addi(b_desc_base, utils.c(_wgmma.wgmma_encode(ki * b_k_byte_stride), i64))
accumulate = _do_mma(
d,
a_mk,
b_k,
d_type=ir.F32Type.get(),
m=m_tiling,
**mma_params,
accumulate=accumulate,
)
def _do_mma(
d_addr: ir.Value,
a_desc: ir.Value,
b_desc: ir.Value,
a_transpose: bool,
b_transpose: bool,
a_k_stride: int,
b_k_stride: int,
m: int,
n: int,
swizzle: int,
element_type: ir.Type,
d_type: ir.Type,
accumulate: ir.Value,
):
i1 = ir.IntegerType.get_signless(1)
i64 = ir.IntegerType.get_signless(64)
kn_tiling = swizzle // utils.bytewidth(element_type)
instr_k = 32 // utils.bytewidth(element_type)
if a_k_stride % 16 or b_k_stride % 16:
raise ValueError
i_desc = create_instr_descriptor(
m, n, d_type, element_type, a_transpose, b_transpose
)
for _ in range(kn_tiling // instr_k):
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[d_addr, a_desc, b_desc, i_desc, accumulate],
f"tcgen05.mma.cta_group::1.kind::{element_type} [$0], $1, $2, $3, $4;",
"r,l,l,r,b",
has_side_effects=True,
)
accumulate = arith.constant(i1, 1)
a_desc = arith.addi(a_desc, arith.constant(i64, a_k_stride >> 4))
b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4))
return accumulate
def commit_arrive(barrier):
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[barrier],
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];",
"l",
has_side_effects=True
)
def tmem_alloc(tmem_addr, ncols: int):
if ncols.bit_count() != 1 or not 32 <= ncols <= 512:
raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}")
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[tmem_addr],
f"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], {ncols};",
"r",
has_side_effects=True,
)
def tmem_relinquish_alloc_permit():
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[],
"tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;",
"",
has_side_effects=True,
)
def tmem_load(tmem_addr, shape, num):
if num.bit_count() != 1 or num > 128:
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
match shape:
case "16x128b":
num_out_regs = 2
case "16x256b":
num_out_regs = 4
case _:
raise NotImplementedError(f"{shape=} is unsupported")
num_out_regs *= num
i32 = ir.IntegerType.get_signless(32)
out_regs = ",".join("$" + str(i) for i in range(num_out_regs))
regs = llvm.inline_asm(
ir.Type.parse(
"!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
),
[tmem_addr],
f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];",
"=r," * num_out_regs + "r",
has_side_effects=True,
)
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
class TMEMLayout(enum.Enum):
"""Layout of the array in TMEM.
See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-organization
"""
D = "D"
@dataclasses.dataclass(frozen=True)
class TMEMRef:
address: ir.Value
layout: TMEMLayout
num_cols: int
dtype: ir.Type
@classmethod
def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int, dtype: ir.Type):
i32 = ir.IntegerType.get_signless(32)
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
if addr_ref_ty.memory_space != smem:
raise ValueError(f"tmem_addr_ref must be in workgroup memory, got: {addr_ref_ty}")
if addr_ref_ty.element_type != i32:
raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
tmem_addr = memref.load(tmem_addr_ref, [arith.ConstantOp.create_index(0)])
# TODO: Do we have to do this??
# warp_idx = utils.warp_idx(sync=False)
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
return cls(tmem_addr, layout, num_cols, dtype)
@property
def num_rows(self):
match self.layout:
case TMEMLayout.D:
return 128
case _:
raise NotImplementedError(self.layout)
@property
def shape(self):
return (self.num_rows, self.num_cols)
def __getitem__(self, *idxs):
i32 = ir.IntegerType.get_signless(32)
base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
if any(is_squeezed):
raise ValueError("TMEM loads only support slicing")
if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape:
raise NotImplementedError("Slicing of TMEM not impelmented yet")
if self.layout != TMEMLayout.D:
raise NotImplementedError(self.layout)
if self.num_cols % 8:
raise NotImplementedError
if self.dtype != ir.F32Type.get():
raise NotImplementedError(self.dtype)
layout = _m128_256bit_32bit_layout(self.shape)
regs_shape = layout.registers_shape(self.shape)
num = self.num_cols // 8
registers = np.empty(regs_shape, dtype=object)
# We load 16 lanes at a time, but need 32 in total.
for row_group in range(2):
addr = arith.addi(self.address, arith.constant(i32, (row_group * 16) << 16))
regs = tmem_load(addr, "16x256b", num)
regs = [llvm.bitcast(self.dtype, r) for r in regs]
vector_regs = []
undef = llvm.mlir_undef(ir.VectorType.get((2,), self.dtype))
for r_low, r_high in zip(regs[::2], regs[1::2]):
high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32))
vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32))
vector_regs.append(vreg)
# Dimension 4 is the one where we split 32 rows into tiles of 8.
regs_slice = [slice(None)] * 4 + [slice(row_group * 2, (row_group + 1) * 2)]
registers[*regs_slice] = np.asarray(vector_regs, dtype=object).reshape(registers[*regs_slice].shape)
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
def _m128_256bit_32bit_layout(shape: tuple[int, ...]):
"""Returns a tiled layout that is easy to relayout to WGMMA layout after doubling the bitwidth."""
if len(shape) != 2:
raise ValueError(f"Shape {shape} is not 2D")
if shape[0] % 128 != 0 or shape[1] % 8 != 0:
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
return fa.TiledLayout(
fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))),
warp_dim=-8,
lane_dims=(-4, -3),
vector_dim=-1,
)

@ -17,6 +17,7 @@ import dataclasses
import enum
import functools
import itertools
from typing import Any
import jax
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
@ -100,6 +101,7 @@ def create_descriptor(
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
memory_space: int | None = None,
const_init: int = 0,
):
i64 = ir.IntegerType.get_signless(64)
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
@ -118,7 +120,8 @@ def create_descriptor(
)
# We ignore the offset
desc_const = (
(wgmma_encode(leading_byte_offset) << 16)
const_init
| (wgmma_encode(leading_byte_offset) << 16)
| (wgmma_encode(stride_byte_offset) << 32)
)
desc = llvm.or_(
@ -299,132 +302,218 @@ class WGMMALayout(enum.Enum):
COL_MAJOR = enum.auto()
def _validate_mma(
a: Any,
b: ir.Value,
swizzle: int,
a_layout: WGMMALayout,
b_layout: WGMMALayout,
descriptor_const_init: int = 0,
):
# We need swizzle >= 32 to ensure that our K tiling is larger than the MMA
# instruction's K width.
if swizzle < 32:
raise ValueError(f"Unsupported swizzle: {swizzle}")
# Get A type.
if a_in_smem := isinstance(a, ir.Value):
if not ir.MemRefType.isinstance(a.type):
raise ValueError(f"When A is an ir.Value, it must be a memref, got: {a.type}")
a_ty = ir.MemRefType(a.type)
a_element_type = a_ty.element_type
a_shape = tuple(a_ty.shape)
if a_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
raise ValueError("A must be in workgroup memory when it's a reference")
if len(a_shape) != 4:
raise ValueError(f"A must be 4D when it's a reference, got rank {len(a_shape)}")
elif hasattr(a, "shape") and hasattr(a, "mlir_dtype"):
a_element_type = a.mlir_dtype
a_shape = a.shape
else:
raise NotImplementedError(f"Unsupported A type: {type(a)}")
# Get B type (always a reference).
b_ty = ir.MemRefType(b.type)
if b_ty.rank != 4:
raise ValueError(f"B must be 4D, got rank {b_ty.rank}")
# Veirfy element types and compute the tiling.
if (element_type := a_element_type) != b_ty.element_type:
raise ValueError(
f"A and B must have the same element type, got: {a_element_type} and"
f" {b_ty.element_type}"
)
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
if element_type not in supported_types:
raise ValueError(a_element_type)
element_bytewidth = bytewidth(element_type)
kn_tiling = swizzle // element_bytewidth
# Verify the shape and strides of B are as expected.
k_tiles, n_tiles, k_tiling, n_tiling = b_ty.shape
if k_tiling != kn_tiling:
raise ValueError(b_ty.shape)
# Note that while this technically allows n to be smaller than kn_tile,
# the stride checks above will still enforce that the memory region is padded.
# It might be possible to relax that requirement, but I haven't tested it.
if n_tiling > kn_tiling and n_tiling % kn_tiling:
raise ValueError(n_tiling, kn_tiling)
k = k_tiles * kn_tiling
n = n_tiles * n_tiling
b_strides, _ = b_ty.get_strides_and_offset()
b_byte_strides = [s * element_bytewidth for s in b_strides]
b_k_byte_stride = b_byte_strides[0]
if b_byte_strides[1] != swizzle * kn_tiling:
raise ValueError(b_byte_strides)
if b_byte_strides[2:] == [swizzle, element_bytewidth]:
b_order = WGMMALayout.ROW_MAJOR
elif b_byte_strides[2:] == [element_bytewidth, swizzle]:
b_order = WGMMALayout.COL_MAJOR
else:
raise ValueError(b_byte_strides)
# Verify the shape and strides of A are as expected.
if not a_in_smem:
m = a_shape[0]
a_order = m_tiling = None
else:
a_ty = ir.MemRefType(a.type)
m_tiles, k_tiles, m_tiling, k_tiling = a_ty.shape
m = m_tiles * m_tiling
if k_tiling != kn_tiling or k_tiles * k_tiling != k:
raise ValueError(a_ty.shape)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_byte_strides = [s * element_bytewidth for s in a_strides]
if a_byte_strides[2:] == [swizzle, element_bytewidth]:
a_order = WGMMALayout.ROW_MAJOR
elif a_byte_strides[2:] == [element_bytewidth, swizzle]:
a_order = WGMMALayout.COL_MAJOR
else:
raise ValueError(a_byte_strides)
if a_order == WGMMALayout.COL_MAJOR and swizzle != 128:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
tnsp_lbo = swizzle * (swizzle // 32)
sbo = swizzle // 2
a_desc_fields = dict(
leading_byte_offset=(1 if a_order == a_layout else tnsp_lbo) << 4,
stride_byte_offset=sbo << 4,
swizzle=swizzle,
memory_space=3,
)
b_desc_fields = dict(
leading_byte_offset=(1 if b_order == b_layout else tnsp_lbo) << 4,
stride_byte_offset=sbo << 4,
swizzle=swizzle,
memory_space=3,
)
wgmma_params = dict(
a_transpose=a_order != a_layout,
b_transpose=b_order != b_layout,
a_k_stride=(2 if a_order == a_layout else swizzle) << 4,
b_k_stride=(2 if b_order == b_layout else swizzle) << 4,
n=n,
swizzle=swizzle,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
else element_type,
)
if not a_in_smem:
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
a_k_byte_stride = a_desc_base = None
else:
a_desc_base = create_descriptor(
a, **a_desc_fields, const_init=descriptor_const_init
)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_k_byte_stride = a_strides[1] * element_bytewidth
b_desc_base = create_descriptor(
b, **b_desc_fields, const_init=descriptor_const_init
)
return (
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, kn_tiling),
element_type,
wgmma_params,
a_k_byte_stride,
b_k_byte_stride,
)
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
# transpositions from memref strides.
def wgmma(
acc: WGMMAAccumulator,
a,
b,
a: fa.FragmentedArray | ir.Value,
b: ir.Value,
*,
swizzle: int = 128,
# Order only applies within each tile!
a_order: WGMMALayout | None = None,
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
swizzle: int = 128,
):
if a_in_regs := isinstance(a, fa.FragmentedArray):
a_element_type = a.mlir_dtype
a_shape = a.shape
else:
a_ty = ir.MemRefType(a.type)
a_element_type = a_ty.element_type
a_shape = a_ty.shape
b_ty = ir.MemRefType(b.type)
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
if a_element_type not in supported_types:
raise ValueError(a_element_type)
if b_ty.element_type not in supported_types:
raise ValueError(b_ty.element_type)
if (element_type := a_element_type) != b_ty.element_type:
raise ValueError
element_bytewidth = bytewidth(element_type)
kn_tile = swizzle // element_bytewidth
"""Perform acc += a @ b using the WGMMA instruction.
groups_k, groups_n = b_ty.shape[:2]
k_group_size, n_group_size = (
b_ty.shape[2:] if b_order == WGMMALayout.ROW_MAJOR else b_ty.shape[:1:-1]
)
# Note that while this technically allows n to be smaller than kn_tile,
# the stride checks below will still enforce that the memory region is padded.
# It might be possible to relax that requirement, but I haven't tested it.
if n_group_size > kn_tile and n_group_size % kn_tile:
raise ValueError(n_group_size, kn_tile)
if k_group_size != kn_tile:
raise ValueError(b_ty.shape)
The expected memref shapes are:
a: (m, k, 64, S)
b: (k, n, S, S)
where S = swizzle // bytewidth(element_type).
The refs must be contiguous or be contiguous except for having their two minor
dimensions swapped.
"""
a_in_regs = isinstance(a, fa.FragmentedArray)
if not a_in_regs and not ir.MemRefType.isinstance(a.type):
raise ValueError(f"Unsupported A type: {type(a)}")
if not ir.MemRefType.isinstance(b.type):
raise ValueError(f"B must be a memref, got: {b.type}")
(
a_desc_base,
b_desc_base,
(m, k, n),
(m_tiling, kn_tiling),
element_type,
wgmma_params,
a_k_byte_stride,
b_k_byte_stride,
) = _validate_mma(a, b, swizzle, WGMMALayout.ROW_MAJOR, WGMMALayout.COL_MAJOR)
if a_in_regs:
if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
raise ValueError(a_element_type)
if a_shape[0] % 64 or a_shape[1] % kn_tile:
raise ValueError(a_shape)
if a_shape[1] // kn_tile != groups_k:
raise ValueError(a_shape[1] // kn_tile, groups_k)
groups_m = a_shape[0] // 64
if a_order is not None:
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
raise ValueError(
"a_order can only be specified when A is in shared memory"
f"Only 16-bit dtypes supported for A in registers, got {a.mlir_dtype}"
)
if a.shape[0] % 64:
raise ValueError(f"m must be a multiple of 64, got: {a.shape[0]}")
a_m_byte_stride = None
else:
groups_m = a_shape[0]
if a_shape[1] != groups_k:
raise ValueError(a_shape[1], groups_k)
if a_shape[2:] != [64, kn_tile]:
raise ValueError(a_shape)
if a_order is None:
a_order = WGMMALayout.ROW_MAJOR
if m_tiling != 64:
raise ValueError(f"A must have rows tiled by 64, got: {m_tiling}")
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_m_byte_stride = a_strides[0] * bytewidth(element_type)
if a_order == WGMMALayout.COL_MAJOR and swizzle != 128:
# Not sure what the layout is like, since the tiles aren't square.
raise NotImplementedError
expected_acc_shape = (groups_m * 64, groups_n * n_group_size)
groups_k = k // kn_tiling
groups_m = m // 64
expected_acc_shape = (groups_m * 64, n)
if acc.value.shape != expected_acc_shape:
raise ValueError(
f"Accumulator shape mismatch: expected {expected_acc_shape}, got"
f" {acc.value.shape}"
)
row_major = WGMMALayout.ROW_MAJOR
col_major = WGMMALayout.COL_MAJOR
tnsp_lbo = swizzle * (swizzle // 32)
sbo = swizzle // 2
a_desc_fields = dict(
leading_byte_offset=(1 if a_order == row_major else tnsp_lbo) << 4,
stride_byte_offset=sbo << 4,
swizzle=swizzle,
memory_space=3,
)
b_desc_fields = dict(
leading_byte_offset=(tnsp_lbo if b_order == row_major else 1) << 4,
stride_byte_offset=sbo << 4,
swizzle=swizzle,
memory_space=3,
)
wgmma_params = dict(
a_transpose=a_order == col_major,
b_transpose=b_order == row_major,
a_k_stride=(2 if a_order == row_major else 128) << 4,
b_k_stride=(swizzle if b_order == row_major else 2) << 4,
n=(groups_n * n_group_size),
swizzle=swizzle,
element_type=ir.FloatTF32Type.get()
if ir.F32Type.isinstance(element_type)
else element_type,
)
if a_in_regs:
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
if a_in_regs:
a = wgmma_fence(a) # Make sure the registers are ready.
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
else:
a_desc_base = create_descriptor(a, **a_desc_fields)
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
a_byte_strides = [s * element_bytewidth for s in a_strides]
a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
if a_byte_strides[2:] != [swizzle, element_bytewidth]:
raise ValueError(a_byte_strides)
b_desc_base = create_descriptor(b, **b_desc_fields)
b_strides, _ = b_ty.get_strides_and_offset()
b_byte_strides = [s * element_bytewidth for s in b_strides]
b_k_byte_stride = b_byte_strides[0]
if b_byte_strides[1:] != [swizzle * kn_tile, swizzle, element_bytewidth]:
raise ValueError(b_byte_strides)
i64 = ir.IntegerType.get_signless(64)
new_acc_regs = acc.value.registers.copy()
for mi in range(groups_m):
for ki in range(groups_k):
if a_in_regs:
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tiling : (ki + 1) * kn_tiling]
else:
a_mk = llvm_add(
a_desc_base,

@ -23,9 +23,9 @@ from jax._src.pallas.core import BlockSpec as BlockSpec
from jax._src.pallas.core import CompilerParams as CompilerParams
from jax._src.pallas.core import core_map as core_map
from jax._src.pallas.core import CostEstimate as CostEstimate
from jax._src.pallas.core import lower_as_mlir as lower_as_mlir
from jax._src.pallas.core import GridSpec as GridSpec
from jax._src.pallas.core import IndexingMode as IndexingMode
from jax._src.pallas.core import lower_as_mlir as lower_as_mlir
from jax._src.pallas.core import MemoryRef as MemoryRef
from jax._src.pallas.core import MemorySpace as MemorySpace
from jax._src.pallas.core import no_block_spec as no_block_spec
@ -34,6 +34,7 @@ from jax._src.pallas.core import unblocked as unblocked
from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost
from jax._src.pallas.helpers import empty as empty
from jax._src.pallas.helpers import empty_like as empty_like
from jax._src.pallas.helpers import when as when
from jax._src.pallas.pallas_call import pallas_call as pallas_call
from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p
from jax._src.pallas.primitives import atomic_add as atomic_add
@ -57,7 +58,6 @@ from jax._src.pallas.primitives import swap as swap
from jax._src.pallas.utils import cdiv as cdiv
from jax._src.pallas.utils import next_power_of_2 as next_power_of_2
from jax._src.pallas.utils import strides_from_shape as strides_from_shape
from jax._src.pallas.utils import when as when
from jax._src.state.discharge import run_state as run_state
from jax._src.state.indexing import ds as ds
from jax._src.state.indexing import dslice as dslice

@ -25,6 +25,8 @@ from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy
from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier
from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef

@ -456,6 +456,9 @@ MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive):
multiple_results = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun, *args = fun_and_args
return trace.process_shard_map(shard_map_p, fun, args, **params)
@ -1160,7 +1163,8 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p,
control_flow.loops.cumprod_p, control_flow.loops.cummax_p,
control_flow.loops.cummin_p, pjit.sharding_constraint_p]:
control_flow.loops.cummin_p, pjit.sharding_constraint_p,
pjit.mesh_cast_p]:
register_standard_check(p)
register_standard_rewrite(p)
@ -1715,7 +1719,9 @@ def _partial_eval_jaxpr_custom_rule(
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
mesh = eqn.params['mesh']
with (core.extend_axis_env_nd(mesh.shape.items()),
set_abstract_mesh(_as_manual_mesh(mesh))):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)

@ -42,7 +42,8 @@ def generate_sourcemaps(
with tempfile.TemporaryDirectory() as work_dir:
for pass_to_eval in passes:
if pass_to_eval.compile_fn not in compile_cache:
pass_work_dir = os.path.join(work_dir, pass_to_eval.name)
dirname = pass_to_eval.name.replace(":", "__")
pass_work_dir = os.path.join(work_dir, dirname)
os.makedirs(pass_work_dir, exist_ok=False)
compile_result = pass_to_eval.compile_fn(
pass_work_dir, f, args, kwargs

@ -24,7 +24,8 @@ LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")
SRC_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) ="
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)\)"
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)"
r"( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)"
)
SCOPED_REGEX = re.compile(

@ -86,9 +86,9 @@ from jax._src.interpreters.partial_eval import (
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
def trace_to_jaxpr_dynamic(fun, in_avals, *, keep_inputs=None): # noqa
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
fun, in_avals, debug_info, keep_inputs=keep_inputs)
fun, in_avals, keep_inputs=keep_inputs)
return jaxpr, out_avals, consts

@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_strict_library",
)
licenses(["notice"])
@ -46,8 +45,3 @@ py_library(
"//jax/experimental/jax2tf",
] + py_deps("tensorflow_core"),
)
pytype_strict_library(
name = "build_utils",
srcs = ["build_utils.py"],
)

@ -35,6 +35,8 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
@ -71,16 +73,23 @@ def _get_version_for_build() -> str:
"""Determine the version at build time.
The returned version string depends on which environment variables are set:
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
wheel build rule.
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
return _version
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)

@ -18,6 +18,7 @@ import logging
import os
import pathlib
from jax._src.lib import triton
from jax._src.lib import xla_client
import jax._src.xla_bridge as xb
@ -99,5 +100,11 @@ def initialize():
cuda_plugin_extension.register_custom_type_id, c_api
),
)
triton.register_compilation_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.compile_triton_to_asm, c_api
),
)
else:
logger.warning('cuda_plugin_extension is not found.')

@ -234,8 +234,8 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@nanobind",
"@tsl//tsl/platform:statusor",
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
@ -243,6 +243,7 @@ cc_library(
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
],

@ -41,8 +41,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_cholesky_update", CholeskyUpdate,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA",
GetrfFfi);

@ -23,15 +23,10 @@ namespace {
namespace nb = nanobind;
std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
nb::dict Registrations() {
nb::dict dict;
dict[JAX_GPU_PREFIX "_threefry2x32_ffi"] =
EncapsulateFfiHandler(ThreeFry2x32Ffi);
// TODO(b/338022728): remove after 6 months
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
return dict;
}

@ -33,29 +33,6 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = xla::ffi;
namespace {
// TODO(b/338022728): old custom call target, remove after 6 months
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchThreeFry2x32Kernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
} // namespace
// TODO(b/338022728): remove after 6 months
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
namespace {
ffi::Error ThreeFry2x32Impl(gpuStream_t stream,

@ -121,35 +121,6 @@ void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
out1, n, nullptr);
}
// TODO(b/338022728): remove after 6 months
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
std::array<const std::uint32_t*, 2> data;
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::int64_t n = descriptor.n;
int output_idx = 4;
std::int64_t* n_ptr = nullptr;
if (n < 0) {
// n is an operand in device memory.
n_ptr = reinterpret_cast<std::int64_t*>(buffers[4]);
output_idx = 5;
}
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[output_idx]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
const int block_dim = 128;
const std::int64_t grid_dim =
n < 0 ? 32
: std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], n, n_ptr);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

@ -26,19 +26,6 @@ limitations under the License.
namespace jax {
namespace JAX_GPU_NAMESPACE {
// TODO(b/338022728): remove after 6 months
struct ThreeFry2x32Descriptor {
std::int64_t n; // If -1 then the length is passed as a 5th operand
};
// TODO(b/338022728): remove after 6 months
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor);
// TODO(b/338022728): remove after 6 months
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void LaunchThreeFry2x32KernelFfi(gpuStream_t stream,
std::int64_t n,
std::uint32_t *keys0, std::uint32_t *keys1,

@ -16,23 +16,28 @@ limitations under the License.
#include "jaxlib/gpu_plugin_extension.h"
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_triton_extension.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/py_client_gpu.h"
#include "xla/tsl/python/lib/core/numpy.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"
namespace nb = nanobind;
@ -40,6 +45,44 @@ namespace xla {
namespace {
struct TritonCompilationResult {
std::string asm_text;
int64_t smem_bytes;
int cluster_dim_x;
int cluster_dim_y;
int cluster_dim_z;
};
absl::StatusOr<TritonCompilationResult> CompileTritonToASM(
const PJRT_Api* c_api, absl::string_view module,
absl::string_view arch_name, int num_warps, int num_ctas, int num_stages) {
const PJRT_Triton_Extension* triton_ext =
pjrt::FindExtension<PJRT_Triton_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Triton);
if (triton_ext == nullptr) {
return Unimplemented("The plugin does not have a Triton extension.");
}
PJRT_Triton_Compile_Args args;
args.struct_size = PJRT_Triton_Compile_Args_STRUCT_SIZE;
args.module = module.data();
args.module_size = module.size();
args.arch_name = arch_name.data();
args.arch_name_size = arch_name.size();
args.num_warps = num_warps;
args.num_ctas = num_ctas;
args.num_stages = num_stages;
RETURN_STATUS_IF_PJRT_ERROR(triton_ext->compile(&args), c_api);
auto asm_text = std::string(args.out_asm, args.out_asm_size);
delete[] args.out_asm;
return TritonCompilationResult{
.asm_text = asm_text,
.smem_bytes = args.out_smem_bytes,
.cluster_dim_x = args.out_cluster_dim_x,
.cluster_dim_y = args.out_cluster_dim_y,
.cluster_dim_z = args.out_cluster_dim_z,
};
}
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
const char* fn_name_c_str,
size_t fn_name_size, nb::object fn,
@ -170,6 +213,24 @@ nb::dict Registrations() {
void BuildGpuPluginExtension(nanobind::module_& m) {
tsl::ImportNumpy();
nb::class_<TritonCompilationResult>(m, "TritonCompilationResult")
.def_ro("asm", &TritonCompilationResult::asm_text)
.def_ro("smem_bytes", &TritonCompilationResult::smem_bytes)
.def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x)
.def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y)
.def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z);
m.def("compile_triton_to_asm",
[](nb::capsule c_api, nb::bytes module, absl::string_view arch_name,
int num_warps, int num_ctas, int num_stages) {
return xla::ValueOrThrow(CompileTritonToASM(
static_cast<const PJRT_Api*>(c_api.data()),
absl::string_view(static_cast<const char*>(module.data()),
module.size()),
arch_name, num_warps, num_ctas, num_stages));
});
m.def(
"register_custom_call_target",
[](nb::capsule c_api, nb::object fn_name_py, nb::object fn,

@ -36,10 +36,8 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
# TODO(b/338022728): remove after 6 months, always api_version=1
api_version = 1 if "_ffi" in _name else 0
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
api_version=1)
for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
try:

@ -14,7 +14,10 @@
"""Bazel macros used by the JAX build."""
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = []
jax_test_util_visibility = []
loops_visibility = []
PLATFORM_TAGS_DICT = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}
# TODO(vam): remove this once zstandard builds against Python 3.13
def get_zstandard():
if HERMETIC_PYTHON_VERSION == "3.13":
@ -106,6 +118,7 @@ def jax_visibility(_target):
return []
jax_extra_deps = []
jax_gpu_support_deps = []
jax2tf_deps = []
def pytype_library(name, pytype_srcs = None, **kwargs):
@ -208,7 +221,7 @@ def if_building_jaxlib(
"@pypi_jax_cuda12_pjrt//:pkg",
],
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"]):
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
@ -267,7 +280,7 @@ def jax_multiplatform_test(
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
@ -308,15 +321,60 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
if no_abi:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
)
def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
executable = ctx.executable.wheel_binary
output = ctx.actions.declare_directory(ctx.label.name)
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
" If you absolutely need to build links directly against the CUDA libraries, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")
env = {}
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if BUILD_TAG:
env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
env["JAX_RELEASE"] = "1"
cpu = ctx.attr.cpu
platform_name = ctx.attr.platform_name
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = ctx.attr.no_abi,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]
args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
args.add("--jaxlib_git_hash", git_hash) # required argument
if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
@ -335,11 +393,13 @@ def _jax_wheel_impl(ctx):
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
inputs = [],
outputs = [output_file],
executable = executable,
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]
return [DefaultInfo(files = depset(direct = [output_file]))]
_jax_wheel = rule(
attrs = {
@ -349,19 +409,25 @@ _jax_wheel = rule(
# b/365588895 Investigate cfg = "exec" for multi platform builds
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
},
implementation = _jax_wheel_impl,
executable = False,
)
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
@ -369,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
Args:
name: the name of the wheel
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel
@ -378,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
_jax_wheel(
name = name,
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.
platform_name = select({
"@platforms//os:osx": "Darwin",
"@platforms//os:macos": "Darwin",
"@platforms//os:windows": "Windows",
"@platforms//os:linux": "Linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",

Some files were not shown because too many files have changed in this diff Show More