Merge branch 'jax-ml:main' into main

This commit is contained in:
Michael Hudgins 2024-12-13 14:30:30 -05:00 committed by GitHub
commit 4de58e1af7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
223 changed files with 7267 additions and 3399 deletions

View File

@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL).
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
build:mkl_aarch64_threadpool --@compute_library//:openmp=false
build:mkl_aarch64_threadpool -c opt
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
@ -104,6 +109,8 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
build:clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags
# Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1

View File

@ -11,6 +11,9 @@ on:
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
@ -21,15 +24,18 @@ jobs:
if: github.event.repository.fork == false
strategy:
matrix:
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"]
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
enable-x_64: [1, 0]
runs-on: ${{ matrix.runner }}
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }}
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
steps:
- uses: actions/checkout@v3

View File

@ -11,6 +11,9 @@ on:
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
@ -22,12 +25,16 @@ jobs:
strategy:
matrix:
runner: ["linux-x86-n2-16"]
enable-x_64: [1, 0]
runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
name: "Bazel single accelerator GPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
steps:
- uses: actions/checkout@v3

View File

@ -35,7 +35,7 @@ jobs:
with:
python-version: 3.11
- run: python -m pip install pre-commit
- uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
@ -79,7 +79,7 @@ jobs:
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@ -126,7 +126,7 @@ jobs:
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@ -169,7 +169,7 @@ jobs:
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@ -204,7 +204,7 @@ jobs:
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@ -245,7 +245,7 @@ jobs:
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}

View File

@ -38,11 +38,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install .[ci]
python -m pip install -r array-api-tests/requirements.txt
python -m pip install pytest-xdist -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.numpy
JAX_ENABLE_X64: 'true'
run: |
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt

View File

@ -36,7 +36,7 @@ repos:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib]
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
args: [--config=pyproject.toml]
- repo: https://github.com/mwouts/jupytext

View File

@ -10,7 +10,35 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->
## jax 0.4.36
## jax 0.4.38
* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
for information on the compatibility guarantees of these semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
## jax 0.4.37 (Dec 9, 2024)
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
* Bug fixes
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
* Fix a bug that will throw `index out of range` error in
{func}`jax.lax.while_loop` if the user register pytree node class with
different aux data for the flatten and flatten_with_path.
* Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
## jax 0.4.36 (Dec 5, 2024)
* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing

View File

@ -47,7 +47,7 @@ are instances of such transformations. Others are
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expect bugs and
This is a research project, not an official Google product. Expect
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Please help by trying it out, [reporting
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you

View File

@ -123,6 +123,15 @@ def add_global_arguments(parser: argparse.ArgumentParser):
help="Produce verbose output for debugging.",
)
parser.add_argument(
"--detailed_timestamped_log",
action="store_true",
help="""
Enable detailed logging of the Bazel command with timestamps. The logs
will be stored and can be accessed as artifacts.
""",
)
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
"""Adds all the arguments that applies to the artifact subcommands."""
@ -399,7 +408,7 @@ async def main():
else:
requirements_command.append("//build:requirements.update")
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run)
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
if result.return_code != 0:
raise RuntimeError(f"Command failed with return code {result.return_code}")
else:
@ -476,7 +485,10 @@ async def main():
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
wheel_build_command.append("--config=mkl_open_source_only")
if target_cpu == "aarch64":
wheel_build_command.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command.append("--config=mkl_open_source_only")
if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
@ -597,7 +609,7 @@ async def main():
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
# Exit with error if any wheel build fails.
if result.return_code != 0:
raise RuntimeError(f"Command failed with return code {result.return_code}")

View File

@ -75,7 +75,7 @@ class SubprocessExecutor:
"""
self.environment = environment or dict(os.environ)
async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult:
"""
Executes a subprocess command.
@ -96,14 +96,15 @@ class SubprocessExecutor:
process = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
env=self.environment,
)
await asyncio.gather(
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
)
if detailed_timestamped_log:
await asyncio.gather(
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
)
result.return_code = await process.wait()
result.end_time = datetime.datetime.now()

View File

@ -69,7 +69,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
fi
# Build the artifact.
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 26 KiB

View File

@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
autodidax
jep/index
jax_internal_api

View File

@ -689,22 +689,21 @@ minimization phase.
### Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using
You can find the up-to-date command to run doctests in
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml).
E.g., you can run:
```
pytest docs
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
```
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
function docstrings will run correctly. You can run this locally using, for example:
```
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py
```
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
## Type checking

View File

@ -70,3 +70,31 @@ Common causes of OOM failures
memory. Note however, that the algorithm is basic and you can often get better
trade-off between compute and memory by disabling the automatic remat pass and doing
it manually with `the jax.remat API <https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html>`_
Experimental features
---------------------
Features here are experimental and must be tried with caution.
``TF_GPU_ALLOCATOR=cuda_malloc_async``
This replace XLA's own BFC memory allocator with `cudaMallocAsync
<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`_.
This will remove the big fixed pre-allocation and use a memory pool that grows.
The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`.
The risk are:
- that memory fragmentation is different, so if you are close to the
limit, the exact OOM case due to fragmentation will be different.
- The allocation time won't be all paid at the start, but be incurred
when the memory pool need to be increased. So you could
experience less speed stability at the start and for benchmarks
it will be even more important to ignore the first few iterations.
The risks can be mitigated by pre-allocating a signigicant chunk and
still get the benefit of having a growing memory pool. This can be
done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1`
it will preallocate the same as what was allocatedy by
default. Otherwise, it is the size in bytes that you want to
preallocate.

View File

@ -44,11 +44,8 @@ example, we can add this to the top of a Python file:
```python
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
```
@ -58,9 +55,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
### Code generation flags
* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
softmax fusion, based on pattern-matching backed by Triton code generation.
The default value is False.
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
any GEMM that it supports. The default value is False.
@ -69,6 +63,20 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
* **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
schedulers to overlap asynchronous communication with computation efficiently.
The default value is False.
* **--xla_gpu_memory_limit_slop_factor** This flag serves as a multiplier applied
to the total available memory, creating a threshold that guides the Latency Hiding
Scheduler (LHS) in balancing memory reduction and latency hiding optimizations.
The default value is 95.
This factor effectively establishes a memory limit for compiler passes, determining
when the scheduler should prioritize:
1. Memory reduction: When memory usage approaches or exceeds the calculated threshold.
2. Latency hiding: When memory usage is below the threshold, allowing for more
aggressive optimizations that may temporarily increase memory usage but improve
overall performance.
By adjusting this factor, users can fine-tune the trade-off between memory efficiency
and performance optimizations.
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
i-th layer computation. It also enables overlapping (i+1)-th layer

View File

@ -253,18 +253,14 @@ simply run:
conda install jax -c conda-forge
```
To install it on a machine with an NVIDIA GPU, run:
If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`.
To ensure that the jax version you are installing is indeed CUDA-enabled, run:
```bash
conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia
conda install "jaxlib=*=*cuda*" jax -c conda-forge
```
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
JAX requires. You must therefore either install the `cuda-nvcc` package from
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
is in your path. The channel order above is important (`conda-forge` before
`nvidia`).
If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the
[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)

View File

@ -14,8 +14,11 @@ Classes
.. autosummary::
:toctree: _autosummary
Exported
DisabledSafetyCheck
.. autoclass:: Exported
:members:
.. autoclass:: DisabledSafetyCheck
:members:
Functions
---------

18
docs/jax.extend.core.rst Normal file
View File

@ -0,0 +1,18 @@
``jax.extend.core`` module
==========================
.. automodule:: jax.extend.core
.. autosummary::
:toctree: _autosummary
ClosedJaxpr
Jaxpr
JaxprEqn
Literal
Primitive
Token
Var
array_types
jaxpr_as_fun
primitives

View File

@ -11,6 +11,7 @@ Modules
.. toctree::
:maxdepth: 1
jax.extend.core
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir

View File

@ -11,7 +11,6 @@ jax.lib.xla_bridge
.. autosummary::
:toctree: _autosummary
default_backend
get_backend
get_compile_options

View File

@ -274,6 +274,7 @@ namespace; they are listed below.
mask_indices
matmul
matrix_transpose
matvec
max
maximum
mean
@ -428,6 +429,7 @@ namespace; they are listed below.
var
vdot
vecdot
vecmat
vectorize
vsplit
vstack

View File

@ -102,6 +102,9 @@ Automatic differentiation
closure_convert
checkpoint
Customization
-------------
``custom_jvp``
~~~~~~~~~~~~~~
@ -121,6 +124,16 @@ Automatic differentiation
custom_vjp
custom_vjp.defvjp
``custom_batching``
~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: _autosummary
custom_batching.custom_vmap
custom_batching.custom_vmap.def_vmap
custom_batching.sequential_vmap
jax.Array (:code:`jax.Array`)
-----------------------------

View File

@ -13,8 +13,11 @@ List of Functions
all
flatten
flatten_with_path
leaves
leaves_with_path
map
map_with_path
reduce
structure
transpose

View File

@ -1,14 +0,0 @@
Internal API reference
======================
core
----
.. currentmodule:: jax.core
.. automodule:: jax.core
.. autosummary::
:toctree: _autosummary
Jaxpr
ClosedJaxpr

View File

@ -11,7 +11,16 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->
## Released with jax 0.4.35
## Released with jax 0.4.37
* New functionality
* Added support for `DotAlgorithmPreset` precision arguments for `dot`
lowering on Triton backend.
## Released with jax 0.4.36 (December 6, 2024)
## Released with jax 0.4.35 (October 22, 2024)
* Removals

View File

@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results.
spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a
low-level compiler error message complaining about an out-of-memory error.
Dimension ordering is meaningful
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Array Layouts
^^^^^^^^^^^^^
Dimension ordering of arrays is meaningful in Pallas.
In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually
has no impact on performance, as the compiler is free to rearrange them.
However, as Pallas is meant to expose lower-level capabilities, the dimension
order can have great impact on the quality of generated code.
Recall that the TPUs perform bulk of the computation on 2D vector registers.
Pallas TPU will only ever consider mapping the last two dimensions of
intermediate arrays to those vector register dimensions (sublanes and lanes
respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least
``n`` vector registers to represent. If ``n`` becomes too large, this can lead
to spills, and potential VMEM OOM errors due to an overly large memory footprint.
But it also might not --- the low-level compiler is free to rearrange the
instructions to lower the register pressure, and is in fact very good at it.
Still, it is a good rule of thumb to keep the last two dimensions large
(especially the last dimension), while keeping the leading dimensions small.
TPUs perform bulk of the computation on 2D vector registers, which are typically of
size 8x128 for 32-bit values (as of TPU v6).
When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``),
the last two dimensions of the array will be tiled into the registers.
Pallas will only ever consider mapping the last two dimensions of
intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes
respectively).
Here is a graphical example of how a 12x320 array can be tiled using 6 8x128
tiles:
.. image:: ../../_static/pallas/vector_layout_example.svg
Tiled layouts have several import ramifications for kernel writers:
* The last two axes of an array are treated differently than other
axes. For example, reductions, reshapes, and transposes are generally
more expensive when involving the last two axes. Some reshapes
involving the last two dimensions are not supported and will result in a compiler
error, but are "free" and performed at compile time for other dimensions.
* While sometimes unavoidable, it is generally wasteful to have singleton
dimensions in the last two axes, since they will occupy 1 element out of
the entire tile dimension. Consuming too many registers can
also potentially cause register spills into VMEM which degrades kernel
performance.
* Related to the above point, all vector computation is padded up to the tile
size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and
adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two
8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128.
Multicore TPU configurations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will
receive not only the grid indices, but also the SMEM references to the leading
operands.
.. note::
We are working on implementing examples for this feature. Stay tuned!
See :ref:`pallas_scalar_prefetch_guide` for examples on using this
feature.
Supported data types
^^^^^^^^^^^^^^^^^^^^
At the moment Pallas TPU only supports the following data types:
At the moment Pallas TPU supports the following data types:
* ``jnp.float32``
* ``jnp.bfloat16``
* ``jnp.int*`` (all precisions, except for ``jnp.int4``)
* ``jnp.uint*`` (all precisions)
* ``jnp.bool_``
Computation placement
^^^^^^^^^^^^^^^^^^^^^
@ -306,14 +327,13 @@ Array constructors
^^^^^^^^^^^^^^^^^^
All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``,
``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with
Pallas as of today.
``jnp.full``).
Reductions
^^^^^^^^^^
Sum, maximum and minimum reductions are supported, but only on a single array
axis at a time.
``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well
as ``any`` and ``all`` for boolean values. Integer reductions are not supported.
Reductions over the last array dimension are generally the slowest.
Reductions over the second last dimension are faster, but still slower than
@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second
to last dimension, or (2) it adds a dimension that was just removed by a
reduction.
Random Number Generation
^^^^^^^^^^^^^^^^^^^^^^^^
Pallas supports the most commonly used functions from the ``jax.random`` module,
such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key,
which is the default setting in JAX. Keys can be directly passed into a kernel,
or generated inside of a kernel.
Control flow
^^^^^^^^^^^^

View File

@ -6,6 +6,8 @@
"id": "ZHuzXqQ-9JUQ"
},
"source": [
"(pallas_scalar_prefetch_guide)=\n",
"\n",
"# Scalar Prefetch and Block-Sparse Computation\n",
"\n",
"In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."

View File

@ -14,6 +14,8 @@ kernelspec:
+++ {"id": "ZHuzXqQ-9JUQ"}
(pallas_scalar_prefetch_guide)=
# Scalar Prefetch and Block-Sparse Computation
In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.

View File

@ -37,10 +37,10 @@ class AttrsTests(jtu.JaxTestCase):
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
self.assertEqual(count(), 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertEqual(count(), 0) # cache hit
def test_array_attr_no_jit(self):
with jax.disable_jit():

View File

@ -455,6 +455,7 @@ pytype_strict_library(
":dtypes",
":effects",
":mesh",
":partition_spec",
":pretty_printer",
":source_info_util",
":traceback_util",
@ -558,6 +559,7 @@ pytype_strict_library(
":layout",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":sharding",

View File

@ -28,7 +28,6 @@ ShapedArray = core.ShapedArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
dtypes.int4, np.int8, np.int16, np.int32, np.int64,

View File

@ -19,7 +19,7 @@ from typing import Any, TypeVar
from jax._src import core
from jax._src import traceback_util
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
from jax._src.core import Primitive, valid_jaxtype, get_aval
from jax._src.tree_util import register_pytree_node, tree_map
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
@ -51,7 +51,7 @@ def zeros_like_aval(aval: core.AbstractValue) -> Array:
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
def zeros_like_jaxval(val):
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))
return zeros_like_aval(core.get_aval(val))
def instantiate(z: Zero | Array) -> Array:
if isinstance(z, Zero):
@ -67,7 +67,7 @@ class Zero:
return f'Zero({self.aval})'
@staticmethod
def from_primal_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
return Zero(get_aval(val).to_tangent_aval())
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))

View File

@ -2356,7 +2356,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
f"len(devices) = {len(devices)}.")
def _device_put_sharded(*xs):
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
avals = [core.get_aval(x) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
if a1 != a2)
@ -2418,7 +2418,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
core.get_aval(x))
assert isinstance(aval, ShapedArray)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
if config.pmap_no_rank_reduction.value:

View File

@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation2
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
for i, arg in zip(_dyn_argnums, dyn_args):
args[i] = arg
fixed_args_ = iter(fixed_args)
fixed_args_ = iter(_fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
return f(*args, **kwargs)
return _fun(*args, **kwargs)
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
@lu.transformation2
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
return f(*args, **kwargs)
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
return _fun(*args, **kwargs)
@lru_cache(maxsize=4096)
@ -438,9 +438,9 @@ def flat_out_axes(
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
@lu.transformation_with_aux2
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
ans = f(*args, **kwargs)
spec = tree_unflatten(treedef, leaves)
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
ans = _fun(*args, **kwargs)
spec = tree_unflatten(_treedef, _leaves)
try:
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
except ValueError:
@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
"pmapped function's output.")
raise ValueError(msg) from None
store.store(spec_flat)
_store.store(spec_flat)
return ans
def check_callable(fun):
@ -587,8 +587,7 @@ def _dtype(x):
def _shaped_abstractify_slow(x):
try:
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
except TypeError:
pass
@ -687,10 +686,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
for path, l in generate_key_paths(x) if l is not static)
@lu.transformation_with_aux2
def result_paths(f, store, *args, **kwargs):
def result_paths(_fun, _store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = f(*args, **kwargs)
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
ans = _fun(*args, **kwargs)
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,

View File

@ -33,6 +33,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
@ -40,7 +41,6 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
bufs.append(buf)
break
else:
bufs.append(buf)
bufs.append(candidates_list[-1])
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
@ -1132,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
util.test_event("_array_shard_arg")
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
batch_cs = []
@ -1169,12 +1170,9 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
if xla_extension_version >= 296:
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
else:
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter
batch_xs, batch_devs, batch_shardings)
util.test_event("batched_copy_array")
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
for i, copy_out in safe_zip(batch_indices, copy_outs):
assert results[i] is None
results[i] = copy_out

View File

@ -24,7 +24,6 @@ from typing import cast as type_cast
from jax._src import config
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
import numpy as np
@ -301,10 +300,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
debug_options.xla_dump_hlo_as_long_text = False
debug_options.xla_dump_disable_metadata = False
debug_options.xla_dump_hlo_pipeline_re = ""
# "Requires jaxlib 0.4.36+"
if xla_extension_version > 296:
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
# Optional way to specify the cuda install path to be used by the compiler.
# This could possibly affect the cuda version compiled with, but this should

View File

@ -387,9 +387,6 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
error = _reduce_any_error(error)
return error, out_vals
def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> tuple[Error, list[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
@ -760,7 +757,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
# Get the error-effects out of all branches so the cond can be called with
# a merged error with all these effects.
err_vals, err_tree = jtu.tree_flatten(error)
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
def get_error_effects_from_jaxpr(jxpr):
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
*in_avals)
@ -770,7 +767,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
err_vals, err_tree = jtu.tree_flatten(merged_error)
# Update branch jaxprs to be checkified jaxprs.
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
new_branches, out_trees, _ = unzip3(
jaxpr_to_checkify_jaxpr(
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
@ -792,11 +789,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll, _split_transpose):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
# Query body effects to create a merged error containing all effects (such
# that in and out carried error are of the same type).
err_vals, err_tree = jtu.tree_flatten(error)
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
@ -804,7 +801,7 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
err_vals, err_tree = jtu.tree_flatten(merged_error)
# Create checked-jaxpr, with the needed pre-processing on the inputs.
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
@ -840,7 +837,7 @@ def checkify_while_body_jaxpr(
*body_jaxpr.in_avals])
closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(get_shaped_aval, err_vals)
err_vals = map(core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
@ -882,7 +879,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
cond_in_flat = [*err_vals, *c_consts, *carry]
cond_in_flat = map(get_shaped_aval, cond_in_flat)
cond_in_flat = map(core.get_aval, cond_in_flat)
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
err_tree, *cond_in_flat)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
@ -906,7 +903,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *in_avals)
@ -942,7 +939,7 @@ error_checks[pjit.pjit_p] = pjit_error_check
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
@ -963,7 +960,7 @@ def shard_map_error_check(
# Replicated sharding for in errors.
new_in_names = (*([{}] * num_error_vals), *in_names)
new_vals_in = [*err_vals, *vals_in]
in_avals = list(map(get_shaped_aval, new_vals_in))
in_avals = list(map(core.get_aval, new_vals_in))
for i, v in enumerate(in_avals):
if not (sharder := core.shard_aval_handlers.get(type(v))):
raise ValueError(f'Unsupported aval type: {type(v)}')

View File

@ -18,8 +18,6 @@ from __future__ import annotations
from collections.abc import Sequence
import logging
import os
import tempfile
import time
from typing import Any, Callable
import warnings
@ -36,7 +34,6 @@ from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np
@ -192,9 +189,8 @@ def get_compile_options(
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
if xla_extension_version >= 294:
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
@ -350,7 +346,33 @@ def compile_or_get_cached(
use_compilation_cache = compilation_cache.is_cache_used(backend)
is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1
)
min_device_process_id = min(
devices.flatten(), key=lambda device: device.id
).process_index
is_auto_pgle_used = (
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
)
if not use_compilation_cache:
if (
is_multi_process
and is_auto_pgle_used
and distributed.global_state.client is not None
):
compile_options.executable_build_options.fdo_profile = (
_share_fdo_profiles(
computation,
devices,
compile_options,
backend,
distributed.global_state.client,
min_device_process_id,
)
)
return backend_compile(backend, computation, compile_options,
host_callbacks)
@ -375,61 +397,18 @@ def compile_or_get_cached(
return backend_compile(backend, computation, compile_options,
host_callbacks)
is_multi_process = (
len({device.process_index for device in devices.flatten()}) > 1)
min_device_process_id = (
min(devices.flatten(), key=lambda device: device.id).process_index)
# When PGLE is enabled there might be 3 types of situations:
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
# in the persistent cache. In this case the module should be returned from
# cache and PGLE should be disabled for this module. Is module is stored in
# the persistent cache under the "pgle_profiled_module_key" which calculated
# with replacing FDO profile with flag which identify that module were PGLE
# profiled.
# 2. PGLE profiled module is not in the persistent cache and the module is
# getting built with an FDO profile. In this case we need to share FDO profile
# with other processes and store the result under the
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
# module.
# 3. PGLE profiled module is not in the persistent cache and the module is
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
# simply return the non-PGLE profiled module from the persistent cache.
if (config.enable_pgle.value
and config.pgle_profiling_runs.value > 0):
fdo_profile = compile_options.executable_build_options.fdo_profile
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
pgle_profiled_module_key = compilation_cache.get_cache_key(
if is_auto_pgle_used:
cache_key = _resolve_pgle_module_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
pgle_profiler,
is_multi_process,
cache_key,
module_name,
min_device_process_id,
)
compile_options.executable_build_options.fdo_profile = fdo_profile
if _is_executable_in_cache(backend, pgle_profiled_module_key):
# Load PGLE profiled module from the persistent cache.
cache_key = pgle_profiled_module_key
if pgle_profiler is not None:
pgle_profiler.disable()
elif fdo_profile is not None and len(fdo_profile) > 0:
# Store module under PGLE profiled module cache key.
cache_key = pgle_profiled_module_key
if is_multi_process and distributed.global_state.client is not None:
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
computation, devices, compile_options, backend,
distributed.global_state.client,
min_device_process_id
)
else:
compile_options.executable_build_options.fdo_profile = fdo_profile
logger.debug(
"Compiling module %s with FDO profile: %s",
module_name,
compile_options.executable_build_options.fdo_profile,
)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(
@ -468,22 +447,6 @@ def compile_or_get_cached(
cache_key,
min_device_process_id
)
elif (
config.share_autotune_config_between_hosts.value
and is_multi_process
and distributed.global_state.client is not None
):
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_autotune_config(
backend,
computation,
compile_options,
host_callbacks,
distributed.global_state.client,
module_name,
cache_key,
min_device_process_id
)
else:
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_cache(
@ -495,6 +458,75 @@ def compile_or_get_cached(
cache_key,
)
# When PGLE is enabled there might be 3 types of situations:
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
# in the persistent cache. In this case the module should be returned from
# cache and PGLE should be disabled for this module. Is module is stored in
# the persistent cache under the "pgle_profiled_module_key" which calculated
# with replacing FDO profile with flag which identify that module were PGLE
# profiled.
# 2. PGLE profiled module is not in the persistent cache and the module is
# getting built with an FDO profile. In this case we need to share FDO profile
# with other processes and store the result under the
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
# module.
# 3. PGLE profiled module is not in the persistent cache and the module is
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
# simply return the non-PGLE profiled module from the persistent cache.
def _resolve_pgle_module_cache_key(
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
backend: xc.Client,
pgle_profiler: profiler.PGLEProfiler | None,
is_multi_process: bool,
cache_key: str,
module_name: str,
min_device_process_id: int,
) -> str:
fdo_profile = compile_options.executable_build_options.fdo_profile
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
pgle_profiled_module_key = compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
compile_options.executable_build_options.fdo_profile = fdo_profile
result_key = cache_key
if _is_executable_in_cache(backend, pgle_profiled_module_key):
# Load PGLE profiled module from the persistent cache.
result_key = pgle_profiled_module_key
if pgle_profiler is not None:
pgle_profiler.disable()
elif fdo_profile is not None and len(fdo_profile) > 0:
# Store module under PGLE profiled module cache key.
result_key = pgle_profiled_module_key
if is_multi_process and distributed.global_state.client is not None:
compile_options.executable_build_options.fdo_profile = (
_share_fdo_profiles(
computation,
devices,
compile_options,
backend,
distributed.global_state.client,
min_device_process_id,
)
)
else:
compile_options.executable_build_options.fdo_profile = fdo_profile
logger.debug(
"Compiling module %s with FDO profile of length %d",
module_name,
len(compile_options.executable_build_options.fdo_profile),
)
return result_key
# The process that has the lowest device ID should share FDO profile before
# compilation with other processes.
def _share_fdo_profiles(
@ -512,32 +544,39 @@ def _share_fdo_profiles(
return fdo_profile
compile_options.executable_build_options.fdo_profile = b""
profile_key = (
compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
+ "_fdo_sync"
)
try:
profile_key = (
compilation_cache.get_cache_key(
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
+ "_fdo_sync"
)
except xc._xla.XlaRuntimeError as ex:
logger.error(
"compile_or_get_cached: unable to generate cache key, "
"skipping the fdo profile sharing: %s",
ex,
)
return fdo_profile
if profile_key in _share_fdo_profiles.modules_profiles:
return _share_fdo_profiles.modules_profiles[profile_key]
share_timeout = config.share_binary_between_hosts_timeout_ms.value
if distributed.global_state.process_id == min_process_id:
logger.debug(
"Sharing FDO profile: %s. For module %s. Process %d.",
fdo_profile,
"Module %s. Sharing FDO profile. Process %d.",
module_name,
min_process_id,
)
global_client.key_value_set_bytes(profile_key, fdo_profile)
else:
logger.debug(
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
fdo_profile,
"Module %s. Waiting for FDO profile which should be set by process %d.",
module_name,
min_process_id,
)
@ -551,113 +590,6 @@ def _share_fdo_profiles(
_share_fdo_profiles.modules_profiles = {}
# The process with the first_process_id should compile the module and write an
# autotune config to the K-V storage.
def _compile_and_write_autotune_config(
backend: xc.Client,
computation: ir.Module,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
global_client: lib.xla_extension.DistributedRuntimeClient,
module_name: str,
cache_key: str,
first_process_id: int
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value
debug_options = compile_options.executable_build_options.debug_options
if _compile_and_write_autotune_config.autotune_configs_dir is None:
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
autotune_tmp_file = os.path.join(
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
)
if os.path.exists(autotune_tmp_file):
logger.debug(
"Compiling module: %s. Use existing autotune config file: %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
return _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
if distributed.global_state.process_id == first_process_id:
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
logger.debug("Process %d compiling and dumping autotune for module: %s",
first_process_id, module_name)
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
logger.debug(
"Writing autotune config for module %s to %s",
module_name,
autotune_tmp_file,
)
with open(autotune_tmp_file, "rb") as f:
autotune_config = f.read()
autotune_config = compilation_cache.compress_executable(autotune_config)
global_client.key_value_set_bytes(cache_key, autotune_config)
logger.debug(
"Autotune config for module %s with size %d shared by cache_key %s",
module_name,
len(autotune_config),
cache_key,
)
else:
logger.debug(
"Compiling module %s, waiting for config to be shared by cache_key %s"
"from process %d",
module_name,
cache_key,
first_process_id
)
autotune_config = global_client.blocking_key_value_get_bytes(
cache_key, share_timeout
)
logger.debug(
"Received autotune config for module %s of size %d",
module_name,
len(autotune_config),
)
autotune_config = compilation_cache.decompress_executable(autotune_config)
with open(autotune_tmp_file, "wb") as f:
f.write(autotune_config)
logger.debug(
"Compiling module %s, using autotune config from %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
return executable
_compile_and_write_autotune_config.autotune_configs_dir = None
# The process with the first_process_id should compile the module and write it
# to the K-V storage.
def _compile_and_share_module(

View File

@ -14,30 +14,21 @@
from __future__ import annotations
from collections.abc import Callable, Hashable, Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
import contextlib
import functools
import itertools
import logging
import os
import sys
import threading
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast
from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast
from jax._src import lib
from jax._src.lib import guard_lib
from jax._src.lib import jax_jit
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import logging_config
# TODO(phawkins): reenable pytype after xla_extension_version >= 295
# pytype: skip-file
if xla_extension_version >= 295:
config_ext = xla_client._xla.config
else:
config_ext = None
config_ext = xla_client._xla.config
logger = logging.getLogger(__name__)
@ -200,91 +191,38 @@ class Config:
already_configured_with_absl = True
if xla_extension_version >= 295:
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
These values are included in the cache key for linear_util.cache.
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
return (axis_env_state.value, mesh_context_manager.value,
xla_metadata_context_manager.value,
abstract_mesh_context_manager.value,
device_context.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
use_direct_linearize.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
else:
def trace_context():
"""Returns a tuple of configuration values that affect tracing.
These values are included in the cache key for linear_util.cache.
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
abstract_mesh_context_manager = ()
device_context = ()
xla_metadata_context_manager = ()
compute_on_context_manager = ()
context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
if context and context.abstract_mesh_context_manager:
abstract_mesh_context_manager = context.abstract_mesh_context_manager
if context and context.device_context:
device_context = context.device_context
if context and context.xla_metadata_context_manager:
xla_metadata_context_manager = context.xla_metadata_context_manager
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
device_context, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
use_direct_linearize.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.
"""
return (axis_env_state.value, mesh_context_manager.value,
xla_metadata_context_manager.value,
abstract_mesh_context_manager.value,
device_context.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
eager_constant_folding.value,
numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
use_direct_linearize.value,
softmax_custom_jvp.value,
enable_memories.value,
disable_jit.value,
debug_key_reuse.value,
jax_xla_profile_version.value,
# Technically this affects jaxpr->stablehlo lowering, not tracing.
hlo_source_file_canonicalization_regex.value,
pgle_profiling_runs.value,
enable_pgle.value,
use_shardy_partitioner.value)
config = Config()
@ -296,185 +234,85 @@ parse_flags_with_absl = config.parse_flags_with_absl
class NoDefault: pass
no_default = NoDefault()
if xla_extension_version >= 295:
class State(config_ext.Config[_T]):
class State(config_ext.Config[_T]):
__slots__ = (
'_name', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
__slots__ = (
'_name', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
include_in_jit_key: bool = False,
):
super().__init__(default, include_in_jit_key)
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
if self._validator:
self._validator(default)
if self._update_global_hook:
self._update_global_hook(default)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
include_in_jit_key: bool = False,
):
super().__init__(default, include_in_jit_key)
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
if self._validator:
self._validator(default)
if self._update_global_hook:
self._update_global_hook(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self.set_global(value)
if self._update_global_hook:
self._update_global_hook(value)
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self.set_global(value)
if self._update_global_hook:
self._update_global_hook(value)
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = self.swap_local(new_val)
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = self.swap_local(new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
self.set_local(prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
self.set_local(prev_val)
if self._update_thread_local_hook:
if prev_val is config_ext.unset:
self._update_thread_local_hook(None)
else:
self._update_thread_local_hook(cast(Optional[Any], prev_val))
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self.get_global())
else:
class _Unset: pass
unset = _Unset()
_thread_local_state = threading.local()
class State(Generic[_T]): # type: ignore[no-redef]
__slots__ = (
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
'_validator', '_default_context_manager_value', '__doc__', '__name__',
)
def __init__(
self,
name: str,
default: _T,
help,
update_global_hook: Callable[[_T], None] | None = None,
update_thread_local_hook: Callable[[_T | None], None] | None = None,
validator: Callable[[Any], None] | None = None,
extra_description: str = '',
default_context_manager_value: Any = no_default,
include_in_jit_key: bool = False,
):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
if include_in_jit_key:
assert update_global_hook is None
assert update_thread_local_hook is None
update_global_hook = lambda val: _update_global_jit_state(
**{self.__name__: val})
update_thread_local_hook = lambda val: update_thread_local_jit_state(
**{self.__name__: val})
self._update_global_hook = update_global_hook
self._update_thread_local_hook = update_thread_local_hook
self._validator = validator
self._default_context_manager_value = default_context_manager_value
self._set(default)
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
def _set(self, value: _T) -> None:
if self._validator:
self._validator(value)
self._value = value
if self._update_global_hook:
self._update_global_hook(value)
@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
return cast(_T, val) if val is not unset else self._value
def get_local(self) -> Any:
return _thread_local_state.__dict__.get(self._name, unset)
@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
if self._default_context_manager_value is not no_default:
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
if prev_val is config_ext.unset:
self._update_thread_local_hook(None)
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validator:
self._validator(new_val)
prev_val = getattr(_thread_local_state, self._name, unset)
setattr(_thread_local_state, self._name, new_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, self._name)
if self._update_thread_local_hook:
self._update_thread_local_hook(None)
else:
setattr(_thread_local_state, self._name, prev_val)
if self._update_thread_local_hook:
self._update_thread_local_hook(cast(_T, prev_val))
self._update_thread_local_hook(cast(Optional[Any], prev_val))
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
def _add_hooks(self, update_global_hook, update_thread_local_hook):
"""Private method that adds hooks to an existing context-manager.
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self._value)
Used to avoid cyclic import dependencies."""
self._update_thread_local_hook = update_thread_local_hook
self._update_global_hook = update_global_hook
update_global_hook(self.get_global())
UPGRADE_BOOL_HELP = (
@ -975,132 +813,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
already_configured_with_absl = False
if xla_extension_version >= 295:
trace_state = config_ext.Config(None, include_in_jit_key=True)
axis_env_state = config_ext.Config((), include_in_jit_key=True)
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
device_context = config_ext.Config((), include_in_jit_key=True)
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
else:
# The C++ JIT maintains its own copy of several configuration items as
# a global/thread-local state. These methods allow updates to part of the
# state when a configuration value changes.
class _GlobalExtraJitContext(NamedTuple):
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool = False
eager_constant_folding: bool = False
random_seed_offset: int = 0
threefry_partitionable: bool = False
threefry_gpu_kernel_lowering: bool = False
sharding_in_types: bool = False
use_direct_linearize: bool = False
softmax_custom_jvp: bool = False
xla_profile_version: int = 0
pgle_profiling_runs: int = 0
enable_pgle: bool = False
use_shardy_partitioner: bool = False
def _update_global_jit_state(**kw):
gs = jax_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw)
class _ThreadLocalExtraJitContext(NamedTuple):
"""A namedtuple containing states to add to the cache key.
Just in time compilation (for jit, pmap, etc) behavior is configurable through
global and thread-local options, used in the cache key.
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
abstract_mesh_context_manager: Hashable = ()
device_context: Hashable = ()
compute_on_context_manager: Hashable = ()
xla_metadata_context_manager: Hashable = ()
# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
# restores these to None on exit. If the object default is not `None`, the
# context manager is not a no-op, which leads to problems with stale state
# (e.g. spurious cache misses in tests).
numpy_rank_promotion: str | None = None
numpy_dtype_promotion: str | None = None
default_matmul_precision: Any | None = None
dynamic_shapes: bool | None = None
eager_constant_folding : bool | None = None
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
threefry_gpu_kernel_lowering: bool | None = None
sharding_in_types: bool | None = None
use_direct_linearize: bool | None = None
softmax_custom_jvp: bool | None = None
xla_profile_version: int | None = None
pgle_profiling_runs: int | None = None
enable_pgle: bool | None = None
use_shardy_partitioner: bool | None = None
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to deduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
IDs match.
"""
def __init__(self):
self.canonicalize = functools.lru_cache(128)(lambda x: x)
_thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
# minimum jaxlib version is past version 70
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
tmp = context._replace(**kw)
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
class JitConfig:
def __init__(self, name):
self._name = name
@property
def value(self):
return self.get_local()
def get_local(self):
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
def set_local(self, value):
update_thread_local_jit_state(**{self._name: value})
def swap_local(self, new_value):
prev_value = self.value
self.set_local(new_value)
return prev_value
trace_state = JitConfig('trace_state')
axis_env_state = JitConfig('axis_env_state')
mesh_context_manager = JitConfig('mesh_context_manager')
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
device_context = JitConfig('device_context')
compute_on_context_manager = JitConfig('compute_on_context_manager')
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
trace_state = config_ext.Config(None, include_in_jit_key=True)
axis_env_state = config_ext.Config((), include_in_jit_key=True)
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
device_context = config_ext.Config((), include_in_jit_key=True)
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
# TODO(b/214340779): remove flag when XLA:CPU is improved.
@ -1254,10 +973,10 @@ pmap_shmap_merge = bool_state(
help='If True, pmap and shard_map API will be merged.')
def _update_jax_memories_global(val):
lib.jax_jit.global_state().enable_memories = val
jax_jit.global_state().enable_memories = val
def _update_jax_memories_thread_local(val):
lib.jax_jit.thread_local_state().enable_memories = val
jax_jit.thread_local_state().enable_memories = val
enable_memories = bool_state(
'jax_enable_memories',
@ -1450,20 +1169,6 @@ traceback_in_locations_limit = int_state(
),
)
share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
'If set to True, the coordinator process will share autotune configs '
'other participants. This will increase overall compilation time, but '
'will lead to equal compiled modules in each process. '
'If both jax_share_binary_between_hosts and '
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
"shared when it's possible and autotune config sharing will be used "
'as a fallback.'
),
)
share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts',
default=False,
@ -1576,10 +1281,10 @@ disallow_mesh_context_manager = bool_state(
)
def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val
jax_jit.global_state().enable_x64 = val
def _update_x64_thread_local(val):
lib.jax_jit.thread_local_state().enable_x64 = val
jax_jit.thread_local_state().enable_x64 = val
enable_x64 = bool_state(
name='jax_enable_x64',
@ -1594,11 +1299,11 @@ config._contextmanager_flags.remove('jax_enable_x64')
setattr(Config, "x64_enabled", property(lambda _: enable_x64.value))
def _update_default_device_global(val):
lib.jax_jit.global_state().default_device = val
jax_jit.global_state().default_device = val
def _update_default_device_thread_local(val):
lib.jax_jit.thread_local_state().default_device = val
jax_jit.thread_local_state().default_device = val
def _validate_default_device(val):
@ -1632,10 +1337,10 @@ default_device = string_or_object_state(
validator=_validate_default_device)
def _update_disable_jit_global(val):
lib.jax_jit.global_state().disable_jit = val
jax_jit.global_state().disable_jit = val
def _update_disable_jit_thread_local(val):
lib.jax_jit.thread_local_state().disable_jit = val
jax_jit.thread_local_state().disable_jit = val
disable_jit = bool_state(
name='jax_disable_jit',

View File

@ -39,6 +39,7 @@ from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@ -431,6 +432,8 @@ class Primitive:
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
# set for ref primitives
ref_primitive: bool = False
def __init__(self, name: str):
self.name = name
@ -1597,19 +1600,34 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
return sharding
new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)
def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
from jax._src.sharding_impls import NamedSharding # type: ignore
if sharding is not None:
assert len(sharding.spec) == ndim
return sharding
return _maybe_modify_sharding(sharding)
context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
# code.
if not context_mesh:
return None
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))
@ -1672,10 +1690,8 @@ class ShapedArray(UnshapedArray):
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
axis_types = self.sharding.mesh.axis_types
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
return f'{dt_str}[{shapestr}]{axt}'
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
@ -1687,26 +1703,13 @@ class ShapedArray(UnshapedArray):
raise TypeError("len() of unsized object") from err # same as numpy error
def _get_axis_type_str(axis_types):
from jax._src.mesh import AxisTypes # type: ignore
out = []
for t, axes in axis_types.items():
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
if t == AxisTypes.Collective:
out.append(f"C:{a}")
elif t == AxisTypes.User:
out.append(f"U:{a}")
else:
assert t == AxisTypes.Auto
out.append(f"A:{a}")
return f"{{{', '.join(out)}}}"
def _get_shape_sharding_str(shape, spec):
out = []
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
@ -1882,6 +1885,7 @@ pytype_aval_mappings[MutableArray] = lambda x: x._aval
def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')
mutable_array_p.ref_primitive = True
class InternalMutableArrayEffect(effects.Effect):
pass
@ -1899,6 +1903,18 @@ def _mutable_array_impl(init_val):
aval = get_aval(init_val)
return MutableArray(AbstractRef(aval), init_val)
def freeze(ref):
return freeze_p.bind(ref)
freeze_p = Primitive('freeze')
freeze_p.ref_primitive = True
@freeze_p.def_effectful_abstract_eval
def freeze_abstract_eval(ref_aval):
return ref_aval.inner_aval, {internal_mutable_array_effect}
@freeze_p.def_impl
def _freeze_impl(ref):
return ref[()]
class AbstractToken(AbstractValue):
def str_short(self, short_dtypes=False): return 'Tok'
@ -2516,10 +2532,11 @@ def _check_jaxpr(
# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
if prim is mutable_array_p:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
mut_arrays.add(outvar)
if prim.ref_primitive:
if prim is mutable_array_p:
outvar, = eqn.outvars
in_idx[outvar] = None # type: ignore
mut_arrays.add(outvar)
if eqn.effects != eqn_effects:
raise JaxprTypeError("Inferred effects do not match equation effects. "
f"Equation effects: {eqn.effects}. "
@ -2639,16 +2656,10 @@ def _check_call(ctx_factory, prim, in_atoms, params):
return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(substitute(v.aval), x.aval):
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
# on it.
if (config.sharding_in_types.value and isinstance(x, Literal) and
v.aval.sharding is not None and x.val.ndim == 0):
pass
else:
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
env[v] = x if type(x) is Var else x.val
_check_jaxpr(ctx_factory, call_jaxpr)

View File

@ -18,8 +18,8 @@ import json
import math
import jax
from jax import core
from jax import dtypes
from jax._src import core
from jax._src import dispatch
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching

View File

@ -14,7 +14,7 @@
import functools
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax.interpreters import mlir
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir

View File

@ -15,6 +15,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import functools
import operator
@ -48,17 +49,93 @@ zip, unsafe_zip = util.safe_zip, zip
@custom_api_util.register_custom_decorator_type
class custom_vmap:
fun: Callable
vmap_rule: Callable | None
"""Customize the vmap behavior of a JAX-transformable function.
def __init__(self, fun: Callable):
This decorator is used to customize the behavior of a JAX function under the
:func:`jax.vmap` transformation. A ``custom_vmap``-decorated function will
mostly (see below for caveats) have the same behavior as the underlying
function, except when batched using :py:func:`jax.vmap`. When batched, the
rule defined using :py:func:`~jax.custom_batching.custom_vmap.def_vmap` will
be used.
For example:
>>> @jax.custom_batching.custom_vmap
... def f(x, y):
... return x + y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
... assert all(in_batched)
... assert xs.shape[0] == axis_size
... assert ys.shape[0] == axis_size
... out_batched = True
... return xs * ys, out_batched
...
>>> xs = jnp.arange(3)
>>> ys = jnp.arange(1, 4)
>>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys
Array([0, 2, 6], dtype=int32)
Of note, ``custom_vmap`` functions do not support reverse-mode autodiff. To
customize both vmap and reverse-mode autodiff, combine ``custom_vmap`` with
:py:class:`jax.custom_vjp`. For example:
>>> @jax.custom_vjp
... @jax.custom_batching.custom_vmap
... def f(x, y):
... return jnp.sin(x) * y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
... return jnp.cos(xs) * ys, True
...
>>> def f_fwd(x, y):
... return f(x, y), (jnp.cos(x), jnp.sin(x), y)
...
>>> def f_bwd(res, g):
... cos_x, sin_x, y = res
... return (cos_x * g * y, sin_x * g)
...
>>> f.defvjp(f_fwd, f_bwd)
>>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3))
Array([1., 1., 1.], dtype=float32)
>>> jax.grad(f)(jnp.zeros(()), jnp.ones(()))
Array(1., dtype=float32)
Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the
``custom_vmap``-decorated function.
"""
fun: Callable[..., Any]
vmap_rule: Callable[..., tuple[Any, Any]] | None
def __init__(self, fun: Callable[..., Any]):
functools.update_wrapper(self, fun)
self.fun = fun
self.vmap_rule = None
__getattr__ = custom_api_util.forward_attr
def def_vmap(self, vmap_rule: Callable) -> Callable:
def def_vmap(
self,
vmap_rule: Callable[..., tuple[Any, Any]],
) -> Callable[..., tuple[Any, Any]]:
"""Define the vmap rule for this custom_vmap function.
Args:
vmap_rule: A function that implements the vmap rule. This function should
accept the following arguments: (1) an integer ``axis_size`` as its
first argument, (2) a pytree of booleans with the same structure as the
inputs to the function, specifying whether each argument is batched,
and (3) the batched arguments. It should return a tuple of the batched
output and a pytree of booleans with the same structure as the output,
specifying whether each output element is batched. See the documentation
for :py:func:`jax.custom_batching.custom_vmap` for some examples.
Returns:
This method passes the rule through, returning ``vmap_rule`` unchanged.
"""
self.vmap_rule = vmap_rule
return vmap_rule
@ -72,7 +149,7 @@ class custom_vmap:
"using def_vmap.")
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
@ -272,6 +349,31 @@ def tree_merge(mask, lhs_tree, rhs_tree):
mask, lhs_tree, rhs_tree)
def sequential_vmap(f):
"""A special case of ``custom_vmap`` that uses a loop.
A function decorated with ``sequential_vmap`` will be called sequentially
within a loop when batched. This is useful for functions that don't natively
support batch dimensions.
For example:
>>> @jax.custom_batching.sequential_vmap
... def f(x):
... jax.debug.print("{}", x)
... return x + 1
...
>>> jax.vmap(f)(jnp.arange(3))
0
1
2
Array([1, 2, 3], dtype=int32)
Where the print statements demonstrate that this :py:func:`~jax.vmap` is being
generated using a loop.
See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for
more details.
"""
f = custom_vmap(f)
@f.def_vmap

View File

@ -1051,7 +1051,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
from the closure.
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
in_avals = tuple(map(core.get_aval, flat_args))
if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
@ -1111,9 +1111,6 @@ def partition_list(choice, lst):
return [next(i2 if snd else i1) for snd in which]
return out, merge
def abstractify(x):
return core.get_aval(x)
### Custom transposition
@ -1209,8 +1206,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
f_in_tree = treedef_tuple((res_tree, lin_tree))
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
res_avals = map(abstractify, operands_res)
lin_avals = map(abstractify, operands_lin)
res_avals = map(core.get_aval, operands_res)
lin_avals = map(core.get_aval, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = f_jaxpr.out_avals

View File

@ -455,7 +455,7 @@ class custom_partitioning:
f_, dyn_args = lu.wrap_init(self.fun), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh

View File

@ -20,16 +20,124 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
_CompoundFactor = tuple[str, ...]
_DimMapping = tuple[str | _CompoundFactor, ...]
# A single character replacement for ... to simplify parsing.
_ELLIPSIS: str = ""
BATCHING: str = ""
# A prefix for names of batching dimension factors, used for expanding the
# leading ... into factors.
_BATCHING_DIM_FACTOR_PREFIX = "?"
def _check_factor(factor:str):
"""Validates a factor.
A factor is a string starting with a letter and containing only letters,
digits, or underscores.
"""
if not factor[0].isalpha():
raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'")
for char in factor[1:]:
if char != "_" and not char.isdigit() and not char.isalpha():
raise ValueError(f"Unknown character '{char}'")
class CompoundFactor(tuple):
"""Describes the factors for a compound factor.
A compound factor should contain at least two factors, e.g.
* CompoundFactor('b', 'c').
"""
def __init__(self, *factors):
if len(factors) < 2:
raise ValueError("A compound factor should contain at least two factors")
for factor in factors:
if not isinstance(factor, str):
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
if factor == BATCHING:
raise ValueError("Ellipsis can't be used in a compound factor")
else:
_check_factor(factor)
def __new__(cls, *factors):
return tuple.__new__(CompoundFactor, factors)
class ArrayMapping(tuple):
"""Describes the factors for an operand or result.
Each element is either a factor or a CompoundFactor. A leading element can
also be BATCHING, which represents batching dimensions. examples:
* ArrayMapping('a')
* ArrayMapping('b', 'c')
* ArrayMapping(CompoundFactor('b', 'c'), 'd')
* ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd')
"""
def __init__(self, *dim_mappings):
for i, d in enumerate(dim_mappings):
if not isinstance(d, str) and not isinstance(d, CompoundFactor):
raise ValueError(
"Each element of ArrayMapping must be a str or CompoundFactor, but"
f" got {type(d)}")
if isinstance(d, str):
if d == BATCHING:
if i != 0:
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
else:
_check_factor(d)
def __new__(cls, *dim_mappings):
return tuple.__new__(ArrayMapping, dim_mappings)
class SdyShardingRule:
"""Represents a Shardy sharding rule.
An SdyShardingRule contains the ArrayMappings for operands and results, and an
optional list of factor sizes. A factor is a name used in the ArrayMappings.
If a factor is only used in CompoundFactors, its size must be specified.
"""
operand_mappings: tuple[ArrayMapping, ...]
result_mappings: tuple[ArrayMapping, ...]
factor_sizes: dict[str, int]
def __init__(self, operand_mappings: tuple[ArrayMapping, ...],
result_mappings: tuple[ArrayMapping, ...], **factor_sizes):
# Find all factors and mark whether their size can be inferred.
factors_inferrable = dict()
for value in operand_mappings + result_mappings:
for dim in value:
if isinstance(dim, str):
factors_inferrable[dim] = True
else:
for factor in dim:
if factor not in factors_inferrable.keys():
factors_inferrable[factor] = False
# Check that factors in factor_sizes are used in the rule.
for factor in factor_sizes:
if factor not in factors_inferrable:
raise ValueError(
f"Factor {factor} is not used in the rule, but size is provided")
# Check that factors that are used for a whole dimension aren't in
# factor_sizes and factors that are never used for a whole dimension are
# in factor_sizes.
for factor, inferrable in factors_inferrable.items():
if factor not in factor_sizes and not inferrable:
raise ValueError(
f"Factor {factor} is only used in compound factors; must specify"
" its size")
if factor in factor_sizes and inferrable:
raise ValueError(
f"Factor {factor} represents a whole dimension; do not specify its"
" size")
self.operand_mappings = operand_mappings
self.result_mappings = result_mappings
self.factor_sizes = factor_sizes
def __str__(self):
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"
def _get_batching_dim_factor_name(batch_dim_order : int):
"""Constructs a factor name for a batching dimension.
@ -42,18 +150,18 @@ def _get_batching_dim_factor_name(batch_dim_order : int):
def _parse_values(
rule: str,
) -> tuple[_DimMapping, ...]:
) -> tuple[ArrayMapping, ...]:
"""Parses the LHS or RHS of an Einsum notation like string.
Converts each operand or result in the Einsum notation like string to a tuple
of _DimMapping. This very closely follows how einops parses their rules in
of ArrayMapping. This very closely follows how einops parses their rules in
einops/parsing.py.
Args:
rule: The Einsum notation for the operands or results of an operation.
Returns:
The tuple of values.
The tuple of ArrayMapping.
Raises:
ValueError: If the rule is not balanced or contains unknown characters.
@ -65,10 +173,10 @@ def _parse_values(
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
if not rule:
return ((),)
return (ArrayMapping(),)
all_values = []
# Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the
# Represent all dimensions of an value. When an value[0]==BATCHING, the
# value may have 0 or more leading dimensions.
value = []
current_factor = None
@ -84,12 +192,12 @@ def _parse_values(
current_compound_dim.append(x)
for char in rule:
if char == _ELLIPSIS:
if char == BATCHING:
if (current_factor is not None or current_compound_dim is not None
or value):
raise ValueError(
"Ellipsis can only be used at the beginning of a dimension")
add_factor(_ELLIPSIS)
add_factor(BATCHING)
continue
if char in "(), ":
if current_factor is not None:
@ -106,10 +214,10 @@ def _parse_values(
raise ValueError("Brackets are not balanced")
if len(current_compound_dim) <= 1:
raise ValueError("Brackets should contain at least two factors")
value.append(tuple(current_compound_dim))
value.append(CompoundFactor(*current_compound_dim))
current_compound_dim = None
elif char == ",":
all_values.append(tuple(value))
all_values.append(ArrayMapping(*value))
value = []
elif char == "_" or char.isdigit() or char.isalpha():
if current_factor is None:
@ -125,256 +233,203 @@ def _parse_values(
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
if current_factor is not None:
add_factor(current_factor)
all_values.append(tuple(value))
all_values.append(ArrayMapping(*value))
return tuple(all_values)
def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule:
"""Constructs a SdyShardingRule object from the Einsum notation like string.
class SdyShardingRule:
"""A representation for Shardy sharding rule.
This is done by verifying that the input Einsum notation like string and
with optional factor sizes represents a valid sharding rule and converting
it to an internal representation.
A SdyShardingRule includes an Enisum notation like string and an optional
list of factor sizes. A factor is a name in the Einsum notation. If a factor
is only used in compound factors, its size must be specified.
Args:
rule: The Einsum notation like string for an operation.
**factor_sizes: The optional factor sizes.
SdyShardingRule examples:
* Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k')
* Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k')
* A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j')
* Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2)
* An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...')
Raises:
ValueError: If there is any problem with the rule or factor_sizes.
"""
if not isinstance(rule, str):
raise TypeError(f"rule must be a str, but got {type(rule)}")
if not all(isinstance(size, int) for size in factor_sizes.values()):
raise TypeError(
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
def __init__(self, rule: str, **factor_sizes):
"""Constructs a SdyShardingRule object from the Einsum notation like string.
This is done by verifying that the input Einsum notation like string and
with optional factor sizes represents a valid sharding rule and converting
it to an internal representation.
Args:
rule: The Einsum notation like string for an operation.
**factor_sizes: The optional factor sizes.
Raises:
ValueError: If there is any problem with the rule or factor_sizes.
"""
if not isinstance(rule, str):
raise TypeError(f"rule must be a str, but got {type(rule)}")
if not all(isinstance(size, int) for size in factor_sizes.values()):
raise TypeError(
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
# Replace ... with a single char to simplify parsing.
if _ELLIPSIS in rule:
raise ValueError(f"Unknown character '{_ELLIPSIS}'")
# Replace ... with a single char to simplify parsing.
if BATCHING in rule:
raise ValueError(f"Unknown character '{BATCHING}'")
if "." in rule:
rule = rule.replace("...", BATCHING)
if "." in rule:
rule = rule.replace("...", _ELLIPSIS)
if "." in rule:
raise ValueError("Character '.' must be used inside ellipsis '...'")
raise ValueError("Character '.' must be used inside ellipsis '...'")
try:
operands, results = rule.split("->")
except ValueError as e:
raise ValueError(f"There is no -> in rule: '{rule}'") from e
try:
operands, results = rule.split("->")
except ValueError as e:
raise ValueError(f"There is no -> in rule: '{rule}'") from e
self.operands = _parse_values(operands)
self.results = _parse_values(results)
operand_mappings = _parse_values(operands)
result_mappings = _parse_values(results)
# Find all factors and mark whether their size can be inferred.
factors_inferrable = dict()
for value in self.operands + self.results:
for dim in value:
if dim == _ELLIPSIS:
continue
if isinstance(dim, str):
factors_inferrable[dim] = True
else:
for factor in dim:
if factor not in factors_inferrable.keys():
factors_inferrable[factor] = False
return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes)
# Check that factors in factor_sizes are used in the rule.
for factor in factor_sizes:
if factor not in factors_inferrable:
raise ValueError(
f"Factor {factor} is not used in the rule, but size is provided")
# Check that factors that are used for a whole dimension aren't in
# factor_sizes and factors that are never used for a whole dimension are
# in factor_sizes.
for factor, inferrable in factors_inferrable.items():
if factor not in factor_sizes and not inferrable:
raise ValueError(
f"Factor {factor} is only used in compound factors; must specify"
" its size")
if factor in factor_sizes and inferrable:
raise ValueError(
f"Factor {factor} represents a whole dimension; do not specify its"
" size")
def sdy_sharding_rule_to_mlir(
rule: SdyShardingRule,
operand_types: list[ir.Type],
result_types: list[ir.Type],) -> ir.Attribute:
"""Builds the MLIR representation for the sharding rule.
self.factor_sizes = factor_sizes
This is done by verifying that the rule is consistent with the types of
the operation and converting the Einsum notation like string to
OpShardingRuleAttr.
"""
if len(rule.operand_mappings) != len(operand_types):
raise ValueError(
f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation"
f" has {len(operand_types)} operands")
if len(rule.result_mappings) != len(result_types):
raise ValueError(
f"Sharding rule has {len(rule.result_mappings)} results, but the operation"
f" has {len(result_types)} results")
def __str__(self):
return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})"
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
types = operand_types + result_types
UNKNOWN = -1 # Representation for unknown factor size or factor index.
def build(
self,
operand_types: list[ir.Type],
result_types: list[ir.Type],) -> ir.Attribute:
"""Builds the MLIR representation for the sharding rule.
def get_message_for_value(i):
if i >= len(operand_types):
return f"{i - len(operand_types)}th result"
else:
return f"{i}th operand"
This is done by verifying that the rule is consistent with the types of
the operation and converting the Einsum notation like string to
OpShardingRuleAttr.
def get_rank_for_value(i):
return ir.ShapedType(types[i]).rank
def get_size_for_value_dim(i, j):
return ir.ShapedType(types[i]).shape[j]
def add_factor(factor, size):
"""Adds a factor to factors_to_indices_sizes.
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
if a factor is first used as in a compound factor and then used for a
whole dimension.
"""
if len(self.operands) != len(operand_types):
raise ValueError(
f"Sharding rule has {len(self.operands)} operands, but the operation"
f" has {len(operand_types)} operands"
)
if len(self.results) != len(result_types):
raise ValueError(
f"Sharding rule has {len(self.results)} results, but the operation"
f" has {len(result_types)} results"
)
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
types = operand_types + result_types
UNKNOWN = -1 # Representation for unknown factor size or factor index.
def get_message_for_value(i):
if i >= len(operand_types):
return f"{i - len(operand_types)}th result"
else:
return f"{i}th operand"
def get_rank_for_value(i):
return ir.ShapedType(types[i]).rank
def get_size_for_value_dim(i, j):
return ir.ShapedType(types[i]).shape[j]
def add_factor(factor, size):
"""Adds a factor to factors_to_indices_sizes.
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
if a factor is first used as in a compound factor and then used for a
whole dimension.
"""
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
if factor_index != UNKNOWN:
# Not the first time seeing the factor.
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
factor_or_batching_dim = (
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
else f"Batching dimension {factor[1:]}")
raise ValueError(
f"{factor_or_batching_dim} corresponds to two sizes:"
f" {factor_size} and {size}")
if size != UNKNOWN and factor_size == UNKNOWN:
factors_to_indices_sizes[factor] = [factor_index, size]
else:
# First time seeing the factor.
factor_index = len(factors_to_indices_sizes)
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
if factor_index != UNKNOWN:
# Not the first time seeing the factor.
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
factor_or_batching_dim = (
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
else f"Batching dimension {factor[1:]}")
raise ValueError(
f"{factor_or_batching_dim} corresponds to two sizes:"
f" {factor_size} and {size}")
if size != UNKNOWN and factor_size == UNKNOWN:
factors_to_indices_sizes[factor] = [factor_index, size]
else:
# First time seeing the factor.
factor_index = len(factors_to_indices_sizes)
factors_to_indices_sizes[factor] = [factor_index, size]
def add_batching_dim_factor(batch_dim_order, factor_size):
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
add_factor(ellipsis_batch_dim_name, factor_size)
def add_batching_dim_factor(batch_dim_order, factor_size):
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
add_factor(ellipsis_batch_dim_name, factor_size)
def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
all_indices = []
for factor in factors:
factor_index, factor_size = factors_to_indices_sizes[factor]
accumulated_size *= factor_size
all_indices.append(factor_index)
def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
all_indices = []
for factor in factors:
factor_index, factor_size = factors_to_indices_sizes[factor]
accumulated_size *= factor_size
all_indices.append(factor_index)
dim_size = get_size_for_value_dim(i, j)
if accumulated_size != dim_size:
dim_size = get_size_for_value_dim(i, j)
if accumulated_size != dim_size:
raise ValueError(
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
f" the size {accumulated_size} derived from the compound factors"
f" {factors}")
return sdy.DimMappingAttr.get(factor_indices=all_indices)
# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
ellipsis_rank = None
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
if value and value[0] == BATCHING:
has_batching = True
value = value[1:]
else:
has_batching = False
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_batching_rank = 0
if has_batching and op_rank >= rule_rank:
current_batching_rank = op_rank - rule_rank
if has_batching:
if ellipsis_rank is None:
ellipsis_rank = current_batching_rank
elif ellipsis_rank != current_batching_rank:
raise ValueError(
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
f" the size {accumulated_size} derived from the compound factors"
f" {factors}")
"Ellipsis represents different number of leading dimensions"
f" {ellipsis_rank} and {current_batching_rank}")
rule_rank += current_batching_rank
if rule_rank != op_rank:
msg = get_message_for_value(i)
raise ValueError(
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
f" {msg} has rank {op_rank}")
return sdy.DimMappingAttr.get(factor_indices=all_indices)
for j in range(current_batching_rank):
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
ellipsis_rank = None
for i, value in enumerate(self.operands + self.results):
if value and value[0] == _ELLIPSIS:
has_ellipsis = True
value = value[1:]
for j, dim in enumerate(value):
if isinstance(dim, str):
add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank))
else:
has_ellipsis = False
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_ellipsis_rank = 0
if has_ellipsis and op_rank > rule_rank:
current_ellipsis_rank = op_rank - rule_rank
if has_ellipsis:
if ellipsis_rank is None:
ellipsis_rank = current_ellipsis_rank
elif ellipsis_rank != current_ellipsis_rank:
raise ValueError(
"Ellipsis represents different number of leading dimensions"
f" {ellipsis_rank} and {current_ellipsis_rank}")
rule_rank += current_ellipsis_rank
if rule_rank != op_rank:
msg = get_message_for_value(i)
raise ValueError(
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
f" {msg} has rank {op_rank}")
for factor in dim:
add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN))
for j in range(current_ellipsis_rank):
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
# Build the tensor mappings for each operand and result.
tensor_mappings = []
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
dim_mappings = []
for j, dim in enumerate(value):
if isinstance(dim, str):
add_factor(
dim, get_size_for_value_dim(i, j + current_ellipsis_rank))
else:
for factor in dim:
add_factor(factor, self.factor_sizes.get(factor, UNKNOWN))
# Build the tensor mappings for each operand and result.
tensor_mappings = []
for i, value in enumerate(self.operands + self.results):
dim_mappings = []
if value and value[0] == _ELLIPSIS:
value = value[1:]
if ellipsis_rank is None:
current_ellipsis_rank = 0
else:
current_ellipsis_rank = ellipsis_rank
if value and value[0] == BATCHING:
value = value[1:]
if ellipsis_rank is None:
current_batching_rank = 0
else:
current_ellipsis_rank = 0
current_batching_rank = ellipsis_rank
else:
current_batching_rank = 0
for j in range(current_ellipsis_rank):
for j in range(current_batching_rank):
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
for j, dim in enumerate(value):
if isinstance(dim, str):
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
sdy.DimMappingAttr.get(
factor_indices=[factors_to_indices_sizes[dim][0]]))
else:
dim_mappings.append(
build_dim_mapping_for_compound_factors(
i, j + current_batching_rank, dim))
for j, dim in enumerate(value):
if isinstance(dim, str):
dim_mappings.append(
sdy.DimMappingAttr.get(
factor_indices=[factors_to_indices_sizes[dim][0]]))
else:
dim_mappings.append(
build_dim_mapping_for_compound_factors(
i, j + current_ellipsis_rank, dim))
tensor_mappings.append(
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
tensor_mappings.append(
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
op_sharding_rule = sdy.OpShardingRuleAttr.get(
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
operand_mappings=tensor_mappings[0:len(operand_types)],
result_mappings=tensor_mappings[len(operand_types):])
return op_sharding_rule
return sdy.OpShardingRuleAttr.get(
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
operand_mappings=tensor_mappings[0:len(operand_types)],
result_mappings=tensor_mappings[len(operand_types):])

View File

@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):
@util.cache()
def xla_primitive_callable(prim: core.Primitive, **params):
util.test_event("xla_primitive_callable_cache_miss")
def prim_fun(*args):
with config.eager_constant_folding(False):
return prim.bind(*args, **params)

View File

@ -419,7 +419,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
# Otherwise, fall back to numpy.issubdtype
return np.issubdtype(a_sctype, b_sctype)
return bool(np.issubdtype(a_sctype, b_sctype))
can_cast = np.can_cast

View File

@ -203,6 +203,7 @@ class Exported:
_get_vjp: Callable[[Exported], Exported] | None
def mlir_module(self) -> str:
"""A string representation of the `mlir_module_serialized`."""
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def __str__(self):
@ -211,8 +212,8 @@ class Exported:
return f"Exported(fun_name={self.fun_name}, ...)"
def in_shardings_jax(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.in_shardings_hlo.
The Exported object stores `in_shardings_hlo` as HloShardings, which are
@ -221,30 +222,31 @@ class Exported:
`jax.device_put`.
Example usage:
>>> from jax import export
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
... )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
# Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>>
# Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
>>> from jax import export
>>> # Prepare the exported object:
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
... )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
>>> # Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>> # Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
"""
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
for s in self.in_shardings_hlo)
@ -252,7 +254,7 @@ class Exported:
def out_shardings_jax(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.out_shardings_hlo.
"""Creates Shardings corresponding to `self.out_shardings_hlo`.
See documentation for in_shardings_jax.
"""
@ -289,6 +291,21 @@ class Exported:
return serialize(self, vjp_order=vjp_order)
def call(self, *args, **kwargs):
"""Call an exported function from a JAX program.
Args:
args: the positional arguments to pass to the exported function. This
should be a pytree of arrays with the same pytree structure as the
arguments for which the function was exported.
kwargs: the keyword arguments to pass to the exported function.
Returns: a pytree of result array, with the same structure as the
results of the exported function.
The invocation supports reverse-mode AD, and all the features supported
by exporting: shape polymorphism, multi-platform, device polymorphism.
See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
"""
return call_exported(self)(*args, **kwargs)
@ -501,7 +518,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype
@ -997,7 +1014,10 @@ _CPU_FFI_KERNELS = [
"lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi",
"lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi",
"lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi",
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
]
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
@ -1021,6 +1041,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
# schur on CPU
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
# tridiagonal on CPU
"lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd",
# hessenberg on CPU
"lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd",
# lu on GPU

View File

@ -92,8 +92,11 @@ class _SymbolicConstraint:
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
cmp: Comparator
debug_str: str # The form in which the user expressed it, for error messages
e1: DimSize # This has been normalized w.r.t. previous constraints only
e2: DimSize # This has been normalized w.r.t. previous constraints only
# e1, e2, and diff == e1 - e2, are normalized w.r.t. previous constraints only
e1: DimSize
e2: DimSize
# we pre-compute diff to avoid having the normalization rule kick in later.
diff: DimSize
def __repr__(self):
return f"Constraint({self.debug_str})"
@ -126,6 +129,7 @@ class _DimFactor:
MOD = "mod"
MAX = "max"
MIN = "min"
# TODO(necula): remove non_negative
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
@ -764,13 +768,23 @@ class _DimExpr:
return _DimExpr._linear_combination(self, other, 0, 0, self.scope)
return _ensure_poly(other, "mul", self.scope).__mul__(self)
def __pow__(self, power, modulo=None):
assert modulo is None
try:
power = int(power)
except:
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
return functools.reduce(op.mul, [self] * power)
def __pow__(self, power: core.DimSize, modulo=None):
if modulo is not None:
raise NotImplementedError("__pow__ modulo not implemented")
if is_symbolic_dim(power):
return power.__rpow__(self) # type: ignore
if power != int(power):
raise ValueError(f"Symbolic dimension cannot be raised to non-integer powers: '{self}' ** '{power}'")
if power >= 0:
return functools.reduce(op.mul, [self] * power, 1)
# We don't support negative powers, because JAX does not allow negative
# powers for integers
raise ValueError(f"Symbolic dimension cannot be raised to negative powers: '{self}' ** '{power}'")
def __rpow__(self, other, modulo=None):
if modulo is not None:
raise NotImplementedError("__rpow__ modulo not implemented")
return self.__jax_array__().__rpow__(other)
def __floordiv__(self, divisor):
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
@ -1051,29 +1065,51 @@ class SymbolicScope:
if cmp == Comparator.GEQ and not is_geq:
e1, e2 = e2, e1
diff = e1 - e2
if (diff_const := _DimExpr._to_constant(diff)) is not None:
if ((cmp == Comparator.EQ and diff_const != 0) or
(cmp == Comparator.GEQ and diff_const < 0)):
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
# Compute e1 - e2 before we add to normalization rules
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2,
diff=e1 - e2)
self._process_explicit_constraint(constr)
def _process_explicit_constraint(self, constr: _SymbolicConstraint):
if (diff_const := _DimExpr._to_constant(constr.diff)) is not None:
if ((constr.cmp == Comparator.EQ and diff_const != 0) or
(constr.cmp == Comparator.GEQ and diff_const < 0)):
raise ValueError(f"Unsatisfiable explicit constraint: {constr.debug_str}")
return
if cmp == Comparator.EQ:
if not isinstance(e1, _DimExpr):
if constr.cmp == Comparator.EQ:
if not isinstance(constr.e1, _DimExpr):
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
"The left-hand-side must be of the form `term * coefficient`.")
(before, before_k), *rest = e1._sorted_terms
(before, before_k), *rest = constr.e1._sorted_terms
if rest:
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
"The left-hand-side must be of the form `term * coefficient`.")
after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore]
after = _ensure_poly(constr.e2, "parse_constraint", constr.e1.scope) # type: ignore[name-error,unused-ignore]
if before in self._normalization_rules:
raise NotImplementedError(
f"Found multiple equality constraints with the same left-hand-side: {before}")
self._normalization_rules[before] = (after, before_k)
# Look for constraints of the form mod(before_e1, before_k2) * 1 == 0
if (before_k == 1 and
isinstance(constr.e2, int) and constr.e2 == 0 and
(before_f := before.to_factor()) and
before_f.operation == _DimFactor.MOD and
(before_k2 := _DimExpr._to_constant(before_f.operands[1])) is not None):
# Add before_k2*floordiv(before_e1, before_k2) == before_e1
k_times_floordiv = _DimExpr._from_term(
_DimTerm.from_operation(
_DimFactor.FLOORDIV, *before_f.operands, scope=constr.e1.scope),
before_k2, scope=constr.e1.scope)
before_e1 = before_f.operands[0]
self._process_explicit_constraint(
_SymbolicConstraint(cmp=Comparator.EQ,
e1=k_times_floordiv, e2=before_e1,
diff=k_times_floordiv - before_e1,
debug_str=f"{k_times_floordiv} == {before_e1}")
)
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
self._explicit_constraints.append(constr)
def _check_same_scope(self, other: _DimExpr,
@ -1468,7 +1504,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype
@ -1982,7 +2018,8 @@ def compute_dim_vars_from_arg_shapes(
generate the code for computing the dimension variables. It also generates
the shape assertions.
Returns: the values of the dimension variables, in the order determined by
Returns:
The values of the dimension variables, in the order determined by
`all_dim_vars(args_avals)`.
"""
dim_vars = all_dim_vars(args_avals)
@ -1996,8 +2033,7 @@ def compute_dim_vars_from_arg_shapes(
}
synthetic_eval = ShapeEvaluator(synthetic_env)
shape_constraints.shape_assertions(synthetic_eval)
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
return tuple(dim_values)
return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars)
def _solve_dim_equations(
eqns: list[_DimEquation],
@ -2110,14 +2146,12 @@ def _solve_dim_equations(
for constr in scope._explicit_constraints:
# We can't just construct constr.e1 - constr.e2 because for an equality
# constraint it would be reduced to 0.
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
c_diff = c_e1 - c_e2
c_diff = constr.diff._evaluate(shape_env) if not core.is_constant_dim(constr.diff) else constr.diff # type: ignore
shape_constraints.add_constraint(
constr.cmp, c_diff, 0,
error_message_pieces=[
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
f"Expected '{constr.e1} - {constr.e2}' to be "
f"Expected '{constr.diff}' to be "
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
"but found ", c_diff,
@ -2131,7 +2165,8 @@ def _solve_dim_equations(
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
if not eqns:
add_explicit_symbolic_constraints(shape_env)
return shape_env, shape_constraints # SUCCESS
# SUCCESS
return shape_env, shape_constraints # pytype: disable=bad-return-type
elif len(eqns) >= nr_eqns:
break

View File

@ -85,19 +85,11 @@ class _DecisionByElimination:
# the result (albeit, for now, without a good feedback loop to understand
# how the order matters for inequalities).
for constr in self.scope._explicit_constraints:
if not core.is_constant_dim(constr.e1):
self.add_implicit_constraints_expr(constr.e1) # type: ignore
if not core.is_constant_dim(constr.e2):
self.add_implicit_constraints_expr(constr.e2) # type: ignore
# The equality constraints are not needed for inequality decisions,
# because the LHS should always be rewritten in terms of the RHS.
# In fact, adding them may break the assumption that if we eliminate
# the leading term we end up with only smaller terms, because the LHS
# may appear in the rest and may be rewritten to something larger.
# However, we want to add the implicit constraints within.
if constr.cmp == Comparator.GEQ:
self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0,
constr.debug_str)
if not core.is_constant_dim(constr.diff):
self.add_implicit_constraints_expr(constr.diff) # type: ignore
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
constr.debug_str)
# Clear the cache, since we have added constraints.
@ -197,7 +189,7 @@ class _DecisionByElimination:
Combine a term with existing constraints.
For input (t, t_k) the tuple (c_eq, c, c_s, t_s) is among the returned
tuples if there exists a constraint `c =[c_eq] 0` that can be combined
with `t*t_k` to eliminate `t`.
with `t*t_k` to eliminate `t`, and:
* `c =[c_eq] 0`
* The term `comb = t*t_k*t_s + c*c_s` does not contain `t`, and if
@ -207,7 +199,7 @@ class _DecisionByElimination:
"""
# TODO: maybe a generator is useful here instead of materializing the list
acc: list[tuple[Comparator, _DimExpr, int, int]] = []
# First combine with the existing term constraints
# First combine with the existing term bounds
t_lb, t_ub = self._term_bounds.get(t, (-np.inf, np.inf))
if t_lb == t_ub:
acc.append((Comparator.EQ, _DimExpr(((t, 1),), scope) - int(t_lb),

View File

@ -388,7 +388,7 @@ def ffi_call(
f"custom_call_api_version < 4; got {custom_call_api_version}.")
def wrapped(*args: ArrayLike, **kwargs: Any):
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
in_avals = [core.get_aval(x) for x in args]
if input_layouts is None:
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))

View File

@ -241,3 +241,218 @@ module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00",
xla_call_module_version=6,
) # End paste
data_2024_11_29 = {}
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_11_29["c128"] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_zgees_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
[ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
[ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j],
[12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),),
expected_outputs=(array([[ 3.2464249196572972e+01+0.j, -1.3416407864998739e+01+0.j,
-1.2558842947806125e-14+0.j, -7.3490869705474997e-15+0.j],
[ 0.0000000000000000e+00+0.j, -2.4642491965729798e+00+0.j,
-2.5534994473279107e-15+0.j, -1.3671521621839345e-16+0.j],
[ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j,
-1.8779126463272594e-15+0.j, 7.2486619604759691e-16+0.j],
[ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j,
0.0000000000000000e+00+0.j, 4.8523679991768567e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197511 +0.j,
0.5401354211381763 +0.j, -0.09085002384085737+0.j],
[ 0.33000459866554743+0.j, -0.43714638836388686+0.j,
-0.6524649518290251 +0.j, 0.5237265380279561 +0.j],
[ 0.545832745943757 +0.j, -0.04546002040802424-0.j,
-0.31547635975648136+0.j, -0.774903004533341 +0.j],
[ 0.7616608932219662 +0.j, 0.346226347547838 +0.j,
0.42780589044732925+0.j, 0.3420264903462419 +0.j]])),
mlir_module_text=r"""
#loc1 = loc("input")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f64>> loc("input")) -> (tensor<4x4xcomplex<f64>> {jax.result_info = "[0]"}, tensor<4x4xcomplex<f64>> {jax.result_info = "[1]"}) {
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @lapack_zgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex<f64>>) -> (tensor<4x4xcomplex<f64>>, tensor<4x4xcomplex<f64>>, tensor<4xcomplex<f64>>, tensor<i32>, tensor<i32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<4x4xcomplex<f64>> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex<f64>> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<4x4xcomplex<f64>> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex<f64>> loc(#loc3)
return %6, %10 : tensor<4x4xcomplex<f64>>, tensor<4x4xcomplex<f64>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0bO\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02>\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x0b\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_zgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_11_29["c64"] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_cgees_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
[ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
[ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j],
[12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),),
expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -2.1337737e-06+0.j,
1.8261760e-06+0.j],
[ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -6.0543999e-07+0.j,
4.8744488e-07+0.j],
[ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -6.5878328e-07+0.j,
3.9895070e-07+0.j],
[ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j,
3.0199919e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5404726 +0.j,
-0.08882082 +0.j],
[ 0.3300045 +0.j, -0.4371462 +0.j, -0.6544272 +0.j,
0.52127254 +0.j],
[ 0.54583293 +0.j, -0.045460045-0.j, -0.312564 +0.j,
-0.77608234 +0.j],
[ 0.76166105 +0.j, 0.34622625 +0.j, 0.42651838 +0.j,
0.34363067 +0.j]], dtype=complex64)),
mlir_module_text=r"""
#loc1 = loc("input")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f32>> loc("input")) -> (tensor<4x4xcomplex<f32>> {jax.result_info = "[0]"}, tensor<4x4xcomplex<f32>> {jax.result_info = "[1]"}) {
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:5 = stablehlo.custom_call @lapack_cgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex<f32>>) -> (tensor<4x4xcomplex<f32>>, tensor<4x4xcomplex<f32>>, tensor<4xcomplex<f32>>, tensor<i32>, tensor<i32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<4x4xcomplex<f32>> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex<f32>> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<4x4xcomplex<f32>> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex<f32>> loc(#loc3)
return %6, %10 : tensor<4x4xcomplex<f32>>, tensor<4x4xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02\x1e\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\t\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_cgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_11_29["f32"] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_sgees_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]], dtype=float32),),
expected_outputs=(array([[ 3.2464233e+01, -1.3416398e+01, -1.6680369e-05, 4.0411728e-06],
[ 0.0000000e+00, -2.4642496e+00, -1.8640144e-06, 6.7429795e-07],
[ 0.0000000e+00, 0.0000000e+00, -7.2618576e-07, 3.9895073e-07],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0443638e-07]],
dtype=float32), array([[-0.11417632 , 0.8288333 , -0.5413438 , 0.08334288 ],
[-0.33000442 , 0.43714583 , 0.65967286 , -0.5146185 ],
[-0.54583275 , 0.045459934, 0.30468878 , 0.7792079 ],
[-0.7616609 , -0.34622616 , -0.4230168 , -0.34793234 ]],
dtype=float32)),
mlir_module_text=r"""
#loc1 = loc("input")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf32> loc("input")) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) {
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:6 = stablehlo.custom_call @lapack_sgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<i32>, tensor<i32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x4xf32> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x4xf32> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3)
return %6, %10 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\n\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_sgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_11_29["f64"] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_dgees_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]),),
expected_outputs=(array([[ 3.2464249196572979e+01, -1.3416407864998748e+01,
4.7128510442204522e-15, -8.6687960588453852e-15],
[ 0.0000000000000000e+00, -2.4642491965729767e+00,
1.8990547895861982e-15, -2.4680570671743780e-16],
[ 0.0000000000000000e+00, 0.0000000000000000e+00,
-1.8780225147134376e-15, -7.2486619604759710e-16],
[ 0.0000000000000000e+00, 0.0000000000000000e+00,
0.0000000000000000e+00, 4.8523923435746521e-16]]), array([[-0.1141764513873386 , 0.8288327563197505 , 0.5401360966805397 ,
0.09084600741204968],
[-0.3300045986655475 , 0.43714638836388714, -0.6524688462214561 ,
-0.5237216863090944 ],
[-0.5458327459437569 , 0.04546002040802441, -0.31547059759870844,
0.774905350382041 ],
[-0.7616608932219663 , -0.34622634754783793, 0.4278033471396243 ,
-0.3420296714849957 ]])),
mlir_module_text=r"""
#loc1 = loc("input")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf64> loc("input")) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) {
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
%0:6 = stablehlo.custom_call @lapack_dgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<i32>, tensor<i32>) loc(#loc3)
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
%2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<4x4xf64> loc(#loc3)
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3)
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<4x4xf64> loc(#loc3)
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3)
return %6, %10 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\x1a\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_dgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+",
xla_call_module_version=9,
nr_devices=1,
) # End paste

View File

@ -203,3 +203,281 @@ module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00",
xla_call_module_version=6,
) # End paste
data_2024_12_02 = {}
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['c128'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_ztrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array([
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
]),
array([
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
[10.0 + 0.0j, 11.0 + 0.0j, 12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j],
[15.0 + 0.0j, 16.0 + 0.0j, 17.0 + 0.0j, 18.0 + 0.0j, 19.0 + 0.0j],
]),
),
expected_outputs=(
array([
[
0.0 + 0.0j,
0.2 + 0.0j,
0.4 + 0.0j,
0.6000000000000001 + 0.0j,
0.8 + 0.0j,
],
[
0.5 + 0.0j,
0.52 + 0.0j,
0.54 + 0.0j,
0.5599999999999999 + 0.0j,
0.58 + 0.0j,
],
[
0.36666666666666664 + 0.0j,
0.3146666666666667 + 0.0j,
0.2626666666666667 + 0.0j,
0.21066666666666667 + 0.0j,
0.15866666666666665 + 0.0j,
],
[
0.16833333333333336 + 0.0j,
0.1217333333333333 + 0.0j,
0.07513333333333323 + 0.0j,
0.0285333333333333 + 0.0j,
-0.018066666666666675 + 0.0j,
],
]),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f64>> loc("a"), %arg1: tensor<4x5xcomplex<f64>> loc("b")) -> (tensor<4x5xcomplex<f64>> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>> loc(#loc4)
%0 = stablehlo.custom_call @lapack_ztrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f64>>, tensor<4x5xcomplex<f64>>, tensor<complex<f64>>) -> tensor<4x5xcomplex<f64>> loc(#loc4)
return %0 : tensor<4x5xcomplex<f64>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0bO+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\x1e\x03\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\x0b)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ztrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['c64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_ctrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array(
[
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
],
dtype=complex64,
),
array(
[
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
[
10.0 + 0.0j,
11.0 + 0.0j,
12.0 + 0.0j,
13.0 + 0.0j,
14.0 + 0.0j,
],
[
15.0 + 0.0j,
16.0 + 0.0j,
17.0 + 0.0j,
18.0 + 0.0j,
19.0 + 0.0j,
],
],
dtype=complex64,
),
),
expected_outputs=(
array(
[
[0.0 + 0.0j, 0.2 + 0.0j, 0.4 + 0.0j, 0.6 + 0.0j, 0.8 + 0.0j],
[
0.5 + 0.0j,
0.52 + 0.0j,
0.54 + 0.0j,
0.56 + 0.0j,
0.58000004 + 0.0j,
],
[
0.36666667 + 0.0j,
0.31466666 + 0.0j,
0.26266667 + 0.0j,
0.21066667 + 0.0j,
0.15866666 + 0.0j,
],
[
0.16833334 + 0.0j,
0.12173338 + 0.0j,
0.0751333 + 0.0j,
0.02853328 + 0.0j,
-0.018066704 + 0.0j,
],
],
dtype=complex64,
),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xcomplex<f32>> loc("a"), %arg1: tensor<4x5xcomplex<f32>> loc("b")) -> (tensor<4x5xcomplex<f32>> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> loc(#loc4)
%0 = stablehlo.custom_call @lapack_ctrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f32>>, tensor<4x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<4x5xcomplex<f32>> loc(#loc4)
return %0 : tensor<4x5xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\xfe\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\t)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ctrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['f32'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_strsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array(
[
[5.0, 0.0, 0.0, 0.0],
[4.0, 10.0, 0.0, 0.0],
[8.0, 9.0, 15.0, 0.0],
[12.0, 13.0, 14.0, 20.0],
],
dtype=float32,
),
array(
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
],
dtype=float32,
),
),
expected_outputs=(
array(
[
[0.0, 0.2, 0.4, 0.6, 0.8],
[0.5, 0.52, 0.54, 0.56, 0.58000004],
[0.36666667, 0.31466666, 0.26266667, 0.21066667, 0.15866666],
[0.16833334, 0.12173338, 0.0751333, 0.02853328, -0.018066704],
],
dtype=float32,
),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf32> loc("a"), %arg1: tensor<4x5xf32> loc("b")) -> (tensor<4x5xf32> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc4)
%0 = stablehlo.custom_call @lapack_strsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf32>, tensor<4x5xf32>, tensor<f32>) -> tensor<4x5xf32> loc(#loc4)
return %0 : tensor<4x5xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x1f+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xe6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\t\x00\x00\x80?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_strsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_12_02['f64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_dtrsm_ffi'],
serialized_date=datetime.date(2024, 12, 2),
inputs=(
array([
[5.0, 0.0, 0.0, 0.0],
[4.0, 10.0, 0.0, 0.0],
[8.0, 9.0, 15.0, 0.0],
[12.0, 13.0, 14.0, 20.0],
]),
array([
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
]),
),
expected_outputs=(
array([
[0.0, 0.2, 0.4, 0.6000000000000001, 0.8],
[0.5, 0.52, 0.54, 0.5599999999999999, 0.58],
[
0.36666666666666664,
0.3146666666666667,
0.2626666666666667,
0.21066666666666667,
0.15866666666666665,
],
[
0.16833333333333336,
0.1217333333333333,
0.07513333333333323,
0.0285333333333333,
-0.018066666666666675,
],
]),
),
mlir_module_text=r"""
#loc1 = loc("a")
#loc2 = loc("b")
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x4xf64> loc("a"), %arg1: tensor<4x5xf64> loc("b")) -> (tensor<4x5xf64> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64> loc(#loc4)
%0 = stablehlo.custom_call @lapack_dtrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf64>, tensor<4x5xf64>, tensor<f64>) -> tensor<4x5xf64> loc(#loc4)
return %0 : tensor<4x5xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
""",
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xf6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_dtrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
xla_call_module_version=9,
nr_devices=1,
) # End paste

File diff suppressed because one or more lines are too long

View File

@ -143,21 +143,31 @@ def linearize_jaxpr(jaxpr, nonzeros):
def direct_linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
assert not has_aux
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
with core.set_current_trace(linearize_trace):
ans = traceable.call_wrapped(*tracers)
if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
aux_primals = [x.primal
if isinstance(x, LinearizeTracer)
and x._trace.tag is linearize_trace.tag
else x for x in aux]
else:
ans = traceable.call_wrapped(*tracers)
aux = None
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
out_tangents = map(instantiate_zeros, out_tangents)
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
del attrs_tracked # TODO: attrs
return out_primals, out_tangents_pvals, jaxpr, consts
if has_aux:
return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals
else:
return out_primals, out_tangents_pvals, jaxpr, consts
def linearize(traceable, *primals, **kwargs):
if config.use_direct_linearize.value:
@ -175,7 +185,11 @@ def linearize(traceable, *primals, **kwargs):
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals):
raise ValueError(
"Linearization failed to produce known values for all output primals. "
"This is typically caused by attempting to differentiate a function "
"uses an operation that does not support reverse-mode autodiff.")
out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
if not has_aux:
return out_primals_consts, out_tangents_pvals, jaxpr, consts
@ -263,6 +277,20 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
if eqn.primitive.ref_primitive:
if eqn.primitive is core.mutable_array_p:
val_var, = eqn.invars
ref_var, = eqn.outvars
ref = read_primal(ref_var)
ct_out = core.freeze(ref)
write_cotangent(eqn.primitive, val_var, ct_out)
elif eqn.primitive is core.freeze_p:
val_var, = eqn.outvars # type: ignore
ref_var, = eqn.invars # type: ignore
ct_in = instantiate_zeros(read_cotangent(val_var))
write_primal(ref_var, core.mutable_array(ct_in))
continue
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
@ -514,22 +542,45 @@ class LinearizeTrace(Trace):
def process_primitive(self, primitive, args, params):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
tangent_nzs = [type(t) is not Zero for t in tangents_in]
if all(type(t) is Zero for t in tangents_in):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
lin = primitive_linearizations.get(primitive)
if lin is None:
lin = partial(fallback_linearize_rule, primitive)
fallback = partial(fallback_linearize_rule, primitive)
lin = primitive_linearizations.get(primitive, fallback)
with core.set_current_trace(self.parent_trace):
primal_out, tangent_nonzeros_out, residuals, linearized = lin(
tangent_nonzeros, *primals_in, **params)
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
with core.set_current_trace(self.tangent_trace):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
else:
return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace,
(fun, fwd, bwd, *primals_in),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten
with core.set_current_trace(self.parent_trace):
res_and_primals_out = fwd.call_wrapped(*fwd_in)
_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
with core.set_current_trace(self.tangent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
if is_nonzero:
@ -539,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
assert type(tangent) is Zero
return primal
def fallback_linearize_rule(prim, _, *args, **kwargs):
assert not prim.multiple_results
def fallback_linearize_rule(prim, nonzeros, *primals, **params):
jvp = primitive_jvps.get(prim)
if not jvp:
msg = f"Differentiation rule for '{prim}' not implemented"
raise NotImplementedError(msg)
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
tangent_avals = [get_aval(p).to_tangent_aval() for p in primals]
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval)
for aval, nz in zip(tangent_avals, nonzeros)]
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)
def call_prim(*args_):
return [prim.bind(*args_, **kwargs)]
if not prim.multiple_results:
out_primals = [out_primals]
out_tangents = [out_tangents]
with config.use_direct_linearize(False):
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
lu.wrap_init(call_prim), *args, **kwargs)
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
out_nzs = [type(r) is not Zero for r in out_tangents]
out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals]
out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t in tangent_args if type(t) is not Zero]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
def linearized(residuals, *tangents):
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
return out_tangent
def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in)
nz_tangents_out_iter = iter(nz_tangents_out)
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
if prim.multiple_results:
return all_out_tangents
else:
out_tangent, = all_out_tangents
return out_tangent
if prim.multiple_results:
return out_primals, out_nzs, out_consts, linearized
else:
out_primal, = out_primals
out_nz, = out_nzs
return out_primal, out_nz, out_consts, linearized
return out_primal, True, consts, linearized
class LinearizeTracer(Tracer):
__slots__ = ['primal', 'tangent']

View File

@ -50,6 +50,7 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.partition_spec import UnconstrainedSingleton
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects, ir, passmanager
@ -1084,6 +1085,7 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
util.test_event("lower_jaxpr_to_module")
platforms = tuple(map(xb.canonicalize_platform, platforms))
in_avals = (jaxpr.in_avals if arg_shardings is None else
@ -1377,6 +1379,7 @@ def lower_jaxpr_to_fun(
Returns:
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)
# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
@ -1699,6 +1702,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding_impls.SdyArraySharding(
@ -2523,12 +2527,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
if aval.sharding.mesh._are_all_axes_collective:
return op
if aval.sharding.mesh._are_all_axes_auto:
return op
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
# `return op` early and avoid bloating HLO size.
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
# TODO(yashkatariya): Enable this
# unspecified_dims = (set(range(aval.ndim))
# if aval.sharding.mesh._any_axis_collective else None)
return wrap_with_sharding_op(ctx, op, aval, proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_collective:
unspecified_dims = set(range(aval.ndim))
elif aval.sharding.mesh._any_axis_auto:
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
if isinstance(s, UnconstrainedSingleton)}
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
@ -2792,7 +2803,7 @@ def _emit_tpu_python_callback(
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
# Needed for token layouts
layout = np.zeros((0,), dtype="int64")
layout: np.ndarray = np.zeros((0,), dtype="int64")
else:
layout = np.array(minor_to_major, dtype="int64")
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())

View File

@ -177,9 +177,12 @@ class JaxprTrace(Trace['JaxprTracer']):
if const is None:
aval = pval.get_aval()
if type(aval) is DShapedArray:
# TODO(dougalm): Fix the type error and remove the pytype pragmas.
# pytype: disable=attribute-error
shape = [self.new_instantiated_const(d)
if isinstance(d, Tracer) and d._trace.level < self.level else d
for d in aval.shape]
# pytype: enable=attribute-error
aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
else:
@ -1006,7 +1009,7 @@ def partial_eval_jaxpr_stateful(
in_inst: bool | Sequence[bool],
ensure_out_unknowns: bool | Sequence[bool],
ensure_out_inst: bool | Sequence[bool],
saveable: Callable[..., RematCases_],
saveable: Callable[..., RematCases_] | None,
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
@ -1014,6 +1017,8 @@ def partial_eval_jaxpr_stateful(
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
if saveable is None:
saveable = everything_saveable
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
@ -1021,6 +1026,8 @@ def partial_eval_jaxpr_stateful(
tuple(ensure_out_inst), saveable)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
everything_saveable = lambda *_, **__: True
@weakref_lru_cache
def _partial_eval_jaxpr_custom_cached(
jaxpr: Jaxpr,
@ -1776,6 +1783,9 @@ def _inline_literals(
newvars: dict[Var, Var] = {}
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
lit_or_var = (
lambda a: a if isinstance(a, Literal) else (lit(a) or var(a))
)
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
@ -1794,10 +1804,10 @@ def _inline_literals(
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(x) or var(x) for x in eqn.invars]
invars = [lit_or_var(x) for x in eqn.invars]
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,

View File

@ -41,7 +41,6 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
@ -61,11 +60,11 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src.sharding import Sharding as JSharding
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
@ -99,7 +98,6 @@ ShardedAxis = sharding_specs.ShardedAxis
Replicated = sharding_specs.Replicated
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = mesh_lib.Mesh
MeshAxisName = sharding_impls.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
@ -108,8 +106,6 @@ ShardingSpec = sharding_specs.ShardingSpec
def to_xc_copy_semantics(copy_semantics):
if xla_extension_version < 296:
return [None] * len(copy_semantics)
out = []
for cs in copy_semantics:
if cs is None or cs == dispatch.CopySemantics.ALIAS:
@ -234,16 +230,20 @@ shard_arg_handlers[core.MutableArray] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
sharding: JSharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
util.test_event("batched_device_put_start")
try:
from jax._src import array
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
finally:
util.test_event("batched_device_put_end")
def _shard_aval(size, axis: int, aval):
try:
@ -1722,20 +1722,19 @@ def _get_and_check_device_assignment(
devices: Sequence[xc.Device] | None,
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
first_sharding_info = None
if devices is None:
devices = ()
else:
devices = tuple(devices)
devices = () if devices is None else tuple(devices)
for i, s_type, source_info in shardings:
if isinstance(i, UnspecifiedValue):
for sh, s_type, source_info in shardings:
if isinstance(sh, UnspecifiedValue):
continue
if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
continue
if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
else (i._device_assignment, s_type, source_info))
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
(sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO)
else (sh._device_assignment, s_type, source_info))
arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO)
else sh._device_assignment)
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
@ -1836,7 +1835,8 @@ class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
avals: tuple[core.AbstractValue]):
gspmd_shardings = [
s if isinstance(s, (UnspecifiedValue, AUTO))
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
@ -1894,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
abstract_mesh: mesh_lib.AbstractMesh | None):
abstract_mesh: AbstractMesh | None):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -2000,6 +2000,8 @@ def jaxpr_transfer_mem_kinds(
def are_all_shardings_default_mem_kind(da_object, shardings):
if da_object is None:
return True
try:
default_mem_kind = da_object.default_memory_kind
except:
@ -2081,6 +2083,41 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
return tuple(safe_map(read, jaxpr.outvars))
def _get_num_devices(
shardings, device_assignment, lowering_platforms, prim_requires_devices
) -> tuple[int, tuple[xc.Device, ...] | None]:
ext_abstract_mesh, concrete_sharding = None, False
for s in shardings:
if isinstance(s, UnspecifiedValue):
continue
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
raise ValueError("AbstractMesh should be the same across all "
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
ext_abstract_mesh = s.mesh
else:
concrete_sharding = True
if (concrete_sharding and ext_abstract_mesh is not None and
len(device_assignment) != ext_abstract_mesh.size):
raise ValueError(
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
f" device assignment size: {len(device_assignment)}")
if concrete_sharding:
return len(device_assignment), device_assignment
if ext_abstract_mesh is None:
return len(device_assignment), device_assignment
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via"
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
" only AbstractMesh exists in a jitted computation.")
if prim_requires_devices:
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")
return ext_abstract_mesh.size, None
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
@ -2125,12 +2162,14 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return mesh_lib.Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
return Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
axis_types=abstract_mesh.axis_types)
out = []
for s, a in zip(shardings, avals):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
@ -2150,7 +2189,7 @@ def lower_sharding_computation(
donated_invars: Sequence[bool],
*,
keep_unused: bool,
context_mesh: mesh_lib.Mesh | None,
context_mesh: Mesh | None,
compiler_options_kvs: tuple[tuple[str, Any], ...],
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
@ -2208,6 +2247,7 @@ def lower_sharding_computation(
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in unique_intermediate_shardings)),
devices_from_context)
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
if config.sharding_in_types.value:
out_shardings = _concretize_abstract_shardings(
@ -2218,21 +2258,33 @@ def lower_sharding_computation(
platforms = lowering_platforms or (
getattr(backend, "_raw_platform", backend.platform),)
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
da_object = _create_da_object(tuple(device_assignment))
# TODO(yashkatariya): All device specific logic should go in compilation
# but this requires a big refactor. The current `_get_num_devices` logic
# is good enough to lower with AbstractMesh but cannot be compiled. Once
# I refactor, this will also work well with mesh being provided at
# compile time.
# Sets device_assignment to None if only abstractMesh and unspecified exists.
num_devices, device_assignment = _get_num_devices( # type: ignore
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings),
device_assignment, lowering_platforms, prim_requires_devices)
committed = bool(
devices_from_context
or num_devices > 1
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
da_object = (_create_da_object(tuple(device_assignment))
if device_assignment is not None else None)
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings],
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
if all_default_mem_kind:
propagated_out_mem_kinds = (None,) * len(global_out_avals)
@ -2241,12 +2293,12 @@ def lower_sharding_computation(
closed_jaxpr, in_shardings)
# 2. Build up the HLO
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
abstract_mesh = None
if prim_requires_devices:
assert da_object is not None
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings]):
unique_intermediate_shardings):
if isinstance(sharding, NamedSharding):
if (abstract_mesh is not None and
abstract_mesh != sharding.mesh.abstract_mesh):
@ -2264,7 +2316,7 @@ def lower_sharding_computation(
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
semantic_out_shardings, in_layouts, out_layouts, num_devices,
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
@ -2307,7 +2359,7 @@ def lower_sharding_computation(
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
intermediate_shardings=unique_intermediate_shardings,
context_mesh=context_mesh)
@ -2477,7 +2529,7 @@ def _register_out_sharding_handler(
def _gspmd_to_named_sharding(
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
assert isinstance(orig_in_s.mesh, Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
@ -2529,7 +2581,7 @@ def _get_out_sharding_from_orig_sharding(
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
intermediate_shardings=None, context_mesh: Mesh | None = None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
@ -2817,7 +2869,7 @@ class UnloadedMeshExecutable:
keepalive: Any,
kept_var_idx: set[int],
backend: xb.XlaBackend,
device_assignment: xc.DeviceList | Sequence[xc.Device],
device_assignment: xc.DeviceList | Sequence[xc.Device] | None,
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
@ -2829,8 +2881,15 @@ class UnloadedMeshExecutable:
all_args_info: AllArgsInfo | None = None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: mesh_lib.Mesh | None = None
context_mesh: Mesh | None = None,
) -> MeshExecutable:
if (device_assignment is None or
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
for s in it.chain(in_shardings, out_shardings))):
raise RuntimeError(
"A jitted computation cannot contain AbstractMesh in in_shardings and"
" out_shardings during compilation. You can use `jax.export` to "
" lower with an AbstractMesh and later compile with concrete devices.")
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
if isinstance(device_assignment, xc.DeviceList):
@ -2851,6 +2910,7 @@ class UnloadedMeshExecutable:
mesh = i.mesh
break
util.test_event("pxla_cached_compilation")
xla_executable = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,

View File

@ -37,9 +37,6 @@ map, unsafe_map = safe_map, map
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)
def _abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
@ -91,7 +88,7 @@ def _initial_style_jaxprs_with_common_consts(
return [], [], []
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
# If we get a `Ref` in the consts, we know it must come from an outer
# `run_state`. We also know if shouldn't be boxed up in another tracer.
# We assert that it is in fact a DynamicJaxprTracer

View File

@ -49,7 +49,6 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
@ -135,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
@ -227,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
@ -513,7 +512,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# jaxpr for each branch.
branches_known_ : list[core.ClosedJaxpr] = []
branches_staged_: list[core.ClosedJaxpr] = []
branch_res_avals: list[core.AbstractValue] = []
branch_res_avals: list[list[core.AbstractValue]] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(

View File

@ -44,7 +44,7 @@ from jax._src.typing import Array
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict, weakref_lru_cache)
from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
from jax._src.lax.control_flow.common import _initial_style_jaxpr
import numpy as np
## JAX utilities
@ -196,7 +196,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))
carry_avals = tuple(map(_abstractify, init_flat))
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
return jaxpr, out_tree

View File

@ -47,7 +47,7 @@ from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _initial_style_jaxpr,
_avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lax.other import logaddexp
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
@ -361,7 +361,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
if p else 'the input carry')
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
paths, in_carry_flat = unzip2(leaves_and_paths)
in_avals = _map(_abstractify, in_carry_flat)
in_avals = _map(core.get_aval, in_carry_flat)
if in_carry_tree != out_carry_tree:
try:
out_carry = tree_unflatten(out_carry_tree, out_avals)
@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'of the carry output is a {thing2}, so {explanation}'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
@ -1315,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(

View File

@ -32,7 +32,6 @@ from jax._src.util import split_list, safe_map
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_check_tree,
_initial_style_jaxpr,
)
@ -87,7 +86,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
guess_avals = tuple(_map(core.get_aval, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
@ -230,7 +229,7 @@ def custom_linear_solve(
transpose_solve = solve
b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
b_avals = tuple(_map(core.get_aval, b_flat))
tree, = treedef_children(in_args_tree)

View File

@ -102,27 +102,31 @@ def _validate_shapes(shapes: Sequence[Shape]):
else:
map(_check_static_shape, shapes)
def _try_broadcast_shapes(
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
if len(shapes) == 1: return shapes[0]
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
"""
Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
"""
if not shapes:
raise TypeError(f"{name}: At least one shape is required.")
ranks = {len(shape) for shape in shapes}
if len(ranks) > 1: return None # must have consistent rank
rank = ranks.pop()
if not rank: return () # scalar case
if len(ranks) != 1:
raise TypeError(f'{name}: arrays must have the same number of dimensions,'
f' got {ranks}')
result_shape = []
for ds in unsafe_zip(*shapes):
for ds in zip(*shapes):
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
# if all axes are identical objects, the resulting size is the object
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size (or 1)
# if all dims are equal (or 1), the result is the non-1 size
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
return None
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
f'{", ".join(map(str, map(tuple, shapes)))}.')
return tuple(result_shape)
def asarray(x: ArrayLike) -> Array:
@ -159,24 +163,39 @@ def _broadcast_shapes_uncached(*shapes):
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
try: return _reduce(_broadcast_ranks, rst, fst)
except ValueError: pass
result_shape = _max(shapes, key=len)
ndim = len(result_shape)
if ndim == 0 or all(core.definitely_equal_shape(result_shape[ndim - len(s):], s) for s in shapes):
return result_shape
# Next try singleton-broadcasting, padding out ranks using singletons.
ndim = _max(len(shape) for shape in shapes)
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
result_shape = _try_broadcast_shapes(shape_list)
if result_shape is None:
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
return result_shape
rank_promoted_shapes = tuple((*((1,) * (ndim - len(shape))), *shape) for shape in shapes)
try:
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
except TypeError as err:
# Raise ValueError here for backward compatibility.
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
def _broadcast_ranks(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
assert len(s1) <= len(s2)
s1_ = s2[len(s2) - len(s1):]
if core.definitely_equal_shape(s1_, s1): return s2
else: raise ValueError
def broadcast_shardings(*avals) -> NamedSharding:
fst, *rst = avals
if not rst:
return fst.sharding
# First check if we need only rank promotion (and not singleton-broadcasting).
res_aval = _max(avals, key=lambda a: a.ndim)
ndim = res_aval.ndim
if ndim == 0 or all(
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
return res_aval.sharding
# Next try singleton-broadcasting, padding out ranks using singletons.
aval_list = []
for a in avals:
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
new_shape = (1,) * (ndim - a.ndim) + a.shape
aval_list.append(a.update(shape=new_shape,
sharding=a.sharding.with_spec(new_spec)))
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
def _identity(x): return x
@ -1471,7 +1490,7 @@ def reduce(operands: Any,
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
weak_type=weak_type)
else:
flat_init_avals = safe_map(_abstractify, flat_init_values)
flat_init_avals = safe_map(core.get_aval, flat_init_values)
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
@ -1730,6 +1749,9 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
val = ad_util.zeros_like_aval(aval.inner_aval)
return core.mutable_array(val)
# TODO(dougalm): this is nonsense but it's here because in places like
# custom_vjp we assume that all arguments have tangent spaces. We could have
# a distinct NotATangentType value instead.
ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore
def iota(dtype: DTypeLike, size: int) -> Array:
@ -2137,27 +2159,7 @@ def broadcasting_shape_rule(name, *avals):
shapes = [aval.shape for aval in avals if aval.shape]
if not shapes:
return ()
if len({len(shape) for shape in shapes}) != 1:
msg = '{}: arrays must have same number of dimensions, got {}.'
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
# TODO(mattjj): de-duplicate with _try_broadcast_shapes
result_shape = []
for ds in zip(*shapes):
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
# if all axes are identical objects, the resulting size is the object
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
f'{", ".join(map(str, map(tuple, shapes)))}.')
return tuple(result_shape)
return _try_broadcast_shapes(*shapes, name=name)
def broadcasting_sharding_rule(name, *avals):
@ -2206,7 +2208,7 @@ def broadcasting_sharding_rule(name, *avals):
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
require_same_dtypes=False):
require_same_dtypes=True):
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
allow_extended_dtype=allow_extended_dtype,
require_same=require_same_dtypes)
@ -2759,8 +2761,8 @@ def _add_transpose(t, x, y):
# some places (e.g. in custom_jvp) it may not always hold. For example, see
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
if type(t) is ad_util.Zero:
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
else:
@ -2790,8 +2792,8 @@ def _sub_transpose(t, x, y):
# Morally the following assertion is true, but see the comment in add_p's
# transpose rule.
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
if type(t) is ad_util.Zero:
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
else:
@ -3773,6 +3775,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
if platform == "cpu" and precision not in {
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3,
DotAlgorithmPreset.BF16_BF16_F32_X6,
}:
raise ValueError(
f"The precision '{precision}' is not supported by dot_general on CPU")
@ -5502,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
shape = primals[0].shape
iotas = []
for dim, size in enumerate(shape):
iotas.append(broadcasted_iota(np.int64, shape, dim))
sorted_primals_and_idx = sort_p.bind(
*primals, iotas[dimension], dimension=dimension,
is_stable=is_stable, num_keys=num_keys)
idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i]
for i in range(len(shape)))
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
*primals, broadcasted_iota(np.uint64, shape, dimension),
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
dimension))
dnums = slicing.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(dimension,),
start_index_map=(dimension,),
operand_batching_dims=batch_dims,
start_indices_batching_dims=batch_dims,
)
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
gather_idx = partial(
slicing.gather,
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
)
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
for t in tangents]
return tuple(sorted_primals_and_idx[:-1]), tangents_out
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
@ -6381,10 +6396,6 @@ def _eq_meet(a, b):
return eq(a, b)
def _abstractify(x):
return core.get_aval(x)
def empty(dtype):
return empty_p.bind(dtype=dtype)
empty_p = core.Primitive('empty')
@ -6494,3 +6505,7 @@ optimization_barrier_p.def_impl(
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher

View File

@ -48,6 +48,7 @@ from jax._src.lax.lax import (
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -1328,7 +1329,6 @@ def _triangular_solve_lowering(
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
def _triangular_solve_cpu_lower(
ctx, a, b, *, left_side, lower, transpose_a,
@ -1341,10 +1341,12 @@ def _triangular_solve_cpu_lower(
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
return lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
*ctx_args, a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
@ -1357,6 +1359,8 @@ def _triangular_solve_cpu_lower(
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
@ -2616,37 +2620,54 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
batch_dims = operand_aval.shape[:-2]
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable,
a_shape_vals=a_shape_vals)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result
if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat():
schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result
else:
# Number of return values depends on value of sort_eig_vals.
schur_form, schur_vectors, *_, info = gees_result
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
T = _broadcasting_select_hlo(
select_schur_form_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
schur_form = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_T_aval,
broadcast_dimensions=range(len(batch_dims))),
select_T_aval,
T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
output = [T]
mlir.broadcast_in_dim(
ctx,
ok,
select_schur_form_aval,
broadcast_dimensions=range(len(batch_dims)),
),
select_schur_form_aval,
schur_form,
ctx.avals_out[0],
_nan_like_hlo(ctx, ctx.avals_out[0]),
ctx.avals_out[0],
)
output = [schur_form]
if compute_schur_vectors:
select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vs = _broadcasting_select_hlo(
schur_vectors = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vs_aval,
broadcast_dimensions=range(len(batch_dims))),
mlir.broadcast_in_dim(
ctx, ok, select_vs_aval, broadcast_dimensions=range(len(batch_dims))
),
select_vs_aval,
vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
schur_vectors,
ctx.avals_out[1],
_nan_like_hlo(ctx, ctx.avals_out[1]),
ctx.avals_out[1],
)
output.append(vs)
output.append(schur_vectors)
return output
@ -2819,24 +2840,35 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal(x), 0
return tridiagonal(x, lower=lower), 0
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower):
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower, platform):
a_aval, = ctx.avals_in
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
cpu_args = []
if platform == "cpu":
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
cpu_args.extend(ctx_args)
a, d, e, taus, info = sytrd_impl(*cpu_args, a_aval.dtype, a, lower=lower)
return a, d, e, taus, info
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo),
platform='cpu')
tridiagonal_p,
partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo, platform="cpu"),
platform="cpu",
)
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd),
platform='cuda')
tridiagonal_p,
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd, platform="cuda"),
platform="cuda",
)
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd),
platform='rocm')
tridiagonal_p,
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd, platform="rocm"),
platform="rocm",
)
# Utilities

View File

@ -457,6 +457,55 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
return tree_util.tree_map(bind, x)
def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
"""Ragged version of :func:`all_to_all`.
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
replicas (e.g. there is only one group and covers all axis indices).
Ragged arrays are defined by a set of three arrays:
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
along which each indexed element has variable size.
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
``data`` array, and represents the starting offset of each ragged element of
the ``data`` array.
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
the ``data`` array, where the size is specified in units of sub-elements. A
sub-element is defined as the suffix of the ``data`` array shape obtained by
removing the outermost "ragged" dimension.
The ``offsets`` and ``sizes`` arrays must have the same size.
# Example ragged tensor
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
offsets: [3] = {0, 1, 4}
sizes: [3] = {1, 3, 4}
# Index 'data' at 'offsets'[0], 'sizes'[0]'
{a,b,c}
# Index 'data' at 'offsets'[1], 'sizes'[1]'
{d,e,f},{g,h,i},{j,k,l}
# Index 'data' at 'offsets'[2], 'sizes'[2]'
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
Args:
operand: array with ragged dimension along its outermost dimension.
output: array of ragged input offsets.
input_offsets: array of ragged input send sizes.
send_sizes: array of ragged output data.
output_offsets: array of ragged output offsets.
recv_sizes: array of ragged output receive sizes.
Returns:
array with shape equal to ``output``.
"""
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
output_offsets, recv_sizes)
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
def axis_index(axis_name):
"""Return the index along the mapped axis ``axis_name``.
@ -1052,6 +1101,64 @@ batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
N = input_offsets.type.shape[0]
backend_config = ir.DictAttr.get({
'replica_groups': ir.DenseIntElementsAttr.get(
np.arange(0, N, 1, dtype=np.int64), shape=[1, N]
)
})
return hlo.CustomCallOp(
result=[output.type],
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
recv_sizes],
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
backend_config=backend_config,
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
).results
@ragged_all_to_all_p.def_abstract_eval
def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
if operand.shape != output.shape:
raise ValueError('ragged_all_to_all input and output shapes must be equal.')
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
raise ValueError("ragged_all_to_all send_sizes must be integer type.")
if not dtypes.issubdtype(output_offsets.dtype, np.integer):
raise ValueError("ragged_all_to_all output_offsets must be integer type.")
if not dtypes.issubdtype(recv_sizes.dtype, np.integer):
raise ValueError("ragged_all_to_all recv_sizes must be integer type.")
if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1:
raise ValueError(
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
" size, but got shape {}".format(input_offsets.shape)
)
if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1:
raise ValueError(
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
" size, but got shape {}".format(send_sizes.shape)
)
if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1:
raise ValueError(
"ragged_all_to_all output_offsets must be rank 1 with positive"
" dimension size, but got shape {}".format(output_offsets.shape)
)
if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1:
raise ValueError(
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
" size, but got shape {}".format(recv_sizes.shape)
)
return output.update(
shape=list(output.shape),
dtype=output.dtype,
weak_type=output.weak_type,
)
ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p))
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
"""Gather values of x across all replicas.
@ -1484,7 +1591,25 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
axis_name, = axis_name
if axis_name not in axis_env.names:
raise NameError(f"unbound axis name: {axis_name}")
axis_context = ctx.module_context.axis_context
axis_pos = list(axis_env.names).index(axis_name)
# For partial auto, lower using iota.
if (isinstance(axis_context, SPMDAxisContext) and
axis_context.manual_axes and
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
x = hlo.iota(ir.RankedTensorType.get(
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
sharding_proto = (
NamedSharding(axis_context.mesh, P(axis_name))
._to_xla_hlo_sharding(1).to_proto())
aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32)
aval_out = ShapedArray((1,), np.int32)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto)
proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh)
x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto)
return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x)
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
div = mlir.ir_constant(
np.array(
@ -1492,12 +1617,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
)
)
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
device_id = hlo.partition_id()
else:
device_id = hlo.replica_id()

View File

@ -483,7 +483,7 @@ def scatter_add(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.add,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -536,7 +536,7 @@ def scatter_sub(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(
lax.sub, lax._abstractify(lax._const(operand, 0))
lax.sub, core.get_aval(lax._const(operand, 0))
)
return scatter_sub_p.bind(
operand,
@ -591,7 +591,7 @@ def scatter_mul(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
lax._abstractify(lax._const(operand, 1)))
core.get_aval(lax._const(operand, 1)))
return scatter_mul_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -638,7 +638,7 @@ def scatter_min(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.min,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_min_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -685,7 +685,7 @@ def scatter_max(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(lax.max,
lax._abstractify(lax._const(operand, 0)))
core.get_aval(lax._const(operand, 0)))
return scatter_max_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -748,7 +748,7 @@ def scatter_apply(
_apply = _scatter_apply_cache.setdefault(func, _apply)
except TypeError: # func is not weak referenceable
pass
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
return scatter_p.bind(
operand, scatter_indices, unused, update_jaxpr=jaxpr,

View File

@ -90,7 +90,7 @@ def _reduce_window(
return monoid_reducer(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation)
else:
flat_init_avals = map(lax._abstractify, flat_init_values)
flat_init_avals = map(core.get_aval, flat_init_values)
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree
)
@ -176,7 +176,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None) -> Array:
init_value = lax._const(operand, 1)
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
@ -226,7 +226,7 @@ def _reduce_window_logaddexp(
base_dilation: Sequence[int] | None = None,
window_dilation: Sequence[int] | None = None) -> Array:
init_value = lax._const(operand, -np.inf)
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
jaxpr, consts = lax._reduction_jaxpr(logaddexp, core.get_aval(init_value))
if base_dilation is None:
base_dilation = (1,) * len(window_dimensions)
if window_dilation is None:
@ -245,9 +245,9 @@ def _select_and_scatter(operand: Array, select: Callable,
padding: Sequence[tuple[int, int]], source: Array,
init_value: Array, scatter: Callable) -> Array:
select_jaxpr, select_consts = lax._reduction_jaxpr(
select, lax._abstractify(init_value))
select, core.get_aval(init_value))
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
scatter, lax._abstractify(init_value))
scatter, core.get_aval(init_value))
return select_and_scatter_p.bind(
operand, source, init_value, select_jaxpr=select_jaxpr,
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,

View File

@ -111,8 +111,6 @@ class AxisTypes(enum.Enum):
return self.name
def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
if axis_types is None:
return {}
d = {}
for t, names in axis_types.items():
if isinstance(names, tuple):
@ -124,6 +122,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
_mesh_object_dict = {} # type: ignore
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
class Mesh(contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.
@ -178,11 +177,11 @@ class Mesh(contextlib.ContextDecorator):
devices: np.ndarray
axis_names: tuple[MeshAxisName, ...]
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
axis_types: MeshAxisType
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
axis_names: str | Sequence[MeshAxisName], *,
axis_types: MeshAxisType | None = None):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
@ -198,9 +197,9 @@ class Mesh(contextlib.ContextDecorator):
f"devices.ndim == {devices.ndim} and "
f"len(axis_names) == {len(axis_names)}.")
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
axis_types)
axis_types_tuple = tuple(axis_types.items())
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
val = _mesh_object_dict.get(key, None)
if val is not None:
@ -216,7 +215,8 @@ class Mesh(contextlib.ContextDecorator):
return self
def __reduce__(self):
return (type(self), (self.devices, self.axis_names, self.axis_types))
return (type(self), (self.devices, self.axis_names),
{'axis_types': self.axis_types})
def __eq__(self, other):
if not isinstance(other, Mesh):
@ -335,7 +335,7 @@ class Mesh(contextlib.ContextDecorator):
def _repr(self):
if self.empty:
return "Mesh(device_ids=[], axis_names=())"
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
atr = f", axis_types={self.axis_types}"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
def __repr__(self):
@ -348,7 +348,7 @@ class Mesh(contextlib.ContextDecorator):
@functools.cached_property
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, self.axis_types)
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@ -373,17 +373,16 @@ class AbstractMesh:
details.
"""
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
axis_types: MeshAxisType | None = None):
self.shape_tuple = shape_tuple
self.axis_types = axis_types
if self.shape_tuple:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
self._axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
else axis_types)
self._axis_types_tuple = tuple(self.axis_types.items())
def __hash__(self):
return hash((self.shape_tuple, self._axis_types_tuple))
@ -397,7 +396,7 @@ class AbstractMesh:
self._axis_types_tuple == other._axis_types_tuple)
def __repr__(self):
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
atr = f", axis_types={self.axis_types}"
return f"AbstractMesh({self.shape_tuple}{atr})"
@property
@ -430,10 +429,20 @@ class AbstractMesh:
@functools.cached_property
def _are_all_axes_collective(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_auto(self) -> bool:
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
@property
def devices(self):
_raise_value_error("devices")

View File

@ -386,7 +386,7 @@ def _create_device_mesh_for_nd_torus_splitting_axes(
)
):
best_logical_axis_assignment = logical_axis_assignment
assignment[:, logical_axis] = best_logical_axis_assignment
assignment[:, logical_axis] = best_logical_axis_assignment # type: ignore # numpy 2.2
# Read out the assignment.
logical_mesh = _generate_logical_mesh(
@ -597,10 +597,10 @@ def _generate_logical_mesh(
zip(logical_indices, physical_indices, range(len(logical_indices)))
)
)
logical_mesh = np.transpose(logical_mesh, transpose_axes)
logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2
# Reshape to add the trivial dimensions back.
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2
return logical_mesh

View File

@ -67,7 +67,7 @@ def relu(x: ArrayLike) -> Array:
For more information see
`Numerical influence of ReLU(0) on backpropagation
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
Args:
x : input array
@ -84,7 +84,7 @@ def relu(x: ArrayLike) -> Array:
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
@jax.jit

View File

@ -9168,6 +9168,89 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
@export
@jit
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Batched matrix-vector product.
JAX implementation of :func:`numpy.matvec`.
Args:
x1: array of shape ``(..., M, N)``
x2: array of shape ``(..., N)``. Leading dimensions must be broadcast-compatible
with leading dimensions of ``x1``.
Returns:
An array of shape ``(..., M)`` containing the batched matrix-vector product.
See also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.vecmat`: vector-matrix product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
Examples:
Simple matrix-vector product:
>>> x1 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)
Batched matrix-vector product:
>>> x2 = jnp.array([[7, 8, 9],
... [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
[ 38, 92]], dtype=int32)
"""
util.check_arraylike("matvec", x1, x2)
return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2)
@export
@jit
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Batched conjugate vector-matrix product.
JAX implementation of :func:`numpy.vecmat`.
Args:
x1: array of shape ``(..., M)``.
x2: array of shape ``(..., M, N)``. Leading dimensions must be broadcast-compatible
with leading dimensions of ``x1``.
Returns:
An array of shape ``(..., N)`` containing the batched conjugate vector-matrix product.
See also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.matvec`: matrix-vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
Examples:
Simple vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
... [6, 7],
... [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)
Batched vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40, 46],
[ 94, 109]], dtype=int32)
"""
util.check_arraylike("matvec", x1, x2)
return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2)
@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def vdot(
@ -9244,6 +9327,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
See Also:
- :func:`jax.numpy.vdot`: flattened vector product.
- :func:`jax.numpy.vecmat`: vector-matrix product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.

View File

@ -1374,7 +1374,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
else:
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
rcond = float(jnp.finfo(dtype).eps) * max(n, m)
else:
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
u, s, vt = svd(a, full_matrices=False)
@ -1517,7 +1517,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@export
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array:
"""Compute the norm of a matrix or stack of matrices.
JAX implementation of :func:`numpy.linalg.matrix_norm`

View File

@ -246,7 +246,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
# set rcond
if rcond is None:
rcond = len(x_arr) * finfo(x_arr.dtype).eps
rcond = len(x_arr) * float(finfo(x_arr.dtype).eps)
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x_arr, order)

View File

@ -23,6 +23,7 @@ from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import api_util
from jax._src.lax import lax
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
@partial(api.jit, inline=True)
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
avals = [api_util.shaped_abstractify(arg) for arg in args]
shapes = [a.shape for a in avals]
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
if config.sharding_in_types.value else None)
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
) -> Array:
check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
sharding=sharding)
# The `jit` on `where` exists to avoid materializing constants in cases like

View File

@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices(
for i, idxs in enumerate(itertools.product(*axis_indices)):
for _ in range(num_replicas):
indices[next(device_it)] = idxs
indices[next(device_it)] = idxs # type: ignore # numpy 2.2
return indices

View File

@ -13,14 +13,17 @@
# limitations under the License.
"""Helper tool for automatic cost estimation."""
import dataclasses
import functools
import math
from typing import Any, Sequence
import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src.state import discharge
from jax._src.pallas import core as pallas_core
from jax._src.interpreters import partial_eval as pe
from jax._src.util import safe_map
@ -87,10 +90,9 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
A pallas_core.CostEstimate object containing the cost estimate.
"""
flattened_args, treedef = jax.tree.flatten(args)
def _partial_fun(*flat_args):
return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs)
wrapped_fun = lu.wrap_init(
lambda *args, **kwargs: (_partial_fun(*args, **kwargs),))
partial_fun = functools.partial(fun, **kwargs)
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun),
treedef)
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
@ -243,3 +245,12 @@ def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_):
bytes_accessed=inner_cost.bytes_accessed,
)
register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule)
def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2):
inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr))
return CostEstimate(
flops=inner_cost.flops,
transcendentals=inner_cost.transcendentals,
bytes_accessed=inner_cost.bytes_accessed,
)
register_cost_rule(discharge.run_state_p, _run_state_rule)

View File

@ -63,7 +63,7 @@ from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
from jax._src.state.types import RefBitcaster, RefReshaper
from jax._src.state.utils import dtype_bitwidth
from jax._src.typing import DTypeLike
from jax._src.typing import Array, DTypeLike
from jax._src.util import safe_map
from jax._src.util import safe_zip
from jax._src.util import split_list
@ -2295,7 +2295,49 @@ _cmpf_lowering_types = {
}
def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
# The relationship between comparison operations on booleans and boolean
# algebra is as follows:
# eq(x, y) = !(x ^ y)
# ne(x, y) = x ^ y
# lt(x, y) = !x && y
# le(x, y) = !x || y
# gt(x, y) = x && !y
# ge(x, y) = x || !y
def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array):
"""A helper function for lowering comparison operations for boolean inputs.
Args:
primitive: A JAX primitive representing a comparison operation, which is
one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals),
`lax.lt_p` (less than), `lax.le_p` (less than or equal to),
`lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to).
x: A boolean array representing the first operand in the comparison.
y: A boolean array representing the second operand in the comparison.
Returns:
A boolean array that is the result of applying the comparison operation
between `x` and `y` based on the given primitive.
Raises:
ValueError: If an unsupported comparison primitive is provided.
"""
if primitive == lax.eq_p:
return jnp.logical_not(jnp.logical_xor(x, y))
elif primitive == lax.ne_p:
return jnp.logical_xor(x, y)
elif primitive == lax.lt_p:
return jnp.logical_and(jnp.logical_not(x), y)
elif primitive == lax.le_p:
return jnp.logical_or(jnp.logical_not(x), y)
elif primitive == lax.gt_p:
return jnp.logical_and(x, jnp.logical_not(y))
elif primitive == lax.ge_p:
return jnp.logical_or(x, jnp.logical_not(y))
else:
raise ValueError(f"Unsupported comparison primitive: {primitive}")
def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
x_aval, y_aval = ctx.avals_in
if x_aval.dtype != y_aval.dtype:
@ -2304,60 +2346,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
)
dtype = x_aval.dtype
# For boolean comparisons, we handle them in two different ways. For `ne`,
# we directly use the xor operation since they are equivalent. For all
# other comparisons, we convert the boolean values to `int32` and use select
# operations to perform the comparison.
#
# The relationship between comparison operations on booleans and boolean
# algebra is as follows:
#
# eq(a, b) = !(a ^ b)
# ne(a, b) = a ^ b
# lt(a, b) = !a && b
# le(a, b) = !a || b
# gt(a, b) = a && !b
# ge(a, b) = a || !b
#
# However, except for `ne`, all other operations require negation, which is
# currently not supported. At present, even if negation were supported,
# it would still need to be implemented using `select` operations, making
# it equivalent to our current approach. For more details on negation support,
# see https://github.com/jax-ml/jax/issues/24243.
if jnp.issubdtype(dtype, jnp.bool_):
if prim == lax.ne_p:
return arith.xori(x, y)
i32 = ir.IntegerType.get_signless(32)
vtype = ir.VectorType.get(x_aval.shape, i32)
# Convert `x` and `y` from `bool` to `int32` for comparison, with 2
# for true and 0 for false. For example, comparing `x > y` is equivalent
# to `(x ? 2 : 0) > (y ? 2 : 0)`.
#
# Note that we cannot use 1 for true because the select operation will be
# misteriously eliminated.
two = arith.constant(i32, 2)
zero = arith.constant(i32, 0)
out_aval, = ctx.avals_out
if out_aval.shape != ():
# broadcast to vectors if we are comparing vectors
two = vector.broadcast(vtype, two)
zero = vector.broadcast(vtype, zero)
x = arith.select(x, two, zero)
y = arith.select(y, two, zero)
dtype = jnp.int32
return lower_fun(
functools.partial(_cmp_boolean_lowering_helper, primitive),
multiple_results=False,
)(ctx, x, y)
if jnp.issubdtype(dtype, jnp.integer):
is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
pred = (_cmpui_lowering_types if is_uint else _cmpsi_lowering_types)[prim]
pred = (
_cmpui_lowering_types if is_uint else _cmpsi_lowering_types
)[primitive]
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
return arith.cmpi(predicate, x, y)
if jnp.issubdtype(dtype, jnp.floating):
pred = _cmpf_lowering_types[prim]
pred = _cmpf_lowering_types[primitive]
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
return arith.cmpf(predicate, x, y)
@ -2381,6 +2385,15 @@ lowering_rules[lax.and_p] = _and_lowering_rule
skip_mlir_conversions.add(lax.and_p)
def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
out_aval, = ctx.avals_out
out_type = aval_to_ir_type(out_aval)
return _not_lowering_rule(ctx, tpu.weird(out_type, x))
lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
return arith.ori(x, y)

View File

@ -21,10 +21,9 @@ import tempfile
from typing import Any
import jax
from jax import core as jax_core
from jax import dtypes
from jax._src import config
from jax._src import core as jax_src_core
from jax._src import core as jax_core
from jax._src import sharding_impls
from jax._src import tpu_custom_call
from jax._src.interpreters import mlir
@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule(
# Replace in_avals to physical avals.
# This step is required for mapping logical types to physical types.
# (e.g. PRNG key -> uint32[2])
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
ctx = ctx.replace(avals_in=physical_avals)
# Booleans are loaded into the kernel as integers.

View File

@ -139,7 +139,7 @@ def roll(
@roll_p.def_abstract_eval
def _roll_abstract_eval(x, shift, **_):
del shift
return jax_core.raise_to_shaped(x)
return x
def _roll_lowering_rule(

View File

@ -44,6 +44,7 @@ pytype_strict_library(
deps = [
":lowering",
"//jax",
"//jax:core",
"//jax:mlir",
"//jax:mosaic_gpu",
"//jax/_src/pallas",

View File

@ -1256,13 +1256,11 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
match x.layout:
case mgpu.WGStridedFragLayout():
if axes != (0,):
raise NotImplementedError("No support for axes other than 0 yet")
if set(axes) != set(range(x_aval.ndim)):
raise NotImplementedError("No support for axes yet")
scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)
with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]:
return mgpu.FragmentedArray.splat(
x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype)
)
return x.reduce_sum(scratch)
case mgpu.WGMMA_LAYOUT:
if axes != (x_aval.ndim - 1,):
raise NotImplementedError

View File

@ -23,7 +23,7 @@ from typing import Any
import warnings
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering

View File

@ -112,7 +112,6 @@ def _pad_values_to_block_dimension(value,
return value
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
return tuple(
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
)
@ -1151,7 +1150,7 @@ def checkify_pallas_kernel_body_jaxpr(
grid_mapping: GridMapping) -> tuple[
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
err_vals, err_tree = tree_util.tree_flatten(error)
err_vals = map(checkify.get_shaped_aval, err_vals)
err_vals = map(jax_core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
@ -1274,13 +1273,13 @@ def pallas_call_checkify_rule(error: checkify.Error,
closed_jaxpr, enabled_errors, error, grid_mapping)
error = error._add_placeholder_effects(error_effects)
err_vals, err_in_tree = jax.tree.flatten(error)
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
shaped_err_avals = map(jax_core.get_aval, err_vals)
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
# all enabled errors removed, but have the error as inputs and return values.
input_avals = [v.aval for v in jaxpr.invars]
num_err_vals = len(err_vals)
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
shaped_input_avals = tuple(input_avals)
checkify_in_avals = [*shaped_err_avals,
*shaped_input_avals]
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
@ -1416,8 +1415,7 @@ def _trace_kernel_to_jaxpr(
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
if consts:
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
consts_avals = [jax_core.get_aval(c) for c in consts]
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
@ -1804,8 +1802,7 @@ def pallas_call(
def wrapped(*args):
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_in_avals = tuple(jax_core.get_aval(a) for a in flat_args)
flat_out_avals = tuple(_convert_out_shape_to_aval(v)
for v in flat_out_shapes)

View File

@ -76,6 +76,7 @@ pytype_strict_library(
":lowering",
"//jax",
"//jax:config",
"//jax:core",
"//jax:mlir",
"//jax:util",
"//jax/_src/lib",

View File

@ -102,7 +102,7 @@ class LoweringRuleContext:
@dataclasses.dataclass
class LoweringResult:
"""Keeps pybind11 objects alive."""
"""Keeps python objects alive."""
module: ir.Module
grid: tuple[int, ...]
@ -1983,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation):
return tt_dialect.trans(x, permutation)
def _check_dot_operands(
x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any
):
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
return
def _dot(
x: ir.Value,
y: ir.Value,
acc: ir.Value | None = None,
*,
allow_tf32: bool = True,
max_num_imprecise_acc: int | None = None,
out_type: ir.Type | None = None,
) -> ir.Value:
if out_type is None:
out_type = ir.F32Type.get()
elif isinstance(out_type, ir.BF16Type):
raise NotImplementedError(f"unsupported output type: {out_type}")
x_type = ir.RankedTensorType(x.type)
y_type = ir.RankedTensorType(y.type)
if min(*x_type.shape, *y_type.shape) < 16:
raise ValueError("all dimensions of x and y must be >= 16 ")
if x_type.element_type != y_type.element_type:
raise ValueError(
"x and y must have the same element type, but got:"
f" {x_type.element_type} and {y_type.element_type}"
)
_check_dot_operands(x_type, y_type, object())
element_type = x_type.element_type
if isinstance(element_type, ir.IntegerType):
if element_type.width != 8:
raise TypeError(f"unsupported element type: {element_type}")
element_type = ir.IntegerType.get_signless(32)
elif isinstance(element_type, (ir.F32Type, ir.BF16Type)):
element_type = ir.F32Type.get()
else:
element_type = out_type
if element_type != out_type:
raise TypeError(
f"output type {out_type} does not match element type {element_type}"
)
m, _ = x_type.shape
_, n = y_type.shape
if acc is None:
acc = _full(ir.RankedTensorType.get([m, n], element_type), 0)
if max_num_imprecise_acc is None:
if isinstance(element_type, ir.FloatType) and element_type.width == 8:
# TODO(slebedev): Fill in from options.
raise NotImplementedError
else:
max_num_imprecise_acc = 0
# Ideally, replace all allow_tf32 usages with InputPrecision directly.
input_precision = tt_dialect.InputPrecision.IEEE
if allow_tf32:
input_precision = tt_dialect.InputPrecision.TF32
return tt_dialect.dot(
x,
y,
acc,
max_num_imprecise_acc=max_num_imprecise_acc,
input_precision=input_precision
)
_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
@ -2081,27 +2006,63 @@ def _dot_general_lowering(
if b_contract_dim == 1:
b = tt_dialect.trans(b, (1, 0))
if precision is None:
allow_tf32 = True
else:
prec_a, prec_b = precision
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
a_aval, b_aval = ctx.avals_in
[out_aval] = ctx.avals_out
out_dtype = acc_dtype = out_aval.dtype
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
acc_dtype = jnp.dtype(jnp.float32)
return _cast(
_dot(
a,
b,
allow_tf32=allow_tf32,
out_type=_dtype_to_ir_type(acc_dtype),
),
acc_dtype,
out_dtype,
)
if precision is None or (precision == lax.DotAlgorithmPreset.DEFAULT):
precision = (lax.Precision.DEFAULT, lax.Precision.DEFAULT)
if isinstance(precision, lax.DotAlgorithmPreset):
match precision:
case lax.DotAlgorithmPreset.TF32_TF32_F32:
input_precision = tt_dialect.InputPrecision.TF32
case lax.DotAlgorithmPreset.TF32_TF32_F32_X3:
input_precision = tt_dialect.InputPrecision.TF32x3
case lax.DotAlgorithmPreset.F32_F32_F32:
input_precision = tt_dialect.InputPrecision.IEEE
case (
lax.DotAlgorithmPreset.F16_F16_F16
| lax.DotAlgorithmPreset.F16_F16_F32
| lax.DotAlgorithmPreset.BF16_BF16_BF16
| lax.DotAlgorithmPreset.BF16_BF16_F32
):
input_precision = None
case _:
raise NotImplementedError(f"Unsupported dot algorithm: {precision}.")
a = _cast(a, a_aval.dtype, precision.supported_lhs_types[0])
b = _cast(b, b_aval.dtype, precision.supported_rhs_types[0])
acc_dtype = precision.accumulation_type
elif isinstance(precision, tuple):
a_precision, b_precision = precision
if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS:
input_precision = tt_dialect.InputPrecision.TF32
elif a_aval.dtype == jnp.float32:
input_precision = tt_dialect.InputPrecision.IEEE
else:
input_precision = None
acc_dtype = out_aval.dtype
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
acc_dtype = jnp.float32
else:
raise NotImplementedError(f"Unsupported dot precision: {precision}.")
a_type = ir.RankedTensorType(a.type)
b_type = ir.RankedTensorType(b.type)
if min(*a_type.shape, *b_type.shape) < 16:
raise ValueError("all dimensions of a and b must be >= 16 ")
if a_type.element_type != b_type.element_type:
raise ValueError(
"a and b must have the same element type, but got:"
f" {a_type.element_type} and {b_type.element_type}"
)
m, _ = a_type.shape
_, n = b_type.shape
acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0)
acc = tt_dialect.dot(a, b, acc, input_precision=input_precision)
return _cast(acc, acc_dtype, out_aval.dtype)
def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
@ -2623,7 +2584,8 @@ def _i64_constant(v: int) -> ir.Value:
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
def _dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type:
dtype = jnp.dtype(dtype)
if jnp.issubdtype(dtype, np.integer):
# All integer types in Triton are signless.
return ir.IntegerType.get_signless(dtype.itemsize * 8)

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import io
from typing import Any
from jax import core as jax_core
import jax._src.core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core

View File

@ -19,7 +19,7 @@ from __future__ import annotations
from collections.abc import Sequence
import jax
from jax import core as jax_core
from jax._src import core as jax_core
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas.triton import lowering

View File

@ -14,7 +14,7 @@
from __future__ import annotations
class _UnconstrainedPartitionSingleton:
class UnconstrainedSingleton:
def __repr__(self):
return "UNCONSTRAINED"
@ -23,7 +23,7 @@ class _UnconstrainedPartitionSingleton:
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()
class PartitionSpec(tuple):

View File

@ -498,7 +498,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
@ -549,6 +549,7 @@ def _infer_params_impl(
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
) -> tuple[PjitParams, list[Any]]:
util.test_event("pjit._infer_params_impl", fun)
have_kwargs = bool(kwargs)
if have_kwargs and ji.user_specified_in_shardings:
raise ValueError(
@ -698,9 +699,6 @@ def get_abstract_mesh_from_avals(in_avals):
return None
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if a.sharding is None: # type: ignore
continue
if m is not None and m != a.sharding.mesh:
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
@ -1300,6 +1298,7 @@ def _create_pjit_jaxpr(
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
util.test_event("create_pjit_jaxpr")
del ignored_inline # just for explain_cache_miss
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
@ -1787,10 +1786,9 @@ def _pjit_lower(
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_concrete_mesh()
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
api_name = 'jit'
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
else:
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
@ -2156,8 +2154,18 @@ def _pjit_partial_eval(trace, *in_tracers,
known_ins = tuple(pv.is_known() for pv in in_pvals)
unknown_ins = tuple(not k for k in known_ins)
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect))
for e in jaxpr.effects):
known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins,
False, False, None)
if num_res_ref: raise NotImplementedError
known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts)
unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts)
res_avals = unknown_jaxpr.in_avals[:num_res_val]
else:
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
@ -2343,8 +2351,7 @@ def _pjit_transpose(cts_in, *primals_in,
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
*prune_type(ad.Zero, out_layouts, cts_in)
)
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in)
global_cts_in_avals = tuple(core.get_aval(ct) for ct in primals_and_nz_cts_in)
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
body, global_cts_in_avals)

View File

@ -1249,7 +1249,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
random_gamma_p.def_abstract_eval(lambda key, a, **_: a)
ad.defjvp2(
random_gamma_p, None,
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))

View File

@ -130,7 +130,7 @@ class Rotation(typing.NamedTuple):
else:
return self.quat.shape[0]
def __mul__(self, other):
def __mul__(self, other) -> Rotation:
"""Compose this rotation with the other."""
return Rotation.from_quat(_compose_quat(self.quat, other.quat))

View File

@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
f"is also found in manual_axes: {_manual_axes}.") from None
@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
if mesh.axis_types is None:
return
for p in parsed_pspec:
if p is not None:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
def hashed_index(x) -> int:
# This works for both `pjit` indices and `pmap` indices (which might
# have an integer instead of a slice).
@ -725,7 +738,7 @@ class PositionalSharding(sharding.Sharding):
ids = self._ids.copy()
platform_name = self._devices[0].platform.upper()
for idx, x in np.ndenumerate(ids):
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
max_line_width=100)
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
_check_axis_type_consistency(mesh, parsed_pspec)
return parsed_pspec
@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
*, devices: Sequence[xc.Device] | None = None,
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of
@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
mesh_devices = mesh_utils.create_device_mesh(
new_axis_shapes, devices,
allow_split_physical_axes=allow_split_physical_axes)
return mesh_lib.Mesh(mesh_devices, axis_names)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)

View File

@ -97,7 +97,7 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(itertools.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices[i] = idxs # type: ignore # numpy 2.2
shard_indices = shard_indices.reshape(shard_indices_shape)
# Ensure that each sharded axis is used exactly once in the mesh mapping

View File

@ -533,6 +533,7 @@ class Compiled(Stage):
@staticmethod
def call(*args, **kwargs):
util.test_event("stages_compiled_call")
# This is because `__call__` passes in `self._params` as the first argument.
# Instead of making the call signature `call(params, *args, **kwargs)`
# extract it from args because `params` can be passed as a kwarg by users

View File

@ -153,6 +153,10 @@ def _eval_jaxpr_discharge_state(
[invar], [outvar] = eqn.invars, eqn.outvars
ans = env.read(invar)
refs_to_discharge.add(id(outvar.aval))
elif eqn.primitive is core.freeze_p:
[invar], [outvar] = eqn.invars, eqn.outvars
ans = env.read(invar)
refs_to_discharge.remove(id(invar.aval))
elif (any(should_discharge)
or core.internal_mutable_array_effect in eqn.effects
):
@ -364,7 +368,7 @@ def transform_swap_array(x, transforms, val):
case indexing.NDIndexer():
indexer = transform
if _is_trivial_indexer(indexer):
_results.append(None)
_results.append(_results[-1])
continue
# If everything in the indexer is a slice or ()-shaped, we can also
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.

View File

@ -654,3 +654,8 @@ def _broadcast_to_abstract_eval(aval, *, shape):
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)
# === AD rules for mutable arrays ===
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))

View File

@ -414,7 +414,7 @@ _ref_type_aval_mappings: dict[
def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]:
# Default type mapping just creates an AbstractRef from the array's aval.
aval = core.raise_to_shaped(core.get_aval(x))
aval = core.get_aval(x)
return AbstractRef(aval), x

Some files were not shown because too many files have changed in this diff Show More