mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge branch 'jax-ml:main' into main
This commit is contained in:
commit
4de58e1af7
7
.bazelrc
7
.bazelrc
@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX
|
||||
|
||||
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
||||
|
||||
# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL).
|
||||
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
|
||||
build:mkl_aarch64_threadpool --@compute_library//:openmp=false
|
||||
build:mkl_aarch64_threadpool -c opt
|
||||
|
||||
# Disable clang extention that rejects type definitions within offsetof.
|
||||
# This was added in clang-16 by https://reviews.llvm.org/D133574.
|
||||
# Can be removed once upb is updated, since a type definition is used within
|
||||
@ -104,6 +109,8 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
||||
build:clang --copt=-Wno-gnu-offsetof-extensions
|
||||
# Disable clang extention that rejects unknown arguments.
|
||||
build:clang --copt=-Qunused-arguments
|
||||
# Error on struct/class mismatches, since this causes link failures on Windows.
|
||||
build:clang --copt=-Werror=mismatched-tags
|
||||
|
||||
# Configs for CUDA
|
||||
build:cuda --repo_env TF_NEED_CUDA=1
|
||||
|
12
.github/workflows/bazel_cpu_rbe.yml
vendored
12
.github/workflows/bazel_cpu_rbe.yml
vendored
@ -11,6 +11,9 @@ on:
|
||||
options:
|
||||
- 'yes'
|
||||
- 'no'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
@ -21,15 +24,18 @@ jobs:
|
||||
if: github.event.repository.fork == false
|
||||
strategy:
|
||||
matrix:
|
||||
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"]
|
||||
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
|
||||
enable-x_64: [1, 0]
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available
|
||||
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
|
||||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }}
|
||||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
|
||||
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
|
||||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
|
||||
|
||||
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
7
.github/workflows/bazel_gpu_rbe.yml
vendored
7
.github/workflows/bazel_gpu_rbe.yml
vendored
@ -11,6 +11,9 @@ on:
|
||||
options:
|
||||
- 'yes'
|
||||
- 'no'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
@ -22,12 +25,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
runner: ["linux-x86-n2-16"]
|
||||
enable-x_64: [1, 0]
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
|
||||
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
|
||||
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
|
||||
|
||||
name: "Bazel single accelerator GPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
12
.github/workflows/ci-build.yaml
vendored
12
.github/workflows/ci-build.yaml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
with:
|
||||
python-version: 3.11
|
||||
- run: python -m pip install pre-commit
|
||||
- uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
|
||||
@ -79,7 +79,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
@ -126,7 +126,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
@ -169,7 +169,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
@ -204,7 +204,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
@ -245,7 +245,7 @@ jobs:
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
|
||||
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
|
||||
|
4
.github/workflows/jax-array-api.yml
vendored
4
.github/workflows/jax-array-api.yml
vendored
@ -38,11 +38,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install .[ci]
|
||||
python -m pip install -r array-api-tests/requirements.txt
|
||||
python -m pip install pytest-xdist -r array-api-tests/requirements.txt
|
||||
- name: Run the test suite
|
||||
env:
|
||||
ARRAY_API_TESTS_MODULE: jax.numpy
|
||||
JAX_ENABLE_X64: 'true'
|
||||
run: |
|
||||
cd ${GITHUB_WORKSPACE}/array-api-tests
|
||||
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
|
||||
pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
|
||||
|
@ -36,7 +36,7 @@ repos:
|
||||
- id: mypy
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
|
||||
additional_dependencies: [types-requests==2.31.0, jaxlib]
|
||||
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
|
30
CHANGELOG.md
30
CHANGELOG.md
@ -10,7 +10,35 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
|
||||
-->
|
||||
|
||||
## jax 0.4.36
|
||||
## jax 0.4.38
|
||||
|
||||
* Changes:
|
||||
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
|
||||
as shortcuts of the corresponding `tree_util` functions.
|
||||
|
||||
* Deprecations
|
||||
* a number of APIs in the internal `jax.core` namespace have been deprecated.
|
||||
Most were no-ops, were little-used, or can be replaced by APIs of the same
|
||||
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
|
||||
for information on the compatibility guarantees of these semi-public extensions.
|
||||
* Several previously-deprecated APIs have been removed, including:
|
||||
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
|
||||
`non_negative_dim`.
|
||||
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
|
||||
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
|
||||
|
||||
## jax 0.4.37 (Dec 9, 2024)
|
||||
|
||||
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
|
||||
* Fix a bug that will throw `index out of range` error in
|
||||
{func}`jax.lax.while_loop` if the user register pytree node class with
|
||||
different aux data for the flatten and flatten_with_path.
|
||||
* Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
|
||||
|
||||
## jax 0.4.36 (Dec 5, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
* This release lands "stackless", an internal change to JAX's tracing
|
||||
|
@ -47,7 +47,7 @@ are instances of such transformations. Others are
|
||||
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
|
||||
parallel programming of multiple accelerators, with more to come.
|
||||
|
||||
This is a research project, not an official Google product. Expect bugs and
|
||||
This is a research project, not an official Google product. Expect
|
||||
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||||
Please help by trying it out, [reporting
|
||||
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
|
||||
|
@ -123,6 +123,15 @@ def add_global_arguments(parser: argparse.ArgumentParser):
|
||||
help="Produce verbose output for debugging.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--detailed_timestamped_log",
|
||||
action="store_true",
|
||||
help="""
|
||||
Enable detailed logging of the Bazel command with timestamps. The logs
|
||||
will be stored and can be accessed as artifacts.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
||||
"""Adds all the arguments that applies to the artifact subcommands."""
|
||||
@ -399,7 +408,7 @@ async def main():
|
||||
else:
|
||||
requirements_command.append("//build:requirements.update")
|
||||
|
||||
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run)
|
||||
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
||||
if result.return_code != 0:
|
||||
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
||||
else:
|
||||
@ -476,7 +485,10 @@ async def main():
|
||||
|
||||
if not args.disable_mkl_dnn:
|
||||
logging.debug("Enabling MKL DNN")
|
||||
wheel_build_command.append("--config=mkl_open_source_only")
|
||||
if target_cpu == "aarch64":
|
||||
wheel_build_command.append("--config=mkl_aarch64_threadpool")
|
||||
else:
|
||||
wheel_build_command.append("--config=mkl_open_source_only")
|
||||
|
||||
if args.target_cpu_features == "release":
|
||||
if arch in ["x86_64", "AMD64"]:
|
||||
@ -597,7 +609,7 @@ async def main():
|
||||
|
||||
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
||||
|
||||
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
|
||||
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
||||
# Exit with error if any wheel build fails.
|
||||
if result.return_code != 0:
|
||||
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
||||
|
@ -75,7 +75,7 @@ class SubprocessExecutor:
|
||||
"""
|
||||
self.environment = environment or dict(os.environ)
|
||||
|
||||
async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
|
||||
async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult:
|
||||
"""
|
||||
Executes a subprocess command.
|
||||
|
||||
@ -96,14 +96,15 @@ class SubprocessExecutor:
|
||||
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
|
||||
stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
|
||||
env=self.environment,
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
|
||||
)
|
||||
if detailed_timestamped_log:
|
||||
await asyncio.gather(
|
||||
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
|
||||
)
|
||||
|
||||
result.return_code = await process.wait()
|
||||
result.end_time = datetime.datetime.now()
|
||||
|
@ -69,7 +69,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
fi
|
||||
|
||||
# Build the artifact.
|
||||
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
|
||||
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
|
1
docs/_static/pallas/vector_layout_example.svg
vendored
Normal file
1
docs/_static/pallas/vector_layout_example.svg
vendored
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 26 KiB |
@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
|
||||
|
||||
autodidax
|
||||
jep/index
|
||||
jax_internal_api
|
||||
|
@ -689,22 +689,21 @@ minimization phase.
|
||||
### Doctests
|
||||
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
You can run this using
|
||||
You can find the up-to-date command to run doctests in
|
||||
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml).
|
||||
E.g., you can run:
|
||||
|
||||
```
|
||||
pytest docs
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
```
|
||||
|
||||
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
|
||||
function docstrings will run correctly. You can run this locally using, for example:
|
||||
|
||||
```
|
||||
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py
|
||||
```
|
||||
|
||||
Keep in mind that there are several files that are marked to be skipped when the
|
||||
doctest command is run on the full package; you can see the details in
|
||||
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
|
||||
|
||||
## Type checking
|
||||
|
||||
|
@ -70,3 +70,31 @@ Common causes of OOM failures
|
||||
memory. Note however, that the algorithm is basic and you can often get better
|
||||
trade-off between compute and memory by disabling the automatic remat pass and doing
|
||||
it manually with `the jax.remat API <https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html>`_
|
||||
|
||||
|
||||
Experimental features
|
||||
---------------------
|
||||
|
||||
Features here are experimental and must be tried with caution.
|
||||
|
||||
``TF_GPU_ALLOCATOR=cuda_malloc_async``
|
||||
This replace XLA's own BFC memory allocator with `cudaMallocAsync
|
||||
<https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`_.
|
||||
This will remove the big fixed pre-allocation and use a memory pool that grows.
|
||||
The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`.
|
||||
|
||||
The risk are:
|
||||
|
||||
- that memory fragmentation is different, so if you are close to the
|
||||
limit, the exact OOM case due to fragmentation will be different.
|
||||
- The allocation time won't be all paid at the start, but be incurred
|
||||
when the memory pool need to be increased. So you could
|
||||
experience less speed stability at the start and for benchmarks
|
||||
it will be even more important to ignore the first few iterations.
|
||||
|
||||
The risks can be mitigated by pre-allocating a signigicant chunk and
|
||||
still get the benefit of having a growing memory pool. This can be
|
||||
done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1`
|
||||
it will preallocate the same as what was allocatedy by
|
||||
default. Otherwise, it is the size in bytes that you want to
|
||||
preallocate.
|
||||
|
@ -44,11 +44,8 @@ example, we can add this to the top of a Python file:
|
||||
```python
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||
'--xla_gpu_triton_gemm_any=True '
|
||||
'--xla_gpu_enable_async_collectives=true '
|
||||
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||
)
|
||||
```
|
||||
|
||||
@ -58,9 +55,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
|
||||
### Code generation flags
|
||||
|
||||
* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
|
||||
softmax fusion, based on pattern-matching backed by Triton code generation.
|
||||
The default value is False.
|
||||
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
|
||||
any GEMM that it supports. The default value is False.
|
||||
|
||||
@ -69,6 +63,20 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
* **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
|
||||
schedulers to overlap asynchronous communication with computation efficiently.
|
||||
The default value is False.
|
||||
* **--xla_gpu_memory_limit_slop_factor** This flag serves as a multiplier applied
|
||||
to the total available memory, creating a threshold that guides the Latency Hiding
|
||||
Scheduler (LHS) in balancing memory reduction and latency hiding optimizations.
|
||||
The default value is 95.
|
||||
|
||||
This factor effectively establishes a memory limit for compiler passes, determining
|
||||
when the scheduler should prioritize:
|
||||
1. Memory reduction: When memory usage approaches or exceeds the calculated threshold.
|
||||
2. Latency hiding: When memory usage is below the threshold, allowing for more
|
||||
aggressive optimizations that may temporarily increase memory usage but improve
|
||||
overall performance.
|
||||
|
||||
By adjusting this factor, users can fine-tune the trade-off between memory efficiency
|
||||
and performance optimizations.
|
||||
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
|
||||
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
|
||||
i-th layer computation. It also enables overlapping (i+1)-th layer
|
||||
|
@ -253,18 +253,14 @@ simply run:
|
||||
conda install jax -c conda-forge
|
||||
```
|
||||
|
||||
To install it on a machine with an NVIDIA GPU, run:
|
||||
If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`.
|
||||
|
||||
To ensure that the jax version you are installing is indeed CUDA-enabled, run:
|
||||
|
||||
```bash
|
||||
conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia
|
||||
conda install "jaxlib=*=*cuda*" jax -c conda-forge
|
||||
```
|
||||
|
||||
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
|
||||
JAX requires. You must therefore either install the `cuda-nvcc` package from
|
||||
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
|
||||
is in your path. The channel order above is important (`conda-forge` before
|
||||
`nvidia`).
|
||||
|
||||
If you would like to override which release of CUDA is used by JAX, or to
|
||||
install the CUDA build on a machine without GPUs, follow the instructions in the
|
||||
[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
|
||||
|
@ -14,8 +14,11 @@ Classes
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Exported
|
||||
DisabledSafetyCheck
|
||||
.. autoclass:: Exported
|
||||
:members:
|
||||
|
||||
.. autoclass:: DisabledSafetyCheck
|
||||
:members:
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
18
docs/jax.extend.core.rst
Normal file
18
docs/jax.extend.core.rst
Normal file
@ -0,0 +1,18 @@
|
||||
``jax.extend.core`` module
|
||||
==========================
|
||||
|
||||
.. automodule:: jax.extend.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ClosedJaxpr
|
||||
Jaxpr
|
||||
JaxprEqn
|
||||
Literal
|
||||
Primitive
|
||||
Token
|
||||
Var
|
||||
array_types
|
||||
jaxpr_as_fun
|
||||
primitives
|
@ -11,6 +11,7 @@ Modules
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.extend.core
|
||||
jax.extend.ffi
|
||||
jax.extend.linear_util
|
||||
jax.extend.mlir
|
||||
|
@ -11,7 +11,6 @@ jax.lib.xla_bridge
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
default_backend
|
||||
get_backend
|
||||
get_compile_options
|
||||
|
||||
|
@ -274,6 +274,7 @@ namespace; they are listed below.
|
||||
mask_indices
|
||||
matmul
|
||||
matrix_transpose
|
||||
matvec
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
@ -428,6 +429,7 @@ namespace; they are listed below.
|
||||
var
|
||||
vdot
|
||||
vecdot
|
||||
vecmat
|
||||
vectorize
|
||||
vsplit
|
||||
vstack
|
||||
|
13
docs/jax.rst
13
docs/jax.rst
@ -102,6 +102,9 @@ Automatic differentiation
|
||||
closure_convert
|
||||
checkpoint
|
||||
|
||||
Customization
|
||||
-------------
|
||||
|
||||
``custom_jvp``
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
@ -121,6 +124,16 @@ Automatic differentiation
|
||||
custom_vjp
|
||||
custom_vjp.defvjp
|
||||
|
||||
``custom_batching``
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
custom_batching.custom_vmap
|
||||
custom_batching.custom_vmap.def_vmap
|
||||
custom_batching.sequential_vmap
|
||||
|
||||
jax.Array (:code:`jax.Array`)
|
||||
-----------------------------
|
||||
|
||||
|
@ -13,8 +13,11 @@ List of Functions
|
||||
|
||||
all
|
||||
flatten
|
||||
flatten_with_path
|
||||
leaves
|
||||
leaves_with_path
|
||||
map
|
||||
map_with_path
|
||||
reduce
|
||||
structure
|
||||
transpose
|
||||
|
@ -1,14 +0,0 @@
|
||||
Internal API reference
|
||||
======================
|
||||
|
||||
core
|
||||
----
|
||||
|
||||
.. currentmodule:: jax.core
|
||||
.. automodule:: jax.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Jaxpr
|
||||
ClosedJaxpr
|
@ -11,7 +11,16 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
-->
|
||||
|
||||
## Released with jax 0.4.35
|
||||
## Released with jax 0.4.37
|
||||
|
||||
* New functionality
|
||||
|
||||
* Added support for `DotAlgorithmPreset` precision arguments for `dot`
|
||||
lowering on Triton backend.
|
||||
|
||||
## Released with jax 0.4.36 (December 6, 2024)
|
||||
|
||||
## Released with jax 0.4.35 (October 22, 2024)
|
||||
|
||||
* Removals
|
||||
|
||||
|
@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results.
|
||||
spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a
|
||||
low-level compiler error message complaining about an out-of-memory error.
|
||||
|
||||
Dimension ordering is meaningful
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Array Layouts
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Dimension ordering of arrays is meaningful in Pallas.
|
||||
In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually
|
||||
has no impact on performance, as the compiler is free to rearrange them.
|
||||
However, as Pallas is meant to expose lower-level capabilities, the dimension
|
||||
order can have great impact on the quality of generated code.
|
||||
|
||||
Recall that the TPUs perform bulk of the computation on 2D vector registers.
|
||||
Pallas TPU will only ever consider mapping the last two dimensions of
|
||||
intermediate arrays to those vector register dimensions (sublanes and lanes
|
||||
respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least
|
||||
``n`` vector registers to represent. If ``n`` becomes too large, this can lead
|
||||
to spills, and potential VMEM OOM errors due to an overly large memory footprint.
|
||||
But it also might not --- the low-level compiler is free to rearrange the
|
||||
instructions to lower the register pressure, and is in fact very good at it.
|
||||
Still, it is a good rule of thumb to keep the last two dimensions large
|
||||
(especially the last dimension), while keeping the leading dimensions small.
|
||||
TPUs perform bulk of the computation on 2D vector registers, which are typically of
|
||||
size 8x128 for 32-bit values (as of TPU v6).
|
||||
When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``),
|
||||
the last two dimensions of the array will be tiled into the registers.
|
||||
Pallas will only ever consider mapping the last two dimensions of
|
||||
intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes
|
||||
respectively).
|
||||
|
||||
Here is a graphical example of how a 12x320 array can be tiled using 6 8x128
|
||||
tiles:
|
||||
|
||||
.. image:: ../../_static/pallas/vector_layout_example.svg
|
||||
|
||||
Tiled layouts have several import ramifications for kernel writers:
|
||||
|
||||
* The last two axes of an array are treated differently than other
|
||||
axes. For example, reductions, reshapes, and transposes are generally
|
||||
more expensive when involving the last two axes. Some reshapes
|
||||
involving the last two dimensions are not supported and will result in a compiler
|
||||
error, but are "free" and performed at compile time for other dimensions.
|
||||
* While sometimes unavoidable, it is generally wasteful to have singleton
|
||||
dimensions in the last two axes, since they will occupy 1 element out of
|
||||
the entire tile dimension. Consuming too many registers can
|
||||
also potentially cause register spills into VMEM which degrades kernel
|
||||
performance.
|
||||
* Related to the above point, all vector computation is padded up to the tile
|
||||
size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and
|
||||
adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two
|
||||
8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128.
|
||||
|
||||
Multicore TPU configurations
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will
|
||||
receive not only the grid indices, but also the SMEM references to the leading
|
||||
operands.
|
||||
|
||||
.. note::
|
||||
We are working on implementing examples for this feature. Stay tuned!
|
||||
See :ref:`pallas_scalar_prefetch_guide` for examples on using this
|
||||
feature.
|
||||
|
||||
Supported data types
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
At the moment Pallas TPU only supports the following data types:
|
||||
At the moment Pallas TPU supports the following data types:
|
||||
|
||||
* ``jnp.float32``
|
||||
* ``jnp.bfloat16``
|
||||
* ``jnp.int*`` (all precisions, except for ``jnp.int4``)
|
||||
* ``jnp.uint*`` (all precisions)
|
||||
* ``jnp.bool_``
|
||||
|
||||
Computation placement
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -306,14 +327,13 @@ Array constructors
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``,
|
||||
``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with
|
||||
Pallas as of today.
|
||||
``jnp.full``).
|
||||
|
||||
Reductions
|
||||
^^^^^^^^^^
|
||||
|
||||
Sum, maximum and minimum reductions are supported, but only on a single array
|
||||
axis at a time.
|
||||
``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well
|
||||
as ``any`` and ``all`` for boolean values. Integer reductions are not supported.
|
||||
|
||||
Reductions over the last array dimension are generally the slowest.
|
||||
Reductions over the second last dimension are faster, but still slower than
|
||||
@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second
|
||||
to last dimension, or (2) it adds a dimension that was just removed by a
|
||||
reduction.
|
||||
|
||||
Random Number Generation
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Pallas supports the most commonly used functions from the ``jax.random`` module,
|
||||
such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key,
|
||||
which is the default setting in JAX. Keys can be directly passed into a kernel,
|
||||
or generated inside of a kernel.
|
||||
|
||||
Control flow
|
||||
^^^^^^^^^^^^
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
"id": "ZHuzXqQ-9JUQ"
|
||||
},
|
||||
"source": [
|
||||
"(pallas_scalar_prefetch_guide)=\n",
|
||||
"\n",
|
||||
"# Scalar Prefetch and Block-Sparse Computation\n",
|
||||
"\n",
|
||||
"In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."
|
||||
|
@ -14,6 +14,8 @@ kernelspec:
|
||||
|
||||
+++ {"id": "ZHuzXqQ-9JUQ"}
|
||||
|
||||
(pallas_scalar_prefetch_guide)=
|
||||
|
||||
# Scalar Prefetch and Block-Sparse Computation
|
||||
|
||||
In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.
|
||||
|
@ -37,10 +37,10 @@ class AttrsTests(jtu.JaxTestCase):
|
||||
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
jit_array_attr(5)
|
||||
self.assertEqual(count[0], 1) # compiles once the first time
|
||||
self.assertEqual(count(), 1) # compiles once the first time
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
jit_array_attr(5)
|
||||
self.assertEqual(count[0], 0) # cache hit
|
||||
self.assertEqual(count(), 0) # cache hit
|
||||
|
||||
def test_array_attr_no_jit(self):
|
||||
with jax.disable_jit():
|
||||
|
@ -455,6 +455,7 @@ pytype_strict_library(
|
||||
":dtypes",
|
||||
":effects",
|
||||
":mesh",
|
||||
":partition_spec",
|
||||
":pretty_printer",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
@ -558,6 +559,7 @@ pytype_strict_library(
|
||||
":layout",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":partition_spec",
|
||||
":path",
|
||||
":pickle_util",
|
||||
":sharding",
|
||||
|
@ -28,7 +28,6 @@ ShapedArray = core.ShapedArray
|
||||
AbstractToken = core.AbstractToken
|
||||
abstract_token = core.abstract_token
|
||||
canonicalize_shape = core.canonicalize_shape
|
||||
raise_to_shaped = core.raise_to_shaped
|
||||
|
||||
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
|
||||
dtypes.int4, np.int8, np.int16, np.int32, np.int64,
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, TypeVar
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
|
||||
from jax._src.core import Primitive, valid_jaxtype, get_aval
|
||||
from jax._src.tree_util import register_pytree_node, tree_map
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import safe_map
|
||||
@ -51,7 +51,7 @@ def zeros_like_aval(aval: core.AbstractValue) -> Array:
|
||||
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
|
||||
|
||||
def zeros_like_jaxval(val):
|
||||
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))
|
||||
return zeros_like_aval(core.get_aval(val))
|
||||
|
||||
def instantiate(z: Zero | Array) -> Array:
|
||||
if isinstance(z, Zero):
|
||||
@ -67,7 +67,7 @@ class Zero:
|
||||
return f'Zero({self.aval})'
|
||||
@staticmethod
|
||||
def from_primal_value(val: Any) -> Zero:
|
||||
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
|
||||
return Zero(get_aval(val).to_tangent_aval())
|
||||
|
||||
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))
|
||||
|
||||
|
@ -2356,7 +2356,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
f"len(devices) = {len(devices)}.")
|
||||
|
||||
def _device_put_sharded(*xs):
|
||||
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
|
||||
avals = [core.get_aval(x) for x in xs]
|
||||
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
|
||||
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
|
||||
if a1 != a2)
|
||||
@ -2418,7 +2418,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
"a non-empty sequence.")
|
||||
def _device_put_replicated(x):
|
||||
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
core.get_aval(x))
|
||||
assert isinstance(aval, ShapedArray)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
|
@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
@lu.transformation2
|
||||
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
|
||||
sentinel = object()
|
||||
args = [sentinel] * (len(fixed_args) + len(dyn_args))
|
||||
for i, arg in zip(dyn_argnums, dyn_args):
|
||||
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
|
||||
for i, arg in zip(_dyn_argnums, dyn_args):
|
||||
args[i] = arg
|
||||
fixed_args_ = iter(fixed_args)
|
||||
fixed_args_ = iter(_fixed_args)
|
||||
args = [next(fixed_args_).val if x is sentinel else x for x in args]
|
||||
assert next(fixed_args_, sentinel) is sentinel
|
||||
return f(*args, **kwargs)
|
||||
return _fun(*args, **kwargs)
|
||||
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
kwargs: dict[str, Any]):
|
||||
@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
||||
|
||||
@lu.transformation2
|
||||
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
|
||||
return f(*args, **kwargs)
|
||||
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
|
||||
return _fun(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@ -438,9 +438,9 @@ def flat_out_axes(
|
||||
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
|
||||
ans = f(*args, **kwargs)
|
||||
spec = tree_unflatten(treedef, leaves)
|
||||
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
spec = tree_unflatten(_treedef, _leaves)
|
||||
try:
|
||||
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
|
||||
except ValueError:
|
||||
@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
|
||||
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
|
||||
"pmapped function's output.")
|
||||
raise ValueError(msg) from None
|
||||
store.store(spec_flat)
|
||||
_store.store(spec_flat)
|
||||
return ans
|
||||
|
||||
def check_callable(fun):
|
||||
@ -587,8 +587,7 @@ def _dtype(x):
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return core.raise_to_shaped(
|
||||
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
|
||||
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
@ -687,10 +686,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(f, store, *args, **kwargs):
|
||||
def result_paths(_fun, _store, *args, **kwargs):
|
||||
"linear_util transform to get output pytree paths of pre-flattened function."
|
||||
ans = f(*args, **kwargs)
|
||||
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
|
||||
|
@ -33,6 +33,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
@ -40,7 +41,6 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding,
|
||||
@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
bufs.append(buf)
|
||||
break
|
||||
else:
|
||||
bufs.append(buf)
|
||||
bufs.append(candidates_list[-1])
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
@ -1132,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
||||
|
||||
|
||||
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
|
||||
util.test_event("_array_shard_arg")
|
||||
results = []
|
||||
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
|
||||
batch_cs = []
|
||||
@ -1169,12 +1170,9 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
|
||||
results.append(
|
||||
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
|
||||
|
||||
if xla_extension_version >= 296:
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
||||
batch_xs, batch_devs, batch_shardings, batch_cs)
|
||||
else:
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter
|
||||
batch_xs, batch_devs, batch_shardings)
|
||||
util.test_event("batched_copy_array")
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
||||
batch_xs, batch_devs, batch_shardings, batch_cs)
|
||||
for i, copy_out in safe_zip(batch_indices, copy_outs):
|
||||
assert results[i] is None
|
||||
results[i] = copy_out
|
||||
|
@ -24,7 +24,6 @@ from typing import cast as type_cast
|
||||
from jax._src import config
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager as pm
|
||||
import numpy as np
|
||||
@ -301,10 +300,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
|
||||
debug_options.xla_dump_hlo_as_long_text = False
|
||||
debug_options.xla_dump_disable_metadata = False
|
||||
debug_options.xla_dump_hlo_pipeline_re = ""
|
||||
|
||||
# "Requires jaxlib 0.4.36+"
|
||||
if xla_extension_version > 296:
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
|
||||
|
||||
# Optional way to specify the cuda install path to be used by the compiler.
|
||||
# This could possibly affect the cuda version compiled with, but this should
|
||||
|
@ -387,9 +387,6 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
error = _reduce_any_error(error)
|
||||
return error, out_vals
|
||||
|
||||
def get_shaped_aval(val):
|
||||
return core.raise_to_shaped(core.get_aval(val))
|
||||
|
||||
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
||||
error: Error, *args) -> tuple[Error, list[core.Value]]:
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -760,7 +757,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
|
||||
# Get the error-effects out of all branches so the cond can be called with
|
||||
# a merged error with all these effects.
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
||||
in_avals = map(core.get_aval, [*err_vals, *ops])
|
||||
def get_error_effects_from_jaxpr(jxpr):
|
||||
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
|
||||
*in_avals)
|
||||
@ -770,7 +767,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
|
||||
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
||||
|
||||
# Update branch jaxprs to be checkified jaxprs.
|
||||
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
||||
in_avals = map(core.get_aval, [*err_vals, *ops])
|
||||
new_branches, out_trees, _ = unzip3(
|
||||
jaxpr_to_checkify_jaxpr(
|
||||
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
|
||||
@ -792,11 +789,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
num_consts, num_carry, linear, unroll, _split_transpose):
|
||||
|
||||
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
||||
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
|
||||
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
|
||||
# Query body effects to create a merged error containing all effects (such
|
||||
# that in and out carried error are of the same type).
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *new_in_aval)
|
||||
|
||||
@ -804,7 +801,7 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
||||
|
||||
# Create checked-jaxpr, with the needed pre-processing on the inputs.
|
||||
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *new_in_aval)
|
||||
|
||||
@ -840,7 +837,7 @@ def checkify_while_body_jaxpr(
|
||||
*body_jaxpr.in_avals])
|
||||
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
err_vals = map(get_shaped_aval, err_vals)
|
||||
err_vals = map(core.get_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
|
||||
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
|
||||
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
||||
@ -882,7 +879,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
||||
|
||||
cond_in_flat = [*err_vals, *c_consts, *carry]
|
||||
cond_in_flat = map(get_shaped_aval, cond_in_flat)
|
||||
cond_in_flat = map(core.get_aval, cond_in_flat)
|
||||
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
|
||||
err_tree, *cond_in_flat)
|
||||
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
|
||||
@ -906,7 +903,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = tuple(map(core.get_aval, new_vals_in))
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *in_avals)
|
||||
|
||||
@ -942,7 +939,7 @@ error_checks[pjit.pjit_p] = pjit_error_check
|
||||
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = tuple(map(core.get_aval, new_vals_in))
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
|
||||
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
|
||||
@ -963,7 +960,7 @@ def shard_map_error_check(
|
||||
# Replicated sharding for in errors.
|
||||
new_in_names = (*([{}] * num_error_vals), *in_names)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = list(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = list(map(core.get_aval, new_vals_in))
|
||||
for i, v in enumerate(in_avals):
|
||||
if not (sharder := core.shard_aval_handlers.get(type(v))):
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
|
@ -18,8 +18,6 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
import warnings
|
||||
@ -36,7 +34,6 @@ from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
|
||||
@ -192,9 +189,8 @@ def get_compile_options(
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if xla_extension_version >= 294:
|
||||
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
||||
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
||||
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
||||
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
||||
|
||||
if env_options_overrides is not None:
|
||||
# Some overrides are passed directly on build_options.
|
||||
@ -350,7 +346,33 @@ def compile_or_get_cached(
|
||||
|
||||
use_compilation_cache = compilation_cache.is_cache_used(backend)
|
||||
|
||||
is_multi_process = (
|
||||
len({device.process_index for device in devices.flatten()}) > 1
|
||||
)
|
||||
min_device_process_id = min(
|
||||
devices.flatten(), key=lambda device: device.id
|
||||
).process_index
|
||||
is_auto_pgle_used = (
|
||||
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
|
||||
)
|
||||
|
||||
if not use_compilation_cache:
|
||||
if (
|
||||
is_multi_process
|
||||
and is_auto_pgle_used
|
||||
and distributed.global_state.client is not None
|
||||
):
|
||||
compile_options.executable_build_options.fdo_profile = (
|
||||
_share_fdo_profiles(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
distributed.global_state.client,
|
||||
min_device_process_id,
|
||||
)
|
||||
)
|
||||
|
||||
return backend_compile(backend, computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
@ -375,61 +397,18 @@ def compile_or_get_cached(
|
||||
return backend_compile(backend, computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
is_multi_process = (
|
||||
len({device.process_index for device in devices.flatten()}) > 1)
|
||||
min_device_process_id = (
|
||||
min(devices.flatten(), key=lambda device: device.id).process_index)
|
||||
|
||||
# When PGLE is enabled there might be 3 types of situations:
|
||||
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
|
||||
# in the persistent cache. In this case the module should be returned from
|
||||
# cache and PGLE should be disabled for this module. Is module is stored in
|
||||
# the persistent cache under the "pgle_profiled_module_key" which calculated
|
||||
# with replacing FDO profile with flag which identify that module were PGLE
|
||||
# profiled.
|
||||
# 2. PGLE profiled module is not in the persistent cache and the module is
|
||||
# getting built with an FDO profile. In this case we need to share FDO profile
|
||||
# with other processes and store the result under the
|
||||
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
|
||||
# module.
|
||||
# 3. PGLE profiled module is not in the persistent cache and the module is
|
||||
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
|
||||
# simply return the non-PGLE profiled module from the persistent cache.
|
||||
if (config.enable_pgle.value
|
||||
and config.pgle_profiling_runs.value > 0):
|
||||
fdo_profile = compile_options.executable_build_options.fdo_profile
|
||||
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
|
||||
|
||||
pgle_profiled_module_key = compilation_cache.get_cache_key(
|
||||
if is_auto_pgle_used:
|
||||
cache_key = _resolve_pgle_module_cache_key(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
cache_key_type.IgnoreCallbacks.ALL,
|
||||
pgle_profiler,
|
||||
is_multi_process,
|
||||
cache_key,
|
||||
module_name,
|
||||
min_device_process_id,
|
||||
)
|
||||
compile_options.executable_build_options.fdo_profile = fdo_profile
|
||||
|
||||
if _is_executable_in_cache(backend, pgle_profiled_module_key):
|
||||
# Load PGLE profiled module from the persistent cache.
|
||||
cache_key = pgle_profiled_module_key
|
||||
if pgle_profiler is not None:
|
||||
pgle_profiler.disable()
|
||||
elif fdo_profile is not None and len(fdo_profile) > 0:
|
||||
# Store module under PGLE profiled module cache key.
|
||||
cache_key = pgle_profiled_module_key
|
||||
if is_multi_process and distributed.global_state.client is not None:
|
||||
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
|
||||
computation, devices, compile_options, backend,
|
||||
distributed.global_state.client,
|
||||
min_device_process_id
|
||||
)
|
||||
else:
|
||||
compile_options.executable_build_options.fdo_profile = fdo_profile
|
||||
logger.debug(
|
||||
"Compiling module %s with FDO profile: %s",
|
||||
module_name,
|
||||
compile_options.executable_build_options.fdo_profile,
|
||||
)
|
||||
|
||||
cache_retrieval_start = time.monotonic()
|
||||
retrieved_executable, retrieved_compile_time = _cache_read(
|
||||
@ -468,22 +447,6 @@ def compile_or_get_cached(
|
||||
cache_key,
|
||||
min_device_process_id
|
||||
)
|
||||
elif (
|
||||
config.share_autotune_config_between_hosts.value
|
||||
and is_multi_process
|
||||
and distributed.global_state.client is not None
|
||||
):
|
||||
log_persistent_cache_miss(module_name, cache_key)
|
||||
return _compile_and_write_autotune_config(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
distributed.global_state.client,
|
||||
module_name,
|
||||
cache_key,
|
||||
min_device_process_id
|
||||
)
|
||||
else:
|
||||
log_persistent_cache_miss(module_name, cache_key)
|
||||
return _compile_and_write_cache(
|
||||
@ -495,6 +458,75 @@ def compile_or_get_cached(
|
||||
cache_key,
|
||||
)
|
||||
|
||||
|
||||
# When PGLE is enabled there might be 3 types of situations:
|
||||
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
|
||||
# in the persistent cache. In this case the module should be returned from
|
||||
# cache and PGLE should be disabled for this module. Is module is stored in
|
||||
# the persistent cache under the "pgle_profiled_module_key" which calculated
|
||||
# with replacing FDO profile with flag which identify that module were PGLE
|
||||
# profiled.
|
||||
# 2. PGLE profiled module is not in the persistent cache and the module is
|
||||
# getting built with an FDO profile. In this case we need to share FDO profile
|
||||
# with other processes and store the result under the
|
||||
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
|
||||
# module.
|
||||
# 3. PGLE profiled module is not in the persistent cache and the module is
|
||||
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
|
||||
# simply return the non-PGLE profiled module from the persistent cache.
|
||||
def _resolve_pgle_module_cache_key(
|
||||
computation: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options: xc.CompileOptions,
|
||||
backend: xc.Client,
|
||||
pgle_profiler: profiler.PGLEProfiler | None,
|
||||
is_multi_process: bool,
|
||||
cache_key: str,
|
||||
module_name: str,
|
||||
min_device_process_id: int,
|
||||
) -> str:
|
||||
fdo_profile = compile_options.executable_build_options.fdo_profile
|
||||
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
|
||||
|
||||
pgle_profiled_module_key = compilation_cache.get_cache_key(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
cache_key_type.IgnoreCallbacks.ALL,
|
||||
)
|
||||
compile_options.executable_build_options.fdo_profile = fdo_profile
|
||||
|
||||
result_key = cache_key
|
||||
if _is_executable_in_cache(backend, pgle_profiled_module_key):
|
||||
# Load PGLE profiled module from the persistent cache.
|
||||
result_key = pgle_profiled_module_key
|
||||
if pgle_profiler is not None:
|
||||
pgle_profiler.disable()
|
||||
elif fdo_profile is not None and len(fdo_profile) > 0:
|
||||
# Store module under PGLE profiled module cache key.
|
||||
result_key = pgle_profiled_module_key
|
||||
if is_multi_process and distributed.global_state.client is not None:
|
||||
compile_options.executable_build_options.fdo_profile = (
|
||||
_share_fdo_profiles(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
distributed.global_state.client,
|
||||
min_device_process_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
compile_options.executable_build_options.fdo_profile = fdo_profile
|
||||
logger.debug(
|
||||
"Compiling module %s with FDO profile of length %d",
|
||||
module_name,
|
||||
len(compile_options.executable_build_options.fdo_profile),
|
||||
)
|
||||
return result_key
|
||||
|
||||
|
||||
# The process that has the lowest device ID should share FDO profile before
|
||||
# compilation with other processes.
|
||||
def _share_fdo_profiles(
|
||||
@ -512,32 +544,39 @@ def _share_fdo_profiles(
|
||||
return fdo_profile
|
||||
|
||||
compile_options.executable_build_options.fdo_profile = b""
|
||||
profile_key = (
|
||||
compilation_cache.get_cache_key(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
cache_key_type.IgnoreCallbacks.ALL,
|
||||
)
|
||||
+ "_fdo_sync"
|
||||
)
|
||||
try:
|
||||
profile_key = (
|
||||
compilation_cache.get_cache_key(
|
||||
computation,
|
||||
devices,
|
||||
compile_options,
|
||||
backend,
|
||||
cache_key_type.IgnoreCallbacks.ALL,
|
||||
)
|
||||
+ "_fdo_sync"
|
||||
)
|
||||
except xc._xla.XlaRuntimeError as ex:
|
||||
logger.error(
|
||||
"compile_or_get_cached: unable to generate cache key, "
|
||||
"skipping the fdo profile sharing: %s",
|
||||
ex,
|
||||
)
|
||||
return fdo_profile
|
||||
|
||||
if profile_key in _share_fdo_profiles.modules_profiles:
|
||||
return _share_fdo_profiles.modules_profiles[profile_key]
|
||||
|
||||
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
||||
if distributed.global_state.process_id == min_process_id:
|
||||
logger.debug(
|
||||
"Sharing FDO profile: %s. For module %s. Process %d.",
|
||||
fdo_profile,
|
||||
"Module %s. Sharing FDO profile. Process %d.",
|
||||
module_name,
|
||||
min_process_id,
|
||||
)
|
||||
global_client.key_value_set_bytes(profile_key, fdo_profile)
|
||||
else:
|
||||
logger.debug(
|
||||
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
|
||||
fdo_profile,
|
||||
"Module %s. Waiting for FDO profile which should be set by process %d.",
|
||||
module_name,
|
||||
min_process_id,
|
||||
)
|
||||
@ -551,113 +590,6 @@ def _share_fdo_profiles(
|
||||
|
||||
_share_fdo_profiles.modules_profiles = {}
|
||||
|
||||
|
||||
# The process with the first_process_id should compile the module and write an
|
||||
# autotune config to the K-V storage.
|
||||
def _compile_and_write_autotune_config(
|
||||
backend: xc.Client,
|
||||
computation: ir.Module,
|
||||
compile_options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
global_client: lib.xla_extension.DistributedRuntimeClient,
|
||||
module_name: str,
|
||||
cache_key: str,
|
||||
first_process_id: int
|
||||
) -> xc.LoadedExecutable:
|
||||
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
|
||||
if _compile_and_write_autotune_config.autotune_configs_dir is None:
|
||||
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
|
||||
|
||||
autotune_tmp_file = os.path.join(
|
||||
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
|
||||
)
|
||||
|
||||
if os.path.exists(autotune_tmp_file):
|
||||
logger.debug(
|
||||
"Compiling module: %s. Use existing autotune config file: %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
||||
return _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
if distributed.global_state.process_id == first_process_id:
|
||||
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
|
||||
logger.debug("Process %d compiling and dumping autotune for module: %s",
|
||||
first_process_id, module_name)
|
||||
executable = _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Writing autotune config for module %s to %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
with open(autotune_tmp_file, "rb") as f:
|
||||
autotune_config = f.read()
|
||||
|
||||
autotune_config = compilation_cache.compress_executable(autotune_config)
|
||||
global_client.key_value_set_bytes(cache_key, autotune_config)
|
||||
logger.debug(
|
||||
"Autotune config for module %s with size %d shared by cache_key %s",
|
||||
module_name,
|
||||
len(autotune_config),
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Compiling module %s, waiting for config to be shared by cache_key %s"
|
||||
"from process %d",
|
||||
module_name,
|
||||
cache_key,
|
||||
first_process_id
|
||||
)
|
||||
autotune_config = global_client.blocking_key_value_get_bytes(
|
||||
cache_key, share_timeout
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Received autotune config for module %s of size %d",
|
||||
module_name,
|
||||
len(autotune_config),
|
||||
)
|
||||
autotune_config = compilation_cache.decompress_executable(autotune_config)
|
||||
with open(autotune_tmp_file, "wb") as f:
|
||||
f.write(autotune_config)
|
||||
|
||||
logger.debug(
|
||||
"Compiling module %s, using autotune config from %s",
|
||||
module_name,
|
||||
autotune_tmp_file,
|
||||
)
|
||||
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
|
||||
executable = _compile_and_write_cache(
|
||||
backend,
|
||||
computation,
|
||||
compile_options,
|
||||
host_callbacks,
|
||||
module_name,
|
||||
cache_key,
|
||||
)
|
||||
return executable
|
||||
|
||||
_compile_and_write_autotune_config.autotune_configs_dir = None
|
||||
|
||||
# The process with the first_process_id should compile the module and write it
|
||||
# to the K-V storage.
|
||||
def _compile_and_share_module(
|
||||
|
@ -14,30 +14,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Hashable, Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast
|
||||
from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import guard_lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import logging_config
|
||||
|
||||
# TODO(phawkins): reenable pytype after xla_extension_version >= 295
|
||||
# pytype: skip-file
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
config_ext = xla_client._xla.config
|
||||
else:
|
||||
config_ext = None
|
||||
config_ext = xla_client._xla.config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -200,91 +191,38 @@ class Config:
|
||||
already_configured_with_absl = True
|
||||
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
return (axis_env_state.value, mesh_context_manager.value,
|
||||
xla_metadata_context_manager.value,
|
||||
abstract_mesh_context_manager.value,
|
||||
device_context.value,
|
||||
compute_on_context_manager.value, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
use_direct_linearize.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
else:
|
||||
def trace_context():
|
||||
"""Returns a tuple of configuration values that affect tracing.
|
||||
|
||||
These values are included in the cache key for linear_util.cache.
|
||||
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
tls = jax_jit.thread_local_state()
|
||||
axis_env_state = ()
|
||||
mesh_context_manager = ()
|
||||
abstract_mesh_context_manager = ()
|
||||
device_context = ()
|
||||
xla_metadata_context_manager = ()
|
||||
compute_on_context_manager = ()
|
||||
|
||||
context: Any = tls.extra_jit_context
|
||||
if context and context.axis_env_state is not None:
|
||||
axis_env_state = context.axis_env_state
|
||||
if context and context.mesh_context_manager:
|
||||
mesh_context_manager = context.mesh_context_manager
|
||||
if context and context.abstract_mesh_context_manager:
|
||||
abstract_mesh_context_manager = context.abstract_mesh_context_manager
|
||||
if context and context.device_context:
|
||||
device_context = context.device_context
|
||||
if context and context.xla_metadata_context_manager:
|
||||
xla_metadata_context_manager = context.xla_metadata_context_manager
|
||||
if context and context.compute_on_context_manager:
|
||||
compute_on_context_manager = context.compute_on_context_manager
|
||||
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
|
||||
device_context, xla_metadata_context_manager,
|
||||
compute_on_context_manager, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
use_direct_linearize.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately.
|
||||
"""
|
||||
return (axis_env_state.value, mesh_context_manager.value,
|
||||
xla_metadata_context_manager.value,
|
||||
abstract_mesh_context_manager.value,
|
||||
device_context.value,
|
||||
compute_on_context_manager.value, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
eager_constant_folding.value,
|
||||
numpy_dtype_promotion.value,
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
use_direct_linearize.value,
|
||||
softmax_custom_jvp.value,
|
||||
enable_memories.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
jax_xla_profile_version.value,
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
hlo_source_file_canonicalization_regex.value,
|
||||
pgle_profiling_runs.value,
|
||||
enable_pgle.value,
|
||||
use_shardy_partitioner.value)
|
||||
|
||||
config = Config()
|
||||
|
||||
@ -296,185 +234,85 @@ parse_flags_with_absl = config.parse_flags_with_absl
|
||||
class NoDefault: pass
|
||||
no_default = NoDefault()
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
class State(config_ext.Config[_T]):
|
||||
class State(config_ext.Config[_T]):
|
||||
|
||||
__slots__ = (
|
||||
'_name', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
__slots__ = (
|
||||
'_name', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
include_in_jit_key: bool = False,
|
||||
):
|
||||
super().__init__(default, include_in_jit_key)
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
if self._validator:
|
||||
self._validator(default)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(default)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
include_in_jit_key: bool = False,
|
||||
):
|
||||
super().__init__(default, include_in_jit_key)
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
if self._validator:
|
||||
self._validator(default)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(default)
|
||||
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self.set_global(value)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self.set_global(value)
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = self.swap_local(new_val)
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = self.swap_local(new_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set_local(prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set_local(prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
if prev_val is config_ext.unset:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
self._update_thread_local_hook(cast(Optional[Any], prev_val))
|
||||
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self.get_global())
|
||||
|
||||
else:
|
||||
class _Unset: pass
|
||||
unset = _Unset()
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
class State(Generic[_T]): # type: ignore[no-redef]
|
||||
|
||||
__slots__ = (
|
||||
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',
|
||||
'_validator', '_default_context_manager_value', '__doc__', '__name__',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
default: _T,
|
||||
help,
|
||||
update_global_hook: Callable[[_T], None] | None = None,
|
||||
update_thread_local_hook: Callable[[_T | None], None] | None = None,
|
||||
validator: Callable[[Any], None] | None = None,
|
||||
extra_description: str = '',
|
||||
default_context_manager_value: Any = no_default,
|
||||
include_in_jit_key: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self.__name__ = name[4:] if name.startswith('jax_') else name
|
||||
self.__doc__ = (f"Context manager for `{name}` config option"
|
||||
f"{extra_description}.\n\n{help}")
|
||||
if include_in_jit_key:
|
||||
assert update_global_hook is None
|
||||
assert update_thread_local_hook is None
|
||||
update_global_hook = lambda val: _update_global_jit_state(
|
||||
**{self.__name__: val})
|
||||
update_thread_local_hook = lambda val: update_thread_local_jit_state(
|
||||
**{self.__name__: val})
|
||||
self._update_global_hook = update_global_hook
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._validator = validator
|
||||
self._default_context_manager_value = default_context_manager_value
|
||||
self._set(default)
|
||||
def __bool__(self) -> NoReturn:
|
||||
raise TypeError(
|
||||
"bool() not supported for instances of type '{0}' "
|
||||
"(did you mean to use '{0}.value' instead?)".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _set(self, value: _T) -> None:
|
||||
if self._validator:
|
||||
self._validator(value)
|
||||
self._value = value
|
||||
if self._update_global_hook:
|
||||
self._update_global_hook(value)
|
||||
|
||||
@property
|
||||
def value(self) -> _T:
|
||||
val = _thread_local_state.__dict__.get(self._name, unset)
|
||||
return cast(_T, val) if val is not unset else self._value
|
||||
|
||||
def get_local(self) -> Any:
|
||||
return _thread_local_state.__dict__.get(self._name, unset)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, new_val: Any = no_default):
|
||||
if new_val is no_default:
|
||||
if self._default_context_manager_value is not no_default:
|
||||
new_val = self._default_context_manager_value # default_context_manager_value provided to constructor
|
||||
if prev_val is config_ext.unset:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
# no default_value provided to constructor and no value provided as an
|
||||
# argument, so we raise an error
|
||||
raise TypeError(f"Context manager for {self.__name__} config option "
|
||||
"requires an argument representing the new value for "
|
||||
"the config option.")
|
||||
if self._validator:
|
||||
self._validator(new_val)
|
||||
prev_val = getattr(_thread_local_state, self._name, unset)
|
||||
setattr(_thread_local_state, self._name, new_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(new_val)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev_val is unset:
|
||||
delattr(_thread_local_state, self._name)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(None)
|
||||
else:
|
||||
setattr(_thread_local_state, self._name, prev_val)
|
||||
if self._update_thread_local_hook:
|
||||
self._update_thread_local_hook(cast(_T, prev_val))
|
||||
self._update_thread_local_hook(cast(Optional[Any], prev_val))
|
||||
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
def _add_hooks(self, update_global_hook, update_thread_local_hook):
|
||||
"""Private method that adds hooks to an existing context-manager.
|
||||
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self._value)
|
||||
Used to avoid cyclic import dependencies."""
|
||||
self._update_thread_local_hook = update_thread_local_hook
|
||||
self._update_global_hook = update_global_hook
|
||||
update_global_hook(self.get_global())
|
||||
|
||||
|
||||
UPGRADE_BOOL_HELP = (
|
||||
@ -975,132 +813,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
|
||||
already_configured_with_absl = False
|
||||
|
||||
|
||||
if xla_extension_version >= 295:
|
||||
trace_state = config_ext.Config(None, include_in_jit_key=True)
|
||||
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
||||
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
device_context = config_ext.Config((), include_in_jit_key=True)
|
||||
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
else:
|
||||
# The C++ JIT maintains its own copy of several configuration items as
|
||||
# a global/thread-local state. These methods allow updates to part of the
|
||||
# state when a configuration value changes.
|
||||
class _GlobalExtraJitContext(NamedTuple):
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool = False
|
||||
eager_constant_folding: bool = False
|
||||
random_seed_offset: int = 0
|
||||
threefry_partitionable: bool = False
|
||||
threefry_gpu_kernel_lowering: bool = False
|
||||
sharding_in_types: bool = False
|
||||
use_direct_linearize: bool = False
|
||||
softmax_custom_jvp: bool = False
|
||||
xla_profile_version: int = 0
|
||||
pgle_profiling_runs: int = 0
|
||||
enable_pgle: bool = False
|
||||
use_shardy_partitioner: bool = False
|
||||
|
||||
|
||||
def _update_global_jit_state(**kw):
|
||||
gs = jax_jit.global_state()
|
||||
context = gs.extra_jit_context or _GlobalExtraJitContext()
|
||||
gs.extra_jit_context = context._replace(**kw)
|
||||
|
||||
|
||||
class _ThreadLocalExtraJitContext(NamedTuple):
|
||||
"""A namedtuple containing states to add to the cache key.
|
||||
|
||||
Just in time compilation (for jit, pmap, etc) behavior is configurable through
|
||||
global and thread-local options, used in the cache key.
|
||||
|
||||
The initialization, which uses both config.py and core.py is done using
|
||||
`_update_thread_local_jit_state` in core.py to prevent circular imports.
|
||||
"""
|
||||
trace_state: Any | None = None
|
||||
axis_env_state: Hashable = ()
|
||||
mesh_context_manager: Hashable = ()
|
||||
abstract_mesh_context_manager: Hashable = ()
|
||||
device_context: Hashable = ()
|
||||
compute_on_context_manager: Hashable = ()
|
||||
xla_metadata_context_manager: Hashable = ()
|
||||
|
||||
# Values set by _StateContextManager context managers.
|
||||
# CAUTION: these must be initialized to `None`! The state context manager
|
||||
# restores these to None on exit. If the object default is not `None`, the
|
||||
# context manager is not a no-op, which leads to problems with stale state
|
||||
# (e.g. spurious cache misses in tests).
|
||||
numpy_rank_promotion: str | None = None
|
||||
numpy_dtype_promotion: str | None = None
|
||||
default_matmul_precision: Any | None = None
|
||||
dynamic_shapes: bool | None = None
|
||||
eager_constant_folding : bool | None = None
|
||||
random_seed_offset: int | None = None
|
||||
threefry_partitionable: bool | None = None
|
||||
threefry_gpu_kernel_lowering: bool | None = None
|
||||
sharding_in_types: bool | None = None
|
||||
use_direct_linearize: bool | None = None
|
||||
softmax_custom_jvp: bool | None = None
|
||||
xla_profile_version: int | None = None
|
||||
pgle_profiling_runs: int | None = None
|
||||
enable_pgle: bool | None = None
|
||||
use_shardy_partitioner: bool | None = None
|
||||
|
||||
|
||||
class _ThreadLocalStateCache(threading.local):
|
||||
""""A thread local cache for _ThreadLocalExtraJitContext
|
||||
|
||||
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
|
||||
incurring dispatch overhead for comparing this python object during jit calls.
|
||||
We want to deduplicate the objects that have the same hash/equality to also
|
||||
have the same object ID, since the equality check is much faster if the object
|
||||
IDs match.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.canonicalize = functools.lru_cache(128)(lambda x: x)
|
||||
|
||||
|
||||
_thread_local_state_cache = _ThreadLocalStateCache()
|
||||
|
||||
|
||||
def update_thread_local_jit_state(**kw):
|
||||
tls = jax_jit.thread_local_state()
|
||||
# After xla_client._version >= 70, the thread_local object will necessarily
|
||||
# be initialized when accessed. The following line can be removed when the
|
||||
# minimum jaxlib version is past version 70
|
||||
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
|
||||
tmp = context._replace(**kw)
|
||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||
|
||||
class JitConfig:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.get_local()
|
||||
|
||||
def get_local(self):
|
||||
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
|
||||
|
||||
def set_local(self, value):
|
||||
update_thread_local_jit_state(**{self._name: value})
|
||||
|
||||
def swap_local(self, new_value):
|
||||
prev_value = self.value
|
||||
self.set_local(new_value)
|
||||
return prev_value
|
||||
|
||||
trace_state = JitConfig('trace_state')
|
||||
axis_env_state = JitConfig('axis_env_state')
|
||||
mesh_context_manager = JitConfig('mesh_context_manager')
|
||||
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
|
||||
device_context = JitConfig('device_context')
|
||||
compute_on_context_manager = JitConfig('compute_on_context_manager')
|
||||
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
|
||||
trace_state = config_ext.Config(None, include_in_jit_key=True)
|
||||
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
||||
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
device_context = config_ext.Config((), include_in_jit_key=True)
|
||||
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
|
||||
|
||||
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
||||
@ -1254,10 +973,10 @@ pmap_shmap_merge = bool_state(
|
||||
help='If True, pmap and shard_map API will be merged.')
|
||||
|
||||
def _update_jax_memories_global(val):
|
||||
lib.jax_jit.global_state().enable_memories = val
|
||||
jax_jit.global_state().enable_memories = val
|
||||
|
||||
def _update_jax_memories_thread_local(val):
|
||||
lib.jax_jit.thread_local_state().enable_memories = val
|
||||
jax_jit.thread_local_state().enable_memories = val
|
||||
|
||||
enable_memories = bool_state(
|
||||
'jax_enable_memories',
|
||||
@ -1450,20 +1169,6 @@ traceback_in_locations_limit = int_state(
|
||||
),
|
||||
)
|
||||
|
||||
share_autotune_config_between_hosts = bool_state(
|
||||
name='jax_share_autotune_config_between_hosts',
|
||||
default=False,
|
||||
help=(
|
||||
'If set to True, the coordinator process will share autotune configs '
|
||||
'other participants. This will increase overall compilation time, but '
|
||||
'will lead to equal compiled modules in each process. '
|
||||
'If both jax_share_binary_between_hosts and '
|
||||
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
|
||||
"shared when it's possible and autotune config sharing will be used "
|
||||
'as a fallback.'
|
||||
),
|
||||
)
|
||||
|
||||
share_binary_between_hosts = bool_state(
|
||||
name='jax_share_binary_between_hosts',
|
||||
default=False,
|
||||
@ -1576,10 +1281,10 @@ disallow_mesh_context_manager = bool_state(
|
||||
)
|
||||
|
||||
def _update_x64_global(val):
|
||||
lib.jax_jit.global_state().enable_x64 = val
|
||||
jax_jit.global_state().enable_x64 = val
|
||||
|
||||
def _update_x64_thread_local(val):
|
||||
lib.jax_jit.thread_local_state().enable_x64 = val
|
||||
jax_jit.thread_local_state().enable_x64 = val
|
||||
|
||||
enable_x64 = bool_state(
|
||||
name='jax_enable_x64',
|
||||
@ -1594,11 +1299,11 @@ config._contextmanager_flags.remove('jax_enable_x64')
|
||||
setattr(Config, "x64_enabled", property(lambda _: enable_x64.value))
|
||||
|
||||
def _update_default_device_global(val):
|
||||
lib.jax_jit.global_state().default_device = val
|
||||
jax_jit.global_state().default_device = val
|
||||
|
||||
|
||||
def _update_default_device_thread_local(val):
|
||||
lib.jax_jit.thread_local_state().default_device = val
|
||||
jax_jit.thread_local_state().default_device = val
|
||||
|
||||
|
||||
def _validate_default_device(val):
|
||||
@ -1632,10 +1337,10 @@ default_device = string_or_object_state(
|
||||
validator=_validate_default_device)
|
||||
|
||||
def _update_disable_jit_global(val):
|
||||
lib.jax_jit.global_state().disable_jit = val
|
||||
jax_jit.global_state().disable_jit = val
|
||||
|
||||
def _update_disable_jit_thread_local(val):
|
||||
lib.jax_jit.thread_local_state().disable_jit = val
|
||||
jax_jit.thread_local_state().disable_jit = val
|
||||
|
||||
disable_jit = bool_state(
|
||||
name='jax_disable_jit',
|
||||
|
@ -39,6 +39,7 @@ from jax._src import config
|
||||
from jax._src import effects
|
||||
from jax._src import compute_on
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
|
||||
from jax._src.errors import (
|
||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
@ -431,6 +432,8 @@ class Primitive:
|
||||
call_primitive: bool = False
|
||||
# set for map primitives processed in final style.
|
||||
map_primitive: bool = False
|
||||
# set for ref primitives
|
||||
ref_primitive: bool = False
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
@ -1597,19 +1600,34 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def _maybe_modify_sharding(sharding):
|
||||
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
|
||||
return sharding
|
||||
|
||||
new_spec = []
|
||||
for s in sharding.spec:
|
||||
if s is None or isinstance(s, UnconstrainedSingleton):
|
||||
new_spec.append(s)
|
||||
else:
|
||||
temp_s = s[0] if isinstance(s, tuple) else s
|
||||
new_spec.append(
|
||||
P.UNCONSTRAINED
|
||||
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
|
||||
return sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
def get_sharding(sharding, ndim):
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
if sharding is not None:
|
||||
assert len(sharding.spec) == ndim
|
||||
return sharding
|
||||
return _maybe_modify_sharding(sharding)
|
||||
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
|
||||
# code.
|
||||
if not context_mesh:
|
||||
return None
|
||||
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
return NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
|
||||
@ -1672,10 +1690,8 @@ class ShapedArray(UnshapedArray):
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if hasattr(self, 'sharding') and self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
axis_types = self.sharding.mesh.axis_types
|
||||
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
|
||||
return f'{dt_str}[{shapestr}]{axt}'
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
@ -1687,26 +1703,13 @@ class ShapedArray(UnshapedArray):
|
||||
raise TypeError("len() of unsized object") from err # same as numpy error
|
||||
|
||||
|
||||
def _get_axis_type_str(axis_types):
|
||||
from jax._src.mesh import AxisTypes # type: ignore
|
||||
|
||||
out = []
|
||||
for t, axes in axis_types.items():
|
||||
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
|
||||
if t == AxisTypes.Collective:
|
||||
out.append(f"C:{a}")
|
||||
elif t == AxisTypes.User:
|
||||
out.append(f"U:{a}")
|
||||
else:
|
||||
assert t == AxisTypes.Auto
|
||||
out.append(f"A:{a}")
|
||||
return f"{{{', '.join(out)}}}"
|
||||
|
||||
def _get_shape_sharding_str(shape, spec):
|
||||
out = []
|
||||
for s1, s2 in zip(shape, spec):
|
||||
if s2 is None:
|
||||
out.append(f"{s1}")
|
||||
elif isinstance(s2, UnconstrainedSingleton):
|
||||
out.append(f"{s1}")
|
||||
elif isinstance(s2, tuple):
|
||||
ss = ','.join(s for s in s2)
|
||||
out.append(f"{s1}@({ss})")
|
||||
@ -1882,6 +1885,7 @@ pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
mutable_array_p = Primitive('mutable_array')
|
||||
mutable_array_p.ref_primitive = True
|
||||
|
||||
class InternalMutableArrayEffect(effects.Effect):
|
||||
pass
|
||||
@ -1899,6 +1903,18 @@ def _mutable_array_impl(init_val):
|
||||
aval = get_aval(init_val)
|
||||
return MutableArray(AbstractRef(aval), init_val)
|
||||
|
||||
def freeze(ref):
|
||||
return freeze_p.bind(ref)
|
||||
freeze_p = Primitive('freeze')
|
||||
freeze_p.ref_primitive = True
|
||||
|
||||
@freeze_p.def_effectful_abstract_eval
|
||||
def freeze_abstract_eval(ref_aval):
|
||||
return ref_aval.inner_aval, {internal_mutable_array_effect}
|
||||
|
||||
@freeze_p.def_impl
|
||||
def _freeze_impl(ref):
|
||||
return ref[()]
|
||||
|
||||
class AbstractToken(AbstractValue):
|
||||
def str_short(self, short_dtypes=False): return 'Tok'
|
||||
@ -2516,10 +2532,11 @@ def _check_jaxpr(
|
||||
|
||||
# Check the computed effect type matches the eqn's annotation, and is
|
||||
# included in the jaxpr's annotation.
|
||||
if prim is mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
in_idx[outvar] = None # type: ignore
|
||||
mut_arrays.add(outvar)
|
||||
if prim.ref_primitive:
|
||||
if prim is mutable_array_p:
|
||||
outvar, = eqn.outvars
|
||||
in_idx[outvar] = None # type: ignore
|
||||
mut_arrays.add(outvar)
|
||||
if eqn.effects != eqn_effects:
|
||||
raise JaxprTypeError("Inferred effects do not match equation effects. "
|
||||
f"Equation effects: {eqn.effects}. "
|
||||
@ -2639,16 +2656,10 @@ def _check_call(ctx_factory, prim, in_atoms, params):
|
||||
return aval
|
||||
for v, x in zip(call_jaxpr.invars, in_atoms):
|
||||
if not typecompat(substitute(v.aval), x.aval):
|
||||
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
|
||||
# on it.
|
||||
if (config.sharding_in_types.value and isinstance(x, Literal) and
|
||||
v.aval.sharding is not None and x.val.ndim == 0):
|
||||
pass
|
||||
else:
|
||||
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
|
||||
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
|
||||
f"{x.aval} to jaxpr expecting type "
|
||||
f"{substitute(v.aval)}")
|
||||
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
|
||||
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
|
||||
f"{x.aval} to jaxpr expecting type "
|
||||
f"{substitute(v.aval)}")
|
||||
env[v] = x if type(x) is Var else x.val
|
||||
|
||||
_check_jaxpr(ctx_factory, call_jaxpr)
|
||||
|
@ -18,8 +18,8 @@ import json
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.custom_partitioning import custom_partitioning
|
||||
from jax._src.interpreters import batching
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import functools
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.mlir import hlo
|
||||
from jax.interpreters.mlir import ir
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
import functools
|
||||
import operator
|
||||
|
||||
@ -48,17 +49,93 @@ zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
@custom_api_util.register_custom_decorator_type
|
||||
class custom_vmap:
|
||||
fun: Callable
|
||||
vmap_rule: Callable | None
|
||||
"""Customize the vmap behavior of a JAX-transformable function.
|
||||
|
||||
def __init__(self, fun: Callable):
|
||||
This decorator is used to customize the behavior of a JAX function under the
|
||||
:func:`jax.vmap` transformation. A ``custom_vmap``-decorated function will
|
||||
mostly (see below for caveats) have the same behavior as the underlying
|
||||
function, except when batched using :py:func:`jax.vmap`. When batched, the
|
||||
rule defined using :py:func:`~jax.custom_batching.custom_vmap.def_vmap` will
|
||||
be used.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_batching.custom_vmap
|
||||
... def f(x, y):
|
||||
... return x + y
|
||||
...
|
||||
>>> @f.def_vmap
|
||||
... def f_vmap_rule(axis_size, in_batched, xs, ys):
|
||||
... assert all(in_batched)
|
||||
... assert xs.shape[0] == axis_size
|
||||
... assert ys.shape[0] == axis_size
|
||||
... out_batched = True
|
||||
... return xs * ys, out_batched
|
||||
...
|
||||
>>> xs = jnp.arange(3)
|
||||
>>> ys = jnp.arange(1, 4)
|
||||
>>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys
|
||||
Array([0, 2, 6], dtype=int32)
|
||||
|
||||
Of note, ``custom_vmap`` functions do not support reverse-mode autodiff. To
|
||||
customize both vmap and reverse-mode autodiff, combine ``custom_vmap`` with
|
||||
:py:class:`jax.custom_vjp`. For example:
|
||||
|
||||
>>> @jax.custom_vjp
|
||||
... @jax.custom_batching.custom_vmap
|
||||
... def f(x, y):
|
||||
... return jnp.sin(x) * y
|
||||
...
|
||||
>>> @f.def_vmap
|
||||
... def f_vmap_rule(axis_size, in_batched, xs, ys):
|
||||
... return jnp.cos(xs) * ys, True
|
||||
...
|
||||
>>> def f_fwd(x, y):
|
||||
... return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
||||
...
|
||||
>>> def f_bwd(res, g):
|
||||
... cos_x, sin_x, y = res
|
||||
... return (cos_x * g * y, sin_x * g)
|
||||
...
|
||||
>>> f.defvjp(f_fwd, f_bwd)
|
||||
>>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3))
|
||||
Array([1., 1., 1.], dtype=float32)
|
||||
>>> jax.grad(f)(jnp.zeros(()), jnp.ones(()))
|
||||
Array(1., dtype=float32)
|
||||
|
||||
Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the
|
||||
``custom_vmap``-decorated function.
|
||||
"""
|
||||
|
||||
fun: Callable[..., Any]
|
||||
vmap_rule: Callable[..., tuple[Any, Any]] | None
|
||||
|
||||
def __init__(self, fun: Callable[..., Any]):
|
||||
functools.update_wrapper(self, fun)
|
||||
self.fun = fun
|
||||
self.vmap_rule = None
|
||||
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
def def_vmap(self, vmap_rule: Callable) -> Callable:
|
||||
def def_vmap(
|
||||
self,
|
||||
vmap_rule: Callable[..., tuple[Any, Any]],
|
||||
) -> Callable[..., tuple[Any, Any]]:
|
||||
"""Define the vmap rule for this custom_vmap function.
|
||||
|
||||
Args:
|
||||
vmap_rule: A function that implements the vmap rule. This function should
|
||||
accept the following arguments: (1) an integer ``axis_size`` as its
|
||||
first argument, (2) a pytree of booleans with the same structure as the
|
||||
inputs to the function, specifying whether each argument is batched,
|
||||
and (3) the batched arguments. It should return a tuple of the batched
|
||||
output and a pytree of booleans with the same structure as the output,
|
||||
specifying whether each output element is batched. See the documentation
|
||||
for :py:func:`jax.custom_batching.custom_vmap` for some examples.
|
||||
|
||||
Returns:
|
||||
This method passes the rule through, returning ``vmap_rule`` unchanged.
|
||||
"""
|
||||
self.vmap_rule = vmap_rule
|
||||
return vmap_rule
|
||||
|
||||
@ -72,7 +149,7 @@ class custom_vmap:
|
||||
"using def_vmap.")
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
@ -272,6 +349,31 @@ def tree_merge(mask, lhs_tree, rhs_tree):
|
||||
mask, lhs_tree, rhs_tree)
|
||||
|
||||
def sequential_vmap(f):
|
||||
"""A special case of ``custom_vmap`` that uses a loop.
|
||||
|
||||
A function decorated with ``sequential_vmap`` will be called sequentially
|
||||
within a loop when batched. This is useful for functions that don't natively
|
||||
support batch dimensions.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_batching.sequential_vmap
|
||||
... def f(x):
|
||||
... jax.debug.print("{}", x)
|
||||
... return x + 1
|
||||
...
|
||||
>>> jax.vmap(f)(jnp.arange(3))
|
||||
0
|
||||
1
|
||||
2
|
||||
Array([1, 2, 3], dtype=int32)
|
||||
|
||||
Where the print statements demonstrate that this :py:func:`~jax.vmap` is being
|
||||
generated using a loop.
|
||||
|
||||
See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for
|
||||
more details.
|
||||
"""
|
||||
f = custom_vmap(f)
|
||||
|
||||
@f.def_vmap
|
||||
|
@ -1051,7 +1051,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
from the closure.
|
||||
"""
|
||||
flat_args, in_tree = tree_flatten(example_args)
|
||||
in_avals = tuple(map(abstractify, flat_args))
|
||||
in_avals = tuple(map(core.get_aval, flat_args))
|
||||
if config.check_tracer_leaks.value:
|
||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
||||
else:
|
||||
@ -1111,9 +1111,6 @@ def partition_list(choice, lst):
|
||||
return [next(i2 if snd else i1) for snd in which]
|
||||
return out, merge
|
||||
|
||||
def abstractify(x):
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
### Custom transposition
|
||||
|
||||
@ -1209,8 +1206,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
f_in_tree = treedef_tuple((res_tree, lin_tree))
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
|
||||
|
||||
res_avals = map(abstractify, operands_res)
|
||||
lin_avals = map(abstractify, operands_lin)
|
||||
res_avals = map(core.get_aval, operands_res)
|
||||
lin_avals = map(core.get_aval, operands_lin)
|
||||
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
|
||||
f_jaxpr = _close_jaxpr(f_jaxpr)
|
||||
out_avals = f_jaxpr.out_avals
|
||||
|
@ -455,7 +455,7 @@ class custom_partitioning:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
|
@ -20,16 +20,124 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
|
||||
|
||||
_CompoundFactor = tuple[str, ...]
|
||||
_DimMapping = tuple[str | _CompoundFactor, ...]
|
||||
|
||||
# A single character replacement for ... to simplify parsing.
|
||||
_ELLIPSIS: str = "…"
|
||||
BATCHING: str = "…"
|
||||
|
||||
# A prefix for names of batching dimension factors, used for expanding the
|
||||
# leading ... into factors.
|
||||
_BATCHING_DIM_FACTOR_PREFIX = "?"
|
||||
|
||||
def _check_factor(factor:str):
|
||||
"""Validates a factor.
|
||||
|
||||
A factor is a string starting with a letter and containing only letters,
|
||||
digits, or underscores.
|
||||
"""
|
||||
if not factor[0].isalpha():
|
||||
raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'")
|
||||
for char in factor[1:]:
|
||||
if char != "_" and not char.isdigit() and not char.isalpha():
|
||||
raise ValueError(f"Unknown character '{char}'")
|
||||
|
||||
class CompoundFactor(tuple):
|
||||
"""Describes the factors for a compound factor.
|
||||
|
||||
A compound factor should contain at least two factors, e.g.
|
||||
* CompoundFactor('b', 'c').
|
||||
"""
|
||||
def __init__(self, *factors):
|
||||
if len(factors) < 2:
|
||||
raise ValueError("A compound factor should contain at least two factors")
|
||||
for factor in factors:
|
||||
if not isinstance(factor, str):
|
||||
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
|
||||
if factor == BATCHING:
|
||||
raise ValueError("Ellipsis can't be used in a compound factor")
|
||||
else:
|
||||
_check_factor(factor)
|
||||
|
||||
def __new__(cls, *factors):
|
||||
return tuple.__new__(CompoundFactor, factors)
|
||||
|
||||
|
||||
class ArrayMapping(tuple):
|
||||
"""Describes the factors for an operand or result.
|
||||
|
||||
Each element is either a factor or a CompoundFactor. A leading element can
|
||||
also be BATCHING, which represents batching dimensions. examples:
|
||||
* ArrayMapping('a')
|
||||
* ArrayMapping('b', 'c')
|
||||
* ArrayMapping(CompoundFactor('b', 'c'), 'd')
|
||||
* ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd')
|
||||
"""
|
||||
def __init__(self, *dim_mappings):
|
||||
for i, d in enumerate(dim_mappings):
|
||||
if not isinstance(d, str) and not isinstance(d, CompoundFactor):
|
||||
raise ValueError(
|
||||
"Each element of ArrayMapping must be a str or CompoundFactor, but"
|
||||
f" got {type(d)}")
|
||||
if isinstance(d, str):
|
||||
if d == BATCHING:
|
||||
if i != 0:
|
||||
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
|
||||
else:
|
||||
_check_factor(d)
|
||||
|
||||
def __new__(cls, *dim_mappings):
|
||||
return tuple.__new__(ArrayMapping, dim_mappings)
|
||||
|
||||
|
||||
class SdyShardingRule:
|
||||
"""Represents a Shardy sharding rule.
|
||||
|
||||
An SdyShardingRule contains the ArrayMappings for operands and results, and an
|
||||
optional list of factor sizes. A factor is a name used in the ArrayMappings.
|
||||
If a factor is only used in CompoundFactors, its size must be specified.
|
||||
"""
|
||||
operand_mappings: tuple[ArrayMapping, ...]
|
||||
result_mappings: tuple[ArrayMapping, ...]
|
||||
factor_sizes: dict[str, int]
|
||||
|
||||
def __init__(self, operand_mappings: tuple[ArrayMapping, ...],
|
||||
result_mappings: tuple[ArrayMapping, ...], **factor_sizes):
|
||||
# Find all factors and mark whether their size can be inferred.
|
||||
factors_inferrable = dict()
|
||||
for value in operand_mappings + result_mappings:
|
||||
for dim in value:
|
||||
if isinstance(dim, str):
|
||||
factors_inferrable[dim] = True
|
||||
else:
|
||||
for factor in dim:
|
||||
if factor not in factors_inferrable.keys():
|
||||
factors_inferrable[factor] = False
|
||||
|
||||
# Check that factors in factor_sizes are used in the rule.
|
||||
for factor in factor_sizes:
|
||||
if factor not in factors_inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is not used in the rule, but size is provided")
|
||||
|
||||
# Check that factors that are used for a whole dimension aren't in
|
||||
# factor_sizes and factors that are never used for a whole dimension are
|
||||
# in factor_sizes.
|
||||
for factor, inferrable in factors_inferrable.items():
|
||||
if factor not in factor_sizes and not inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is only used in compound factors; must specify"
|
||||
" its size")
|
||||
if factor in factor_sizes and inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} represents a whole dimension; do not specify its"
|
||||
" size")
|
||||
|
||||
self.operand_mappings = operand_mappings
|
||||
self.result_mappings = result_mappings
|
||||
self.factor_sizes = factor_sizes
|
||||
|
||||
def __str__(self):
|
||||
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"
|
||||
|
||||
|
||||
def _get_batching_dim_factor_name(batch_dim_order : int):
|
||||
"""Constructs a factor name for a batching dimension.
|
||||
|
||||
@ -42,18 +150,18 @@ def _get_batching_dim_factor_name(batch_dim_order : int):
|
||||
|
||||
def _parse_values(
|
||||
rule: str,
|
||||
) -> tuple[_DimMapping, ...]:
|
||||
) -> tuple[ArrayMapping, ...]:
|
||||
"""Parses the LHS or RHS of an Einsum notation like string.
|
||||
|
||||
Converts each operand or result in the Einsum notation like string to a tuple
|
||||
of _DimMapping. This very closely follows how einops parses their rules in
|
||||
of ArrayMapping. This very closely follows how einops parses their rules in
|
||||
einops/parsing.py.
|
||||
|
||||
Args:
|
||||
rule: The Einsum notation for the operands or results of an operation.
|
||||
|
||||
Returns:
|
||||
The tuple of values.
|
||||
The tuple of ArrayMapping.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rule is not balanced or contains unknown characters.
|
||||
@ -65,10 +173,10 @@ def _parse_values(
|
||||
|
||||
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
|
||||
if not rule:
|
||||
return ((),)
|
||||
return (ArrayMapping(),)
|
||||
|
||||
all_values = []
|
||||
# Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the
|
||||
# Represent all dimensions of an value. When an value[0]==BATCHING, the
|
||||
# value may have 0 or more leading dimensions.
|
||||
value = []
|
||||
current_factor = None
|
||||
@ -84,12 +192,12 @@ def _parse_values(
|
||||
current_compound_dim.append(x)
|
||||
|
||||
for char in rule:
|
||||
if char == _ELLIPSIS:
|
||||
if char == BATCHING:
|
||||
if (current_factor is not None or current_compound_dim is not None
|
||||
or value):
|
||||
raise ValueError(
|
||||
"Ellipsis can only be used at the beginning of a dimension")
|
||||
add_factor(_ELLIPSIS)
|
||||
add_factor(BATCHING)
|
||||
continue
|
||||
if char in "(), ":
|
||||
if current_factor is not None:
|
||||
@ -106,10 +214,10 @@ def _parse_values(
|
||||
raise ValueError("Brackets are not balanced")
|
||||
if len(current_compound_dim) <= 1:
|
||||
raise ValueError("Brackets should contain at least two factors")
|
||||
value.append(tuple(current_compound_dim))
|
||||
value.append(CompoundFactor(*current_compound_dim))
|
||||
current_compound_dim = None
|
||||
elif char == ",":
|
||||
all_values.append(tuple(value))
|
||||
all_values.append(ArrayMapping(*value))
|
||||
value = []
|
||||
elif char == "_" or char.isdigit() or char.isalpha():
|
||||
if current_factor is None:
|
||||
@ -125,256 +233,203 @@ def _parse_values(
|
||||
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
|
||||
if current_factor is not None:
|
||||
add_factor(current_factor)
|
||||
all_values.append(tuple(value))
|
||||
all_values.append(ArrayMapping(*value))
|
||||
|
||||
return tuple(all_values)
|
||||
|
||||
def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule:
|
||||
"""Constructs a SdyShardingRule object from the Einsum notation like string.
|
||||
|
||||
class SdyShardingRule:
|
||||
"""A representation for Shardy sharding rule.
|
||||
This is done by verifying that the input Einsum notation like string and
|
||||
with optional factor sizes represents a valid sharding rule and converting
|
||||
it to an internal representation.
|
||||
|
||||
A SdyShardingRule includes an Enisum notation like string and an optional
|
||||
list of factor sizes. A factor is a name in the Einsum notation. If a factor
|
||||
is only used in compound factors, its size must be specified.
|
||||
Args:
|
||||
rule: The Einsum notation like string for an operation.
|
||||
**factor_sizes: The optional factor sizes.
|
||||
|
||||
SdyShardingRule examples:
|
||||
|
||||
* Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k')
|
||||
* Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k')
|
||||
* A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j')
|
||||
* Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2)
|
||||
* An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...')
|
||||
Raises:
|
||||
ValueError: If there is any problem with the rule or factor_sizes.
|
||||
"""
|
||||
if not isinstance(rule, str):
|
||||
raise TypeError(f"rule must be a str, but got {type(rule)}")
|
||||
if not all(isinstance(size, int) for size in factor_sizes.values()):
|
||||
raise TypeError(
|
||||
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
|
||||
|
||||
def __init__(self, rule: str, **factor_sizes):
|
||||
"""Constructs a SdyShardingRule object from the Einsum notation like string.
|
||||
|
||||
This is done by verifying that the input Einsum notation like string and
|
||||
with optional factor sizes represents a valid sharding rule and converting
|
||||
it to an internal representation.
|
||||
|
||||
Args:
|
||||
rule: The Einsum notation like string for an operation.
|
||||
**factor_sizes: The optional factor sizes.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is any problem with the rule or factor_sizes.
|
||||
"""
|
||||
if not isinstance(rule, str):
|
||||
raise TypeError(f"rule must be a str, but got {type(rule)}")
|
||||
if not all(isinstance(size, int) for size in factor_sizes.values()):
|
||||
raise TypeError(
|
||||
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
|
||||
|
||||
# Replace ... with a single char to simplify parsing.
|
||||
if _ELLIPSIS in rule:
|
||||
raise ValueError(f"Unknown character '{_ELLIPSIS}'")
|
||||
# Replace ... with a single char to simplify parsing.
|
||||
if BATCHING in rule:
|
||||
raise ValueError(f"Unknown character '{BATCHING}'")
|
||||
if "." in rule:
|
||||
rule = rule.replace("...", BATCHING)
|
||||
if "." in rule:
|
||||
rule = rule.replace("...", _ELLIPSIS)
|
||||
if "." in rule:
|
||||
raise ValueError("Character '.' must be used inside ellipsis '...'")
|
||||
raise ValueError("Character '.' must be used inside ellipsis '...'")
|
||||
|
||||
try:
|
||||
operands, results = rule.split("->")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"There is no -> in rule: '{rule}'") from e
|
||||
try:
|
||||
operands, results = rule.split("->")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"There is no -> in rule: '{rule}'") from e
|
||||
|
||||
self.operands = _parse_values(operands)
|
||||
self.results = _parse_values(results)
|
||||
operand_mappings = _parse_values(operands)
|
||||
result_mappings = _parse_values(results)
|
||||
|
||||
# Find all factors and mark whether their size can be inferred.
|
||||
factors_inferrable = dict()
|
||||
for value in self.operands + self.results:
|
||||
for dim in value:
|
||||
if dim == _ELLIPSIS:
|
||||
continue
|
||||
if isinstance(dim, str):
|
||||
factors_inferrable[dim] = True
|
||||
else:
|
||||
for factor in dim:
|
||||
if factor not in factors_inferrable.keys():
|
||||
factors_inferrable[factor] = False
|
||||
return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes)
|
||||
|
||||
# Check that factors in factor_sizes are used in the rule.
|
||||
for factor in factor_sizes:
|
||||
if factor not in factors_inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is not used in the rule, but size is provided")
|
||||
|
||||
# Check that factors that are used for a whole dimension aren't in
|
||||
# factor_sizes and factors that are never used for a whole dimension are
|
||||
# in factor_sizes.
|
||||
for factor, inferrable in factors_inferrable.items():
|
||||
if factor not in factor_sizes and not inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is only used in compound factors; must specify"
|
||||
" its size")
|
||||
if factor in factor_sizes and inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} represents a whole dimension; do not specify its"
|
||||
" size")
|
||||
def sdy_sharding_rule_to_mlir(
|
||||
rule: SdyShardingRule,
|
||||
operand_types: list[ir.Type],
|
||||
result_types: list[ir.Type],) -> ir.Attribute:
|
||||
"""Builds the MLIR representation for the sharding rule.
|
||||
|
||||
self.factor_sizes = factor_sizes
|
||||
This is done by verifying that the rule is consistent with the types of
|
||||
the operation and converting the Einsum notation like string to
|
||||
OpShardingRuleAttr.
|
||||
"""
|
||||
if len(rule.operand_mappings) != len(operand_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation"
|
||||
f" has {len(operand_types)} operands")
|
||||
if len(rule.result_mappings) != len(result_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(rule.result_mappings)} results, but the operation"
|
||||
f" has {len(result_types)} results")
|
||||
|
||||
def __str__(self):
|
||||
return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})"
|
||||
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
|
||||
types = operand_types + result_types
|
||||
UNKNOWN = -1 # Representation for unknown factor size or factor index.
|
||||
|
||||
def build(
|
||||
self,
|
||||
operand_types: list[ir.Type],
|
||||
result_types: list[ir.Type],) -> ir.Attribute:
|
||||
"""Builds the MLIR representation for the sharding rule.
|
||||
def get_message_for_value(i):
|
||||
if i >= len(operand_types):
|
||||
return f"{i - len(operand_types)}th result"
|
||||
else:
|
||||
return f"{i}th operand"
|
||||
|
||||
This is done by verifying that the rule is consistent with the types of
|
||||
the operation and converting the Einsum notation like string to
|
||||
OpShardingRuleAttr.
|
||||
def get_rank_for_value(i):
|
||||
return ir.ShapedType(types[i]).rank
|
||||
|
||||
def get_size_for_value_dim(i, j):
|
||||
return ir.ShapedType(types[i]).shape[j]
|
||||
|
||||
def add_factor(factor, size):
|
||||
"""Adds a factor to factors_to_indices_sizes.
|
||||
|
||||
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
|
||||
if a factor is first used as in a compound factor and then used for a
|
||||
whole dimension.
|
||||
"""
|
||||
if len(self.operands) != len(operand_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(self.operands)} operands, but the operation"
|
||||
f" has {len(operand_types)} operands"
|
||||
)
|
||||
if len(self.results) != len(result_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(self.results)} results, but the operation"
|
||||
f" has {len(result_types)} results"
|
||||
)
|
||||
|
||||
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
|
||||
types = operand_types + result_types
|
||||
UNKNOWN = -1 # Representation for unknown factor size or factor index.
|
||||
|
||||
def get_message_for_value(i):
|
||||
if i >= len(operand_types):
|
||||
return f"{i - len(operand_types)}th result"
|
||||
else:
|
||||
return f"{i}th operand"
|
||||
|
||||
def get_rank_for_value(i):
|
||||
return ir.ShapedType(types[i]).rank
|
||||
|
||||
def get_size_for_value_dim(i, j):
|
||||
return ir.ShapedType(types[i]).shape[j]
|
||||
|
||||
def add_factor(factor, size):
|
||||
"""Adds a factor to factors_to_indices_sizes.
|
||||
|
||||
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
|
||||
if a factor is first used as in a compound factor and then used for a
|
||||
whole dimension.
|
||||
"""
|
||||
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
|
||||
if factor_index != UNKNOWN:
|
||||
# Not the first time seeing the factor.
|
||||
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
|
||||
factor_or_batching_dim = (
|
||||
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
|
||||
else f"Batching dimension {factor[1:]}")
|
||||
raise ValueError(
|
||||
f"{factor_or_batching_dim} corresponds to two sizes:"
|
||||
f" {factor_size} and {size}")
|
||||
if size != UNKNOWN and factor_size == UNKNOWN:
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
else:
|
||||
# First time seeing the factor.
|
||||
factor_index = len(factors_to_indices_sizes)
|
||||
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
|
||||
if factor_index != UNKNOWN:
|
||||
# Not the first time seeing the factor.
|
||||
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
|
||||
factor_or_batching_dim = (
|
||||
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
|
||||
else f"Batching dimension {factor[1:]}")
|
||||
raise ValueError(
|
||||
f"{factor_or_batching_dim} corresponds to two sizes:"
|
||||
f" {factor_size} and {size}")
|
||||
if size != UNKNOWN and factor_size == UNKNOWN:
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
else:
|
||||
# First time seeing the factor.
|
||||
factor_index = len(factors_to_indices_sizes)
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
|
||||
def add_batching_dim_factor(batch_dim_order, factor_size):
|
||||
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
|
||||
add_factor(ellipsis_batch_dim_name, factor_size)
|
||||
def add_batching_dim_factor(batch_dim_order, factor_size):
|
||||
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
|
||||
add_factor(ellipsis_batch_dim_name, factor_size)
|
||||
|
||||
def build_dim_mapping_for_compound_factors(i, j, factors):
|
||||
accumulated_size = 1
|
||||
all_indices = []
|
||||
for factor in factors:
|
||||
factor_index, factor_size = factors_to_indices_sizes[factor]
|
||||
accumulated_size *= factor_size
|
||||
all_indices.append(factor_index)
|
||||
def build_dim_mapping_for_compound_factors(i, j, factors):
|
||||
accumulated_size = 1
|
||||
all_indices = []
|
||||
for factor in factors:
|
||||
factor_index, factor_size = factors_to_indices_sizes[factor]
|
||||
accumulated_size *= factor_size
|
||||
all_indices.append(factor_index)
|
||||
|
||||
dim_size = get_size_for_value_dim(i, j)
|
||||
if accumulated_size != dim_size:
|
||||
dim_size = get_size_for_value_dim(i, j)
|
||||
if accumulated_size != dim_size:
|
||||
raise ValueError(
|
||||
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
|
||||
f" the size {accumulated_size} derived from the compound factors"
|
||||
f" {factors}")
|
||||
|
||||
return sdy.DimMappingAttr.get(factor_indices=all_indices)
|
||||
|
||||
# Add factors and their sizes in the order they appear in the rule,
|
||||
# including the batching dimensions represented by ellipsis.
|
||||
ellipsis_rank = None
|
||||
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
|
||||
value = tuple(mapping)
|
||||
if value and value[0] == BATCHING:
|
||||
has_batching = True
|
||||
value = value[1:]
|
||||
else:
|
||||
has_batching = False
|
||||
rule_rank = len(value)
|
||||
op_rank = get_rank_for_value(i)
|
||||
# The number of dimensions represented by ellipsis.
|
||||
current_batching_rank = 0
|
||||
if has_batching and op_rank >= rule_rank:
|
||||
current_batching_rank = op_rank - rule_rank
|
||||
if has_batching:
|
||||
if ellipsis_rank is None:
|
||||
ellipsis_rank = current_batching_rank
|
||||
elif ellipsis_rank != current_batching_rank:
|
||||
raise ValueError(
|
||||
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
|
||||
f" the size {accumulated_size} derived from the compound factors"
|
||||
f" {factors}")
|
||||
"Ellipsis represents different number of leading dimensions"
|
||||
f" {ellipsis_rank} and {current_batching_rank}")
|
||||
rule_rank += current_batching_rank
|
||||
if rule_rank != op_rank:
|
||||
msg = get_message_for_value(i)
|
||||
raise ValueError(
|
||||
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
|
||||
f" {msg} has rank {op_rank}")
|
||||
|
||||
return sdy.DimMappingAttr.get(factor_indices=all_indices)
|
||||
for j in range(current_batching_rank):
|
||||
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
|
||||
|
||||
# Add factors and their sizes in the order they appear in the rule,
|
||||
# including the batching dimensions represented by ellipsis.
|
||||
ellipsis_rank = None
|
||||
for i, value in enumerate(self.operands + self.results):
|
||||
if value and value[0] == _ELLIPSIS:
|
||||
has_ellipsis = True
|
||||
value = value[1:]
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank))
|
||||
else:
|
||||
has_ellipsis = False
|
||||
rule_rank = len(value)
|
||||
op_rank = get_rank_for_value(i)
|
||||
# The number of dimensions represented by ellipsis.
|
||||
current_ellipsis_rank = 0
|
||||
if has_ellipsis and op_rank > rule_rank:
|
||||
current_ellipsis_rank = op_rank - rule_rank
|
||||
if has_ellipsis:
|
||||
if ellipsis_rank is None:
|
||||
ellipsis_rank = current_ellipsis_rank
|
||||
elif ellipsis_rank != current_ellipsis_rank:
|
||||
raise ValueError(
|
||||
"Ellipsis represents different number of leading dimensions"
|
||||
f" {ellipsis_rank} and {current_ellipsis_rank}")
|
||||
rule_rank += current_ellipsis_rank
|
||||
if rule_rank != op_rank:
|
||||
msg = get_message_for_value(i)
|
||||
raise ValueError(
|
||||
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
|
||||
f" {msg} has rank {op_rank}")
|
||||
for factor in dim:
|
||||
add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN))
|
||||
|
||||
for j in range(current_ellipsis_rank):
|
||||
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
|
||||
# Build the tensor mappings for each operand and result.
|
||||
tensor_mappings = []
|
||||
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
|
||||
value = tuple(mapping)
|
||||
dim_mappings = []
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
add_factor(
|
||||
dim, get_size_for_value_dim(i, j + current_ellipsis_rank))
|
||||
else:
|
||||
for factor in dim:
|
||||
add_factor(factor, self.factor_sizes.get(factor, UNKNOWN))
|
||||
|
||||
# Build the tensor mappings for each operand and result.
|
||||
tensor_mappings = []
|
||||
for i, value in enumerate(self.operands + self.results):
|
||||
dim_mappings = []
|
||||
|
||||
if value and value[0] == _ELLIPSIS:
|
||||
value = value[1:]
|
||||
if ellipsis_rank is None:
|
||||
current_ellipsis_rank = 0
|
||||
else:
|
||||
current_ellipsis_rank = ellipsis_rank
|
||||
if value and value[0] == BATCHING:
|
||||
value = value[1:]
|
||||
if ellipsis_rank is None:
|
||||
current_batching_rank = 0
|
||||
else:
|
||||
current_ellipsis_rank = 0
|
||||
current_batching_rank = ellipsis_rank
|
||||
else:
|
||||
current_batching_rank = 0
|
||||
|
||||
for j in range(current_ellipsis_rank):
|
||||
for j in range(current_batching_rank):
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(factor_indices=[
|
||||
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(factor_indices=[
|
||||
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
|
||||
sdy.DimMappingAttr.get(
|
||||
factor_indices=[factors_to_indices_sizes[dim][0]]))
|
||||
else:
|
||||
dim_mappings.append(
|
||||
build_dim_mapping_for_compound_factors(
|
||||
i, j + current_batching_rank, dim))
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(
|
||||
factor_indices=[factors_to_indices_sizes[dim][0]]))
|
||||
else:
|
||||
dim_mappings.append(
|
||||
build_dim_mapping_for_compound_factors(
|
||||
i, j + current_ellipsis_rank, dim))
|
||||
tensor_mappings.append(
|
||||
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
|
||||
|
||||
tensor_mappings.append(
|
||||
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
|
||||
|
||||
op_sharding_rule = sdy.OpShardingRuleAttr.get(
|
||||
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
|
||||
operand_mappings=tensor_mappings[0:len(operand_types)],
|
||||
result_mappings=tensor_mappings[len(operand_types):])
|
||||
return op_sharding_rule
|
||||
return sdy.OpShardingRuleAttr.get(
|
||||
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
|
||||
operand_mappings=tensor_mappings[0:len(operand_types)],
|
||||
result_mappings=tensor_mappings[len(operand_types):])
|
||||
|
@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(prim: core.Primitive, **params):
|
||||
util.test_event("xla_primitive_callable_cache_miss")
|
||||
def prim_fun(*args):
|
||||
with config.eager_constant_folding(False):
|
||||
return prim.bind(*args, **params)
|
||||
|
@ -419,7 +419,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
|
||||
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
|
||||
|
||||
# Otherwise, fall back to numpy.issubdtype
|
||||
return np.issubdtype(a_sctype, b_sctype)
|
||||
return bool(np.issubdtype(a_sctype, b_sctype))
|
||||
|
||||
can_cast = np.can_cast
|
||||
|
||||
|
@ -203,6 +203,7 @@ class Exported:
|
||||
_get_vjp: Callable[[Exported], Exported] | None
|
||||
|
||||
def mlir_module(self) -> str:
|
||||
"""A string representation of the `mlir_module_serialized`."""
|
||||
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
||||
|
||||
def __str__(self):
|
||||
@ -211,8 +212,8 @@ class Exported:
|
||||
return f"Exported(fun_name={self.fun_name}, ...)"
|
||||
|
||||
def in_shardings_jax(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.in_shardings_hlo.
|
||||
|
||||
The Exported object stores `in_shardings_hlo` as HloShardings, which are
|
||||
@ -221,30 +222,31 @@ class Exported:
|
||||
`jax.device_put`.
|
||||
|
||||
Example usage:
|
||||
>>> from jax import export
|
||||
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
|
||||
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
|
||||
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
|
||||
... )(np.arange(jax.device_count()))
|
||||
>>> exp.in_shardings_hlo
|
||||
({devices=[8]<=[8]},)
|
||||
|
||||
# Create a mesh for running the exported object
|
||||
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
|
||||
>>>
|
||||
# Put the args and kwargs on the appropriate devices
|
||||
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
|
||||
... exp.in_shardings_jax(run_mesh)[0])
|
||||
>>> res = exp.call(run_arg)
|
||||
>>> res.addressable_shards
|
||||
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
|
||||
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
|
||||
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
|
||||
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
|
||||
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
|
||||
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
|
||||
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
|
||||
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
|
||||
>>> from jax import export
|
||||
>>> # Prepare the exported object:
|
||||
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
|
||||
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
|
||||
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
|
||||
... )(np.arange(jax.device_count()))
|
||||
>>> exp.in_shardings_hlo
|
||||
({devices=[8]<=[8]},)
|
||||
>>> # Create a mesh for running the exported object
|
||||
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
|
||||
>>> # Put the args and kwargs on the appropriate devices
|
||||
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
|
||||
... exp.in_shardings_jax(run_mesh)[0])
|
||||
>>> res = exp.call(run_arg)
|
||||
>>> res.addressable_shards
|
||||
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
|
||||
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
|
||||
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
|
||||
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
|
||||
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
|
||||
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
|
||||
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
|
||||
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
|
||||
|
||||
"""
|
||||
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
|
||||
for s in self.in_shardings_hlo)
|
||||
@ -252,7 +254,7 @@ class Exported:
|
||||
def out_shardings_jax(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.out_shardings_hlo.
|
||||
"""Creates Shardings corresponding to `self.out_shardings_hlo`.
|
||||
|
||||
See documentation for in_shardings_jax.
|
||||
"""
|
||||
@ -289,6 +291,21 @@ class Exported:
|
||||
return serialize(self, vjp_order=vjp_order)
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
"""Call an exported function from a JAX program.
|
||||
|
||||
Args:
|
||||
args: the positional arguments to pass to the exported function. This
|
||||
should be a pytree of arrays with the same pytree structure as the
|
||||
arguments for which the function was exported.
|
||||
kwargs: the keyword arguments to pass to the exported function.
|
||||
|
||||
Returns: a pytree of result array, with the same structure as the
|
||||
results of the exported function.
|
||||
|
||||
The invocation supports reverse-mode AD, and all the features supported
|
||||
by exporting: shape polymorphism, multi-platform, device polymorphism.
|
||||
See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
|
||||
"""
|
||||
return call_exported(self)(*args, **kwargs)
|
||||
|
||||
|
||||
@ -501,7 +518,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
aval = core.get_aval(a)
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
|
||||
@ -997,7 +1014,10 @@ _CPU_FFI_KERNELS = [
|
||||
"lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi",
|
||||
"lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi",
|
||||
"lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi",
|
||||
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
|
||||
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
|
||||
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
|
||||
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
|
||||
]
|
||||
# These are the JAX custom call target names that are guaranteed to be stable.
|
||||
# Their backwards compatibility is tested by back_compat_test.py.
|
||||
@ -1021,6 +1041,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
||||
"blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm",
|
||||
# schur on CPU
|
||||
"lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees",
|
||||
# tridiagonal on CPU
|
||||
"lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd",
|
||||
# hessenberg on CPU
|
||||
"lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd",
|
||||
# lu on GPU
|
||||
|
@ -92,8 +92,11 @@ class _SymbolicConstraint:
|
||||
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
|
||||
cmp: Comparator
|
||||
debug_str: str # The form in which the user expressed it, for error messages
|
||||
e1: DimSize # This has been normalized w.r.t. previous constraints only
|
||||
e2: DimSize # This has been normalized w.r.t. previous constraints only
|
||||
# e1, e2, and diff == e1 - e2, are normalized w.r.t. previous constraints only
|
||||
e1: DimSize
|
||||
e2: DimSize
|
||||
# we pre-compute diff to avoid having the normalization rule kick in later.
|
||||
diff: DimSize
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constraint({self.debug_str})"
|
||||
@ -126,6 +129,7 @@ class _DimFactor:
|
||||
MOD = "mod"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
# TODO(necula): remove non_negative
|
||||
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
|
||||
# max but kept here for backwards compatibility.
|
||||
|
||||
@ -764,13 +768,23 @@ class _DimExpr:
|
||||
return _DimExpr._linear_combination(self, other, 0, 0, self.scope)
|
||||
return _ensure_poly(other, "mul", self.scope).__mul__(self)
|
||||
|
||||
def __pow__(self, power, modulo=None):
|
||||
assert modulo is None
|
||||
try:
|
||||
power = int(power)
|
||||
except:
|
||||
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
return functools.reduce(op.mul, [self] * power)
|
||||
def __pow__(self, power: core.DimSize, modulo=None):
|
||||
if modulo is not None:
|
||||
raise NotImplementedError("__pow__ modulo not implemented")
|
||||
if is_symbolic_dim(power):
|
||||
return power.__rpow__(self) # type: ignore
|
||||
if power != int(power):
|
||||
raise ValueError(f"Symbolic dimension cannot be raised to non-integer powers: '{self}' ** '{power}'")
|
||||
if power >= 0:
|
||||
return functools.reduce(op.mul, [self] * power, 1)
|
||||
# We don't support negative powers, because JAX does not allow negative
|
||||
# powers for integers
|
||||
raise ValueError(f"Symbolic dimension cannot be raised to negative powers: '{self}' ** '{power}'")
|
||||
|
||||
def __rpow__(self, other, modulo=None):
|
||||
if modulo is not None:
|
||||
raise NotImplementedError("__rpow__ modulo not implemented")
|
||||
return self.__jax_array__().__rpow__(other)
|
||||
|
||||
def __floordiv__(self, divisor):
|
||||
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
|
||||
@ -1051,29 +1065,51 @@ class SymbolicScope:
|
||||
if cmp == Comparator.GEQ and not is_geq:
|
||||
e1, e2 = e2, e1
|
||||
|
||||
diff = e1 - e2
|
||||
if (diff_const := _DimExpr._to_constant(diff)) is not None:
|
||||
if ((cmp == Comparator.EQ and diff_const != 0) or
|
||||
(cmp == Comparator.GEQ and diff_const < 0)):
|
||||
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
|
||||
# Compute e1 - e2 before we add to normalization rules
|
||||
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2,
|
||||
diff=e1 - e2)
|
||||
self._process_explicit_constraint(constr)
|
||||
|
||||
def _process_explicit_constraint(self, constr: _SymbolicConstraint):
|
||||
if (diff_const := _DimExpr._to_constant(constr.diff)) is not None:
|
||||
if ((constr.cmp == Comparator.EQ and diff_const != 0) or
|
||||
(constr.cmp == Comparator.GEQ and diff_const < 0)):
|
||||
raise ValueError(f"Unsatisfiable explicit constraint: {constr.debug_str}")
|
||||
return
|
||||
|
||||
if cmp == Comparator.EQ:
|
||||
if not isinstance(e1, _DimExpr):
|
||||
if constr.cmp == Comparator.EQ:
|
||||
if not isinstance(constr.e1, _DimExpr):
|
||||
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
|
||||
"The left-hand-side must be of the form `term * coefficient`.")
|
||||
(before, before_k), *rest = e1._sorted_terms
|
||||
(before, before_k), *rest = constr.e1._sorted_terms
|
||||
if rest:
|
||||
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
|
||||
"The left-hand-side must be of the form `term * coefficient`.")
|
||||
|
||||
after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore]
|
||||
after = _ensure_poly(constr.e2, "parse_constraint", constr.e1.scope) # type: ignore[name-error,unused-ignore]
|
||||
if before in self._normalization_rules:
|
||||
raise NotImplementedError(
|
||||
f"Found multiple equality constraints with the same left-hand-side: {before}")
|
||||
self._normalization_rules[before] = (after, before_k)
|
||||
# Look for constraints of the form mod(before_e1, before_k2) * 1 == 0
|
||||
if (before_k == 1 and
|
||||
isinstance(constr.e2, int) and constr.e2 == 0 and
|
||||
(before_f := before.to_factor()) and
|
||||
before_f.operation == _DimFactor.MOD and
|
||||
(before_k2 := _DimExpr._to_constant(before_f.operands[1])) is not None):
|
||||
# Add before_k2*floordiv(before_e1, before_k2) == before_e1
|
||||
k_times_floordiv = _DimExpr._from_term(
|
||||
_DimTerm.from_operation(
|
||||
_DimFactor.FLOORDIV, *before_f.operands, scope=constr.e1.scope),
|
||||
before_k2, scope=constr.e1.scope)
|
||||
before_e1 = before_f.operands[0]
|
||||
self._process_explicit_constraint(
|
||||
_SymbolicConstraint(cmp=Comparator.EQ,
|
||||
e1=k_times_floordiv, e2=before_e1,
|
||||
diff=k_times_floordiv - before_e1,
|
||||
debug_str=f"{k_times_floordiv} == {before_e1}")
|
||||
)
|
||||
|
||||
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
|
||||
self._explicit_constraints.append(constr)
|
||||
|
||||
def _check_same_scope(self, other: _DimExpr,
|
||||
@ -1468,7 +1504,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
aval = core.get_aval(a)
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
|
||||
@ -1982,7 +2018,8 @@ def compute_dim_vars_from_arg_shapes(
|
||||
generate the code for computing the dimension variables. It also generates
|
||||
the shape assertions.
|
||||
|
||||
Returns: the values of the dimension variables, in the order determined by
|
||||
Returns:
|
||||
The values of the dimension variables, in the order determined by
|
||||
`all_dim_vars(args_avals)`.
|
||||
"""
|
||||
dim_vars = all_dim_vars(args_avals)
|
||||
@ -1996,8 +2033,7 @@ def compute_dim_vars_from_arg_shapes(
|
||||
}
|
||||
synthetic_eval = ShapeEvaluator(synthetic_env)
|
||||
shape_constraints.shape_assertions(synthetic_eval)
|
||||
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
|
||||
return tuple(dim_values)
|
||||
return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars)
|
||||
|
||||
def _solve_dim_equations(
|
||||
eqns: list[_DimEquation],
|
||||
@ -2110,14 +2146,12 @@ def _solve_dim_equations(
|
||||
for constr in scope._explicit_constraints:
|
||||
# We can't just construct constr.e1 - constr.e2 because for an equality
|
||||
# constraint it would be reduced to 0.
|
||||
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
|
||||
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
|
||||
c_diff = c_e1 - c_e2
|
||||
c_diff = constr.diff._evaluate(shape_env) if not core.is_constant_dim(constr.diff) else constr.diff # type: ignore
|
||||
shape_constraints.add_constraint(
|
||||
constr.cmp, c_diff, 0,
|
||||
error_message_pieces=[
|
||||
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
|
||||
f"Expected '{constr.e1} - {constr.e2}' to be "
|
||||
f"Expected '{constr.diff}' to be "
|
||||
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
|
||||
"but found ", c_diff,
|
||||
|
||||
@ -2131,7 +2165,8 @@ def _solve_dim_equations(
|
||||
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
|
||||
if not eqns:
|
||||
add_explicit_symbolic_constraints(shape_env)
|
||||
return shape_env, shape_constraints # SUCCESS
|
||||
# SUCCESS
|
||||
return shape_env, shape_constraints # pytype: disable=bad-return-type
|
||||
elif len(eqns) >= nr_eqns:
|
||||
break
|
||||
|
||||
|
@ -85,19 +85,11 @@ class _DecisionByElimination:
|
||||
# the result (albeit, for now, without a good feedback loop to understand
|
||||
# how the order matters for inequalities).
|
||||
for constr in self.scope._explicit_constraints:
|
||||
if not core.is_constant_dim(constr.e1):
|
||||
self.add_implicit_constraints_expr(constr.e1) # type: ignore
|
||||
if not core.is_constant_dim(constr.e2):
|
||||
self.add_implicit_constraints_expr(constr.e2) # type: ignore
|
||||
# The equality constraints are not needed for inequality decisions,
|
||||
# because the LHS should always be rewritten in terms of the RHS.
|
||||
# In fact, adding them may break the assumption that if we eliminate
|
||||
# the leading term we end up with only smaller terms, because the LHS
|
||||
# may appear in the rest and may be rewritten to something larger.
|
||||
# However, we want to add the implicit constraints within.
|
||||
if constr.cmp == Comparator.GEQ:
|
||||
self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0,
|
||||
constr.debug_str)
|
||||
if not core.is_constant_dim(constr.diff):
|
||||
self.add_implicit_constraints_expr(constr.diff) # type: ignore
|
||||
|
||||
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
|
||||
constr.debug_str)
|
||||
|
||||
|
||||
# Clear the cache, since we have added constraints.
|
||||
@ -197,7 +189,7 @@ class _DecisionByElimination:
|
||||
Combine a term with existing constraints.
|
||||
For input (t, t_k) the tuple (c_eq, c, c_s, t_s) is among the returned
|
||||
tuples if there exists a constraint `c =[c_eq] 0` that can be combined
|
||||
with `t*t_k` to eliminate `t`.
|
||||
with `t*t_k` to eliminate `t`, and:
|
||||
|
||||
* `c =[c_eq] 0`
|
||||
* The term `comb = t*t_k*t_s + c*c_s` does not contain `t`, and if
|
||||
@ -207,7 +199,7 @@ class _DecisionByElimination:
|
||||
"""
|
||||
# TODO: maybe a generator is useful here instead of materializing the list
|
||||
acc: list[tuple[Comparator, _DimExpr, int, int]] = []
|
||||
# First combine with the existing term constraints
|
||||
# First combine with the existing term bounds
|
||||
t_lb, t_ub = self._term_bounds.get(t, (-np.inf, np.inf))
|
||||
if t_lb == t_ub:
|
||||
acc.append((Comparator.EQ, _DimExpr(((t, 1),), scope) - int(t_lb),
|
||||
|
@ -388,7 +388,7 @@ def ffi_call(
|
||||
f"custom_call_api_version < 4; got {custom_call_api_version}.")
|
||||
|
||||
def wrapped(*args: ArrayLike, **kwargs: Any):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
in_avals = [core.get_aval(x) for x in args]
|
||||
|
||||
if input_layouts is None:
|
||||
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
|
||||
|
@ -241,3 +241,218 @@ module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
|
||||
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00",
|
||||
xla_call_module_version=6,
|
||||
) # End paste
|
||||
|
||||
data_2024_11_29 = {}
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_11_29["c128"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_zgees_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
|
||||
[ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
|
||||
[ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j],
|
||||
[12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),),
|
||||
expected_outputs=(array([[ 3.2464249196572972e+01+0.j, -1.3416407864998739e+01+0.j,
|
||||
-1.2558842947806125e-14+0.j, -7.3490869705474997e-15+0.j],
|
||||
[ 0.0000000000000000e+00+0.j, -2.4642491965729798e+00+0.j,
|
||||
-2.5534994473279107e-15+0.j, -1.3671521621839345e-16+0.j],
|
||||
[ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j,
|
||||
-1.8779126463272594e-15+0.j, 7.2486619604759691e-16+0.j],
|
||||
[ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j,
|
||||
0.0000000000000000e+00+0.j, 4.8523679991768567e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197511 +0.j,
|
||||
0.5401354211381763 +0.j, -0.09085002384085737+0.j],
|
||||
[ 0.33000459866554743+0.j, -0.43714638836388686+0.j,
|
||||
-0.6524649518290251 +0.j, 0.5237265380279561 +0.j],
|
||||
[ 0.545832745943757 +0.j, -0.04546002040802424-0.j,
|
||||
-0.31547635975648136+0.j, -0.774903004533341 +0.j],
|
||||
[ 0.7616608932219662 +0.j, 0.346226347547838 +0.j,
|
||||
0.42780589044732925+0.j, 0.3420264903462419 +0.j]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("input")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xcomplex<f64>> loc("input")) -> (tensor<4x4xcomplex<f64>> {jax.result_info = "[0]"}, tensor<4x4xcomplex<f64>> {jax.result_info = "[1]"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor<complex<f64>> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @lapack_zgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex<f64>>) -> (tensor<4x4xcomplex<f64>>, tensor<4x4xcomplex<f64>>, tensor<4xcomplex<f64>>, tensor<i32>, tensor<i32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<4x4xcomplex<f64>> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex<f64>> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f64>>) -> tensor<4x4xcomplex<f64>> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex<f64>> loc(#loc3)
|
||||
return %6, %10 : tensor<4x4xcomplex<f64>>, tensor<4x4xcomplex<f64>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0bO\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02>\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x0b\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_zgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_11_29["c64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_cgees_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
|
||||
[ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
|
||||
[ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j],
|
||||
[12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),),
|
||||
expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -2.1337737e-06+0.j,
|
||||
1.8261760e-06+0.j],
|
||||
[ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -6.0543999e-07+0.j,
|
||||
4.8744488e-07+0.j],
|
||||
[ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -6.5878328e-07+0.j,
|
||||
3.9895070e-07+0.j],
|
||||
[ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j,
|
||||
3.0199919e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5404726 +0.j,
|
||||
-0.08882082 +0.j],
|
||||
[ 0.3300045 +0.j, -0.4371462 +0.j, -0.6544272 +0.j,
|
||||
0.52127254 +0.j],
|
||||
[ 0.54583293 +0.j, -0.045460045-0.j, -0.312564 +0.j,
|
||||
-0.77608234 +0.j],
|
||||
[ 0.76166105 +0.j, 0.34622625 +0.j, 0.42651838 +0.j,
|
||||
0.34363067 +0.j]], dtype=complex64)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("input")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xcomplex<f32>> loc("input")) -> (tensor<4x4xcomplex<f32>> {jax.result_info = "[0]"}, tensor<4x4xcomplex<f32>> {jax.result_info = "[1]"}) {
|
||||
%cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor<complex<f32>> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:5 = stablehlo.custom_call @lapack_cgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex<f32>>) -> (tensor<4x4xcomplex<f32>>, tensor<4x4xcomplex<f32>>, tensor<4xcomplex<f32>>, tensor<i32>, tensor<i32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<4x4xcomplex<f32>> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex<f32>> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<complex<f32>>) -> tensor<4x4xcomplex<f32>> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex<f32>> loc(#loc3)
|
||||
return %6, %10 : tensor<4x4xcomplex<f32>>, tensor<4x4xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02\x1e\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\t\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_cgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_11_29["f32"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_sgees_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(array([[ 0., 1., 2., 3.],
|
||||
[ 4., 5., 6., 7.],
|
||||
[ 8., 9., 10., 11.],
|
||||
[12., 13., 14., 15.]], dtype=float32),),
|
||||
expected_outputs=(array([[ 3.2464233e+01, -1.3416398e+01, -1.6680369e-05, 4.0411728e-06],
|
||||
[ 0.0000000e+00, -2.4642496e+00, -1.8640144e-06, 6.7429795e-07],
|
||||
[ 0.0000000e+00, 0.0000000e+00, -7.2618576e-07, 3.9895073e-07],
|
||||
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0443638e-07]],
|
||||
dtype=float32), array([[-0.11417632 , 0.8288333 , -0.5413438 , 0.08334288 ],
|
||||
[-0.33000442 , 0.43714583 , 0.65967286 , -0.5146185 ],
|
||||
[-0.54583275 , 0.045459934, 0.30468878 , 0.7792079 ],
|
||||
[-0.7616609 , -0.34622616 , -0.4230168 , -0.34793234 ]],
|
||||
dtype=float32)),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("input")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xf32> loc("input")) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FC00000> : tensor<f32> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:6 = stablehlo.custom_call @lapack_sgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<i32>, tensor<i32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x4xf32> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<4x4xf32> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3)
|
||||
return %6, %10 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\n\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_sgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_11_29["f64"] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_dgees_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(array([[ 0., 1., 2., 3.],
|
||||
[ 4., 5., 6., 7.],
|
||||
[ 8., 9., 10., 11.],
|
||||
[12., 13., 14., 15.]]),),
|
||||
expected_outputs=(array([[ 3.2464249196572979e+01, -1.3416407864998748e+01,
|
||||
4.7128510442204522e-15, -8.6687960588453852e-15],
|
||||
[ 0.0000000000000000e+00, -2.4642491965729767e+00,
|
||||
1.8990547895861982e-15, -2.4680570671743780e-16],
|
||||
[ 0.0000000000000000e+00, 0.0000000000000000e+00,
|
||||
-1.8780225147134376e-15, -7.2486619604759710e-16],
|
||||
[ 0.0000000000000000e+00, 0.0000000000000000e+00,
|
||||
0.0000000000000000e+00, 4.8523923435746521e-16]]), array([[-0.1141764513873386 , 0.8288327563197505 , 0.5401360966805397 ,
|
||||
0.09084600741204968],
|
||||
[-0.3300045986655475 , 0.43714638836388714, -0.6524688462214561 ,
|
||||
-0.5237216863090944 ],
|
||||
[-0.5458327459437569 , 0.04546002040802441, -0.31547059759870844,
|
||||
0.774905350382041 ],
|
||||
[-0.7616608932219663 , -0.34622634754783793, 0.4278033471396243 ,
|
||||
-0.3420296714849957 ]])),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("input")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xf64> loc("input")) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) {
|
||||
%cst = stablehlo.constant dense<0x7FF8000000000000> : tensor<f64> loc(#loc)
|
||||
%c = stablehlo.constant dense<0> : tensor<i32> loc(#loc)
|
||||
%0:6 = stablehlo.custom_call @lapack_dgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<i32>, tensor<i32>) loc(#loc3)
|
||||
%1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<i32> loc(#loc3)
|
||||
%2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> loc(#loc3)
|
||||
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<4x4xf64> loc(#loc3)
|
||||
%5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3)
|
||||
%7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<i1>) -> tensor<1x1xi1> loc(#loc3)
|
||||
%8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<4x4xf64> loc(#loc3)
|
||||
%9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3)
|
||||
%10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3)
|
||||
return %6, %10 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13)
|
||||
#loc3 = loc("jit(func)/jit(main)/schur"(#loc2))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\x1a\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_dgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
@ -203,3 +203,281 @@ module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas =
|
||||
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00",
|
||||
xla_call_module_version=6,
|
||||
) # End paste
|
||||
|
||||
data_2024_12_02 = {}
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_12_02['c128'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_ztrsm_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(
|
||||
array([
|
||||
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
|
||||
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
|
||||
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
|
||||
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
|
||||
]),
|
||||
array([
|
||||
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
|
||||
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
|
||||
[10.0 + 0.0j, 11.0 + 0.0j, 12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j],
|
||||
[15.0 + 0.0j, 16.0 + 0.0j, 17.0 + 0.0j, 18.0 + 0.0j, 19.0 + 0.0j],
|
||||
]),
|
||||
),
|
||||
expected_outputs=(
|
||||
array([
|
||||
[
|
||||
0.0 + 0.0j,
|
||||
0.2 + 0.0j,
|
||||
0.4 + 0.0j,
|
||||
0.6000000000000001 + 0.0j,
|
||||
0.8 + 0.0j,
|
||||
],
|
||||
[
|
||||
0.5 + 0.0j,
|
||||
0.52 + 0.0j,
|
||||
0.54 + 0.0j,
|
||||
0.5599999999999999 + 0.0j,
|
||||
0.58 + 0.0j,
|
||||
],
|
||||
[
|
||||
0.36666666666666664 + 0.0j,
|
||||
0.3146666666666667 + 0.0j,
|
||||
0.2626666666666667 + 0.0j,
|
||||
0.21066666666666667 + 0.0j,
|
||||
0.15866666666666665 + 0.0j,
|
||||
],
|
||||
[
|
||||
0.16833333333333336 + 0.0j,
|
||||
0.1217333333333333 + 0.0j,
|
||||
0.07513333333333323 + 0.0j,
|
||||
0.0285333333333333 + 0.0j,
|
||||
-0.018066666666666675 + 0.0j,
|
||||
],
|
||||
]),
|
||||
),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("a")
|
||||
#loc2 = loc("b")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xcomplex<f64>> loc("a"), %arg1: tensor<4x5xcomplex<f64>> loc("b")) -> (tensor<4x5xcomplex<f64>> {jax.result_info = ""}) {
|
||||
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>> loc(#loc4)
|
||||
%0 = stablehlo.custom_call @lapack_ztrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f64>>, tensor<4x5xcomplex<f64>>, tensor<complex<f64>>) -> tensor<4x5xcomplex<f64>> loc(#loc4)
|
||||
return %0 : tensor<4x5xcomplex<f64>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
|
||||
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0bO+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\x1e\x03\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\x0b)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ztrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_12_02['c64'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_ctrsm_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(
|
||||
array(
|
||||
[
|
||||
[5.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
|
||||
[4.0 + 0.0j, 10.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
|
||||
[8.0 + 0.0j, 9.0 + 0.0j, 15.0 + 0.0j, 0.0 + 0.0j],
|
||||
[12.0 + 0.0j, 13.0 + 0.0j, 14.0 + 0.0j, 20.0 + 0.0j],
|
||||
],
|
||||
dtype=complex64,
|
||||
),
|
||||
array(
|
||||
[
|
||||
[0.0 + 0.0j, 1.0 + 0.0j, 2.0 + 0.0j, 3.0 + 0.0j, 4.0 + 0.0j],
|
||||
[5.0 + 0.0j, 6.0 + 0.0j, 7.0 + 0.0j, 8.0 + 0.0j, 9.0 + 0.0j],
|
||||
[
|
||||
10.0 + 0.0j,
|
||||
11.0 + 0.0j,
|
||||
12.0 + 0.0j,
|
||||
13.0 + 0.0j,
|
||||
14.0 + 0.0j,
|
||||
],
|
||||
[
|
||||
15.0 + 0.0j,
|
||||
16.0 + 0.0j,
|
||||
17.0 + 0.0j,
|
||||
18.0 + 0.0j,
|
||||
19.0 + 0.0j,
|
||||
],
|
||||
],
|
||||
dtype=complex64,
|
||||
),
|
||||
),
|
||||
expected_outputs=(
|
||||
array(
|
||||
[
|
||||
[0.0 + 0.0j, 0.2 + 0.0j, 0.4 + 0.0j, 0.6 + 0.0j, 0.8 + 0.0j],
|
||||
[
|
||||
0.5 + 0.0j,
|
||||
0.52 + 0.0j,
|
||||
0.54 + 0.0j,
|
||||
0.56 + 0.0j,
|
||||
0.58000004 + 0.0j,
|
||||
],
|
||||
[
|
||||
0.36666667 + 0.0j,
|
||||
0.31466666 + 0.0j,
|
||||
0.26266667 + 0.0j,
|
||||
0.21066667 + 0.0j,
|
||||
0.15866666 + 0.0j,
|
||||
],
|
||||
[
|
||||
0.16833334 + 0.0j,
|
||||
0.12173338 + 0.0j,
|
||||
0.0751333 + 0.0j,
|
||||
0.02853328 + 0.0j,
|
||||
-0.018066704 + 0.0j,
|
||||
],
|
||||
],
|
||||
dtype=complex64,
|
||||
),
|
||||
),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("a")
|
||||
#loc2 = loc("b")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xcomplex<f32>> loc("a"), %arg1: tensor<4x5xcomplex<f32>> loc("b")) -> (tensor<4x5xcomplex<f32>> {jax.result_info = ""}) {
|
||||
%cst = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>> loc(#loc4)
|
||||
%0 = stablehlo.custom_call @lapack_ctrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xcomplex<f32>>, tensor<4x5xcomplex<f32>>, tensor<complex<f32>>) -> tensor<4x5xcomplex<f32>> loc(#loc4)
|
||||
return %0 : tensor<4x5xcomplex<f32>> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
|
||||
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x87[\x19\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x15\x17\x0b\x17\x0f\x07\x07\x1b\x07\x13\x13\x02\xfe\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x17\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x03\x13)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05\t)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_ctrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_12_02['f32'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_strsm_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(
|
||||
array(
|
||||
[
|
||||
[5.0, 0.0, 0.0, 0.0],
|
||||
[4.0, 10.0, 0.0, 0.0],
|
||||
[8.0, 9.0, 15.0, 0.0],
|
||||
[12.0, 13.0, 14.0, 20.0],
|
||||
],
|
||||
dtype=float32,
|
||||
),
|
||||
array(
|
||||
[
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0, 9.0],
|
||||
[10.0, 11.0, 12.0, 13.0, 14.0],
|
||||
[15.0, 16.0, 17.0, 18.0, 19.0],
|
||||
],
|
||||
dtype=float32,
|
||||
),
|
||||
),
|
||||
expected_outputs=(
|
||||
array(
|
||||
[
|
||||
[0.0, 0.2, 0.4, 0.6, 0.8],
|
||||
[0.5, 0.52, 0.54, 0.56, 0.58000004],
|
||||
[0.36666667, 0.31466666, 0.26266667, 0.21066667, 0.15866666],
|
||||
[0.16833334, 0.12173338, 0.0751333, 0.02853328, -0.018066704],
|
||||
],
|
||||
dtype=float32,
|
||||
),
|
||||
),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("a")
|
||||
#loc2 = loc("b")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xf32> loc("a"), %arg1: tensor<4x5xf32> loc("b")) -> (tensor<4x5xf32> {jax.result_info = ""}) {
|
||||
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32> loc(#loc4)
|
||||
%0 = stablehlo.custom_call @lapack_strsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf32>, tensor<4x5xf32>, tensor<f32>) -> tensor<4x5xf32> loc(#loc4)
|
||||
return %0 : tensor<4x5xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
|
||||
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b\x1f+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xe6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\t\x00\x00\x80?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_strsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2024_12_02['f64'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_dtrsm_ffi'],
|
||||
serialized_date=datetime.date(2024, 12, 2),
|
||||
inputs=(
|
||||
array([
|
||||
[5.0, 0.0, 0.0, 0.0],
|
||||
[4.0, 10.0, 0.0, 0.0],
|
||||
[8.0, 9.0, 15.0, 0.0],
|
||||
[12.0, 13.0, 14.0, 20.0],
|
||||
]),
|
||||
array([
|
||||
[0.0, 1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0, 9.0],
|
||||
[10.0, 11.0, 12.0, 13.0, 14.0],
|
||||
[15.0, 16.0, 17.0, 18.0, 19.0],
|
||||
]),
|
||||
),
|
||||
expected_outputs=(
|
||||
array([
|
||||
[0.0, 0.2, 0.4, 0.6000000000000001, 0.8],
|
||||
[0.5, 0.52, 0.54, 0.5599999999999999, 0.58],
|
||||
[
|
||||
0.36666666666666664,
|
||||
0.3146666666666667,
|
||||
0.2626666666666667,
|
||||
0.21066666666666667,
|
||||
0.15866666666666665,
|
||||
],
|
||||
[
|
||||
0.16833333333333336,
|
||||
0.1217333333333333,
|
||||
0.07513333333333323,
|
||||
0.0285333333333333,
|
||||
-0.018066666666666675,
|
||||
],
|
||||
]),
|
||||
),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("a")
|
||||
#loc2 = loc("b")
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<4x4xf64> loc("a"), %arg1: tensor<4x5xf64> loc("b")) -> (tensor<4x5xf64> {jax.result_info = ""}) {
|
||||
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64> loc(#loc4)
|
||||
%0 = stablehlo.custom_call @lapack_dtrsm_ffi(%arg0, %arg1, %cst) {mhlo.backend_config = {diag = 78 : ui8, side = 76 : ui8, trans_x = 78 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 1, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<4x4xf64>, tensor<4x5xf64>, tensor<f64>) -> tensor<4x5xf64> loc(#loc4)
|
||||
return %0 : tensor<4x5xf64> loc(#loc)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":715:13)
|
||||
#loc4 = loc("jit(func)/jit(main)/triangular_solve"(#loc3))
|
||||
""",
|
||||
mlir_module_serialized=b"ML\xefR\rStableHLO_v1.8.1\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85[\x17\x01%\x07\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0b\x13\x0b\x037O\x0b\x0b\x0f\x0f\x13\x0b\x0f\x13\x0b\x0b\x0b/+\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0f\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x13\x17\x07\x17\x0f\x07\x07\x1b\x13\x13\x02\xf6\x02\x1f\x11\x03\x05\x1d\x1b\x1d\x03\x07\t\x0b\r\x03\x0f\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x15\x01\x05\x17\x1d\x19\x01\x05\x19\x05\x1b\x17\x1f.\x0b\x1b\x05\x1d\x03\x03#?\x05\x1f\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\r\x01\x1d!\x13\rN\x13\rL\x03\x05''#\x11\x03\x035\r\x037)\x1d#\x1d%\x1d'\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\r\tA+C-E+G-\x1d)\x1d+\x1d-\x1d/\x0b\x03\x1d1\x03\x01\x05\x01\x03\x07%%S\x1f\x15\x01\x03\x03W\x15\x01\x05\x01\x03\x03%\x01\t\x01\x02\x02)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07!\x13\x11\x05\t\x05\x03\x05)\x03\t\x0f)\x03\x01\x0f\x04e\x05\x01Q\x01\x07\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x13\x13\x0b\x17\x00\x05B\x05\x05\x03\x0b\x07G\x05!\x07\x03\x05\x07\x01\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00r\x053#\x0b\x11\x0b\x0b\x0f\x0b!\x03)iK\x05\x05\x13%)9\x15\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00constant_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00a\x00b\x00jit(func)/jit(main)/triangular_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.backend_config\x00\x00jax.result_info\x00main\x00public\x00diag\x00side\x00trans_x\x00uplo\x00lapack_dtrsm_ffi\x00\x08+\t\x05#\x01\x0b/139;\x03=\x11I)KMOQUY",
|
||||
xla_call_module_version=9,
|
||||
nr_devices=1,
|
||||
) # End paste
|
||||
|
File diff suppressed because one or more lines are too long
@ -143,21 +143,31 @@ def linearize_jaxpr(jaxpr, nonzeros):
|
||||
|
||||
def direct_linearize(traceable, *primals, **kwargs):
|
||||
has_aux = kwargs.pop('has_aux', False)
|
||||
assert not has_aux
|
||||
with core.take_current_trace() as parent_trace:
|
||||
tangent_trace = pe.DynamicJaxprTrace()
|
||||
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
|
||||
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
|
||||
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
|
||||
with core.set_current_trace(linearize_trace):
|
||||
ans = traceable.call_wrapped(*tracers)
|
||||
|
||||
if has_aux:
|
||||
ans, aux = traceable.call_wrapped(*tracers)
|
||||
aux_primals = [x.primal
|
||||
if isinstance(x, LinearizeTracer)
|
||||
and x._trace.tag is linearize_trace.tag
|
||||
else x for x in aux]
|
||||
else:
|
||||
ans = traceable.call_wrapped(*tracers)
|
||||
aux = None
|
||||
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
|
||||
out_tangents = map(instantiate_zeros, out_tangents)
|
||||
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
|
||||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
|
||||
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
|
||||
del attrs_tracked # TODO: attrs
|
||||
return out_primals, out_tangents_pvals, jaxpr, consts
|
||||
if has_aux:
|
||||
return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals
|
||||
else:
|
||||
return out_primals, out_tangents_pvals, jaxpr, consts
|
||||
|
||||
def linearize(traceable, *primals, **kwargs):
|
||||
if config.use_direct_linearize.value:
|
||||
@ -175,7 +185,11 @@ def linearize(traceable, *primals, **kwargs):
|
||||
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
|
||||
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
|
||||
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
|
||||
if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals):
|
||||
raise ValueError(
|
||||
"Linearization failed to produce known values for all output primals. "
|
||||
"This is typically caused by attempting to differentiate a function "
|
||||
"uses an operation that does not support reverse-mode autodiff.")
|
||||
out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
|
||||
if not has_aux:
|
||||
return out_primals_consts, out_tangents_pvals, jaxpr, consts
|
||||
@ -263,6 +277,20 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
with ctx:
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
if eqn.primitive.ref_primitive:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
val_var, = eqn.invars
|
||||
ref_var, = eqn.outvars
|
||||
ref = read_primal(ref_var)
|
||||
ct_out = core.freeze(ref)
|
||||
write_cotangent(eqn.primitive, val_var, ct_out)
|
||||
elif eqn.primitive is core.freeze_p:
|
||||
val_var, = eqn.outvars # type: ignore
|
||||
ref_var, = eqn.invars # type: ignore
|
||||
ct_in = instantiate_zeros(read_cotangent(val_var))
|
||||
write_primal(ref_var, core.mutable_array(ct_in))
|
||||
continue
|
||||
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
@ -514,22 +542,45 @@ class LinearizeTrace(Trace):
|
||||
|
||||
def process_primitive(self, primitive, args, params):
|
||||
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
|
||||
tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
|
||||
tangent_nzs = [type(t) is not Zero for t in tangents_in]
|
||||
if all(type(t) is Zero for t in tangents_in):
|
||||
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
|
||||
lin = primitive_linearizations.get(primitive)
|
||||
if lin is None:
|
||||
lin = partial(fallback_linearize_rule, primitive)
|
||||
fallback = partial(fallback_linearize_rule, primitive)
|
||||
lin = primitive_linearizations.get(primitive, fallback)
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
primal_out, tangent_nonzeros_out, residuals, linearized = lin(
|
||||
tangent_nonzeros, *primals_in, **params)
|
||||
primal_out, tangent_nzs_out, residuals, linearized = lin(
|
||||
tangent_nzs, *primals_in, **params)
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
tangent_out = linearized(residuals, *tangents_in)
|
||||
if primitive.multiple_results:
|
||||
return [maybe_linearize_tracer(self, x, nz, t)
|
||||
for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
|
||||
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
|
||||
else:
|
||||
return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
|
||||
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
||||
if all(type(t) is Zero for t in tangents_in):
|
||||
return prim.bind_with_trace(self.parent_trace,
|
||||
(fun, fwd, bwd, *primals_in),
|
||||
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
|
||||
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
|
||||
fwd_in = [x for pair in fwd_in for x in pair] # flatten
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
res_and_primals_out = fwd.call_wrapped(*fwd_in)
|
||||
|
||||
_, res_tree = out_trees()
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
|
||||
|
||||
with core.set_current_trace(self.tangent_trace):
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
|
||||
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
|
||||
|
||||
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
|
||||
if is_nonzero:
|
||||
@ -539,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
|
||||
assert type(tangent) is Zero
|
||||
return primal
|
||||
|
||||
def fallback_linearize_rule(prim, _, *args, **kwargs):
|
||||
assert not prim.multiple_results
|
||||
def fallback_linearize_rule(prim, nonzeros, *primals, **params):
|
||||
jvp = primitive_jvps.get(prim)
|
||||
if not jvp:
|
||||
msg = f"Differentiation rule for '{prim}' not implemented"
|
||||
raise NotImplementedError(msg)
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
|
||||
tangent_avals = [get_aval(p).to_tangent_aval() for p in primals]
|
||||
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval)
|
||||
for aval, nz in zip(tangent_avals, nonzeros)]
|
||||
with core.set_current_trace(trace):
|
||||
out_primals, out_tangents = jvp(primals, tangent_args, **params)
|
||||
|
||||
def call_prim(*args_):
|
||||
return [prim.bind(*args_, **kwargs)]
|
||||
if not prim.multiple_results:
|
||||
out_primals = [out_primals]
|
||||
out_tangents = [out_tangents]
|
||||
|
||||
with config.use_direct_linearize(False):
|
||||
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
|
||||
lu.wrap_init(call_prim), *args, **kwargs)
|
||||
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
|
||||
out_nzs = [type(r) is not Zero for r in out_tangents]
|
||||
out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals]
|
||||
out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz]
|
||||
in_tracers = [t for t in tangent_args if type(t) is not Zero]
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
|
||||
|
||||
def linearized(residuals, *tangents):
|
||||
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
|
||||
return out_tangent
|
||||
def linearized(residuals, *tangents):
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
|
||||
nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in)
|
||||
nz_tangents_out_iter = iter(nz_tangents_out)
|
||||
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
|
||||
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
|
||||
if prim.multiple_results:
|
||||
return all_out_tangents
|
||||
else:
|
||||
out_tangent, = all_out_tangents
|
||||
return out_tangent
|
||||
|
||||
if prim.multiple_results:
|
||||
return out_primals, out_nzs, out_consts, linearized
|
||||
else:
|
||||
out_primal, = out_primals
|
||||
out_nz, = out_nzs
|
||||
return out_primal, out_nz, out_consts, linearized
|
||||
|
||||
return out_primal, True, consts, linearized
|
||||
|
||||
class LinearizeTracer(Tracer):
|
||||
__slots__ = ['primal', 'tangent']
|
||||
|
@ -50,6 +50,7 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import AUTO
|
||||
from jax._src.partition_spec import UnconstrainedSingleton
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
@ -1084,6 +1085,7 @@ def lower_jaxpr_to_module(
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_module")
|
||||
platforms = tuple(map(xb.canonicalize_platform, platforms))
|
||||
|
||||
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
||||
@ -1377,6 +1379,7 @@ def lower_jaxpr_to_fun(
|
||||
Returns:
|
||||
MLIR func op
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_fun", name)
|
||||
|
||||
# The first dimension variable may be the platform index
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
@ -1699,6 +1702,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
|
||||
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
|
||||
# The below custom call achieves the sharding like above example.
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
if config.use_shardy_partitioner.value:
|
||||
physical_ndim = core.physical_aval(aval).ndim
|
||||
s = sharding_impls.SdyArraySharding(
|
||||
@ -2523,12 +2527,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
|
||||
if aval.sharding.mesh._are_all_axes_collective:
|
||||
return op
|
||||
if aval.sharding.mesh._are_all_axes_auto:
|
||||
return op
|
||||
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
|
||||
# `return op` early and avoid bloating HLO size.
|
||||
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
if sharding_proto is None else sharding_proto)
|
||||
# TODO(yashkatariya): Enable this
|
||||
# unspecified_dims = (set(range(aval.ndim))
|
||||
# if aval.sharding.mesh._any_axis_collective else None)
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto)
|
||||
unspecified_dims = None
|
||||
if aval.sharding.mesh._any_axis_collective:
|
||||
unspecified_dims = set(range(aval.ndim))
|
||||
elif aval.sharding.mesh._any_axis_auto:
|
||||
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
|
||||
if isinstance(s, UnconstrainedSingleton)}
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
|
||||
|
||||
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
@ -2792,7 +2803,7 @@ def _emit_tpu_python_callback(
|
||||
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
||||
if minor_to_major is None:
|
||||
# Needed for token layouts
|
||||
layout = np.zeros((0,), dtype="int64")
|
||||
layout: np.ndarray = np.zeros((0,), dtype="int64")
|
||||
else:
|
||||
layout = np.array(minor_to_major, dtype="int64")
|
||||
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
||||
|
@ -177,9 +177,12 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
if const is None:
|
||||
aval = pval.get_aval()
|
||||
if type(aval) is DShapedArray:
|
||||
# TODO(dougalm): Fix the type error and remove the pytype pragmas.
|
||||
# pytype: disable=attribute-error
|
||||
shape = [self.new_instantiated_const(d)
|
||||
if isinstance(d, Tracer) and d._trace.level < self.level else d
|
||||
for d in aval.shape]
|
||||
# pytype: enable=attribute-error
|
||||
aval = aval.update(shape=tuple(shape))
|
||||
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
|
||||
else:
|
||||
@ -1006,7 +1009,7 @@ def partial_eval_jaxpr_stateful(
|
||||
in_inst: bool | Sequence[bool],
|
||||
ensure_out_unknowns: bool | Sequence[bool],
|
||||
ensure_out_inst: bool | Sequence[bool],
|
||||
saveable: Callable[..., RematCases_],
|
||||
saveable: Callable[..., RematCases_] | None,
|
||||
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]:
|
||||
if type(in_inst) is bool:
|
||||
in_inst = (in_inst,) * len(jaxpr.invars)
|
||||
@ -1014,6 +1017,8 @@ def partial_eval_jaxpr_stateful(
|
||||
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
|
||||
if type(ensure_out_inst) is bool:
|
||||
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
|
||||
if saveable is None:
|
||||
saveable = everything_saveable
|
||||
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
|
||||
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
|
||||
tuple(in_inst),
|
||||
@ -1021,6 +1026,8 @@ def partial_eval_jaxpr_stateful(
|
||||
tuple(ensure_out_inst), saveable)
|
||||
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref
|
||||
|
||||
everything_saveable = lambda *_, **__: True
|
||||
|
||||
@weakref_lru_cache
|
||||
def _partial_eval_jaxpr_custom_cached(
|
||||
jaxpr: Jaxpr,
|
||||
@ -1776,6 +1783,9 @@ def _inline_literals(
|
||||
newvars: dict[Var, Var] = {}
|
||||
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
|
||||
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
|
||||
lit_or_var = (
|
||||
lambda a: a if isinstance(a, Literal) else (lit(a) or var(a))
|
||||
)
|
||||
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))
|
||||
|
||||
def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
|
||||
@ -1794,10 +1804,10 @@ def _inline_literals(
|
||||
new_invars = [var(v) for v in jaxpr.invars]
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
invars = [lit(x) or var(x) for x in eqn.invars]
|
||||
invars = [lit_or_var(x) for x in eqn.invars]
|
||||
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
|
||||
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
|
||||
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
|
||||
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
|
||||
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
|
||||
new_eqns)
|
||||
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
|
||||
|
@ -41,7 +41,6 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import profiler
|
||||
@ -61,11 +60,11 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping,
|
||||
@ -99,7 +98,6 @@ ShardedAxis = sharding_specs.ShardedAxis
|
||||
Replicated = sharding_specs.Replicated
|
||||
|
||||
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
||||
Mesh = mesh_lib.Mesh
|
||||
MeshAxisName = sharding_impls.MeshAxisName
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
@ -108,8 +106,6 @@ ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
|
||||
def to_xc_copy_semantics(copy_semantics):
|
||||
if xla_extension_version < 296:
|
||||
return [None] * len(copy_semantics)
|
||||
out = []
|
||||
for cs in copy_semantics:
|
||||
if cs is None or cs == dispatch.CopySemantics.ALIAS:
|
||||
@ -234,16 +230,20 @@ shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: JSharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
from jax._src import array
|
||||
util.test_event("batched_device_put_start")
|
||||
try:
|
||||
from jax._src import array
|
||||
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
finally:
|
||||
util.test_event("batched_device_put_end")
|
||||
|
||||
def _shard_aval(size, axis: int, aval):
|
||||
try:
|
||||
@ -1722,20 +1722,19 @@ def _get_and_check_device_assignment(
|
||||
devices: Sequence[xc.Device] | None,
|
||||
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
|
||||
first_sharding_info = None
|
||||
if devices is None:
|
||||
devices = ()
|
||||
else:
|
||||
devices = tuple(devices)
|
||||
devices = () if devices is None else tuple(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if isinstance(i, UnspecifiedValue):
|
||||
for sh, s_type, source_info in shardings:
|
||||
if isinstance(sh, UnspecifiedValue):
|
||||
continue
|
||||
if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
|
||||
continue
|
||||
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (
|
||||
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
|
||||
else (i._device_assignment, s_type, source_info))
|
||||
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
|
||||
(sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO)
|
||||
else (sh._device_assignment, s_type, source_info))
|
||||
arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO)
|
||||
else sh._device_assignment)
|
||||
if not devices:
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
@ -1836,7 +1835,8 @@ class SemanticallyEqualShardings:
|
||||
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
|
||||
avals: tuple[core.AbstractValue]):
|
||||
gspmd_shardings = [
|
||||
s if isinstance(s, (UnspecifiedValue, AUTO))
|
||||
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
|
||||
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
|
||||
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
|
||||
for s, a in zip(shardings, avals)]
|
||||
self._gspmd_shardings = gspmd_shardings
|
||||
@ -1894,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
abstract_mesh: mesh_lib.AbstractMesh | None):
|
||||
abstract_mesh: AbstractMesh | None):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -2000,6 +2000,8 @@ def jaxpr_transfer_mem_kinds(
|
||||
|
||||
|
||||
def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
if da_object is None:
|
||||
return True
|
||||
try:
|
||||
default_mem_kind = da_object.default_memory_kind
|
||||
except:
|
||||
@ -2081,6 +2083,41 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
|
||||
return tuple(safe_map(read, jaxpr.outvars))
|
||||
|
||||
|
||||
def _get_num_devices(
|
||||
shardings, device_assignment, lowering_platforms, prim_requires_devices
|
||||
) -> tuple[int, tuple[xc.Device, ...] | None]:
|
||||
ext_abstract_mesh, concrete_sharding = None, False
|
||||
for s in shardings:
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
continue
|
||||
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
||||
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
|
||||
raise ValueError("AbstractMesh should be the same across all "
|
||||
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
|
||||
ext_abstract_mesh = s.mesh
|
||||
else:
|
||||
concrete_sharding = True
|
||||
if (concrete_sharding and ext_abstract_mesh is not None and
|
||||
len(device_assignment) != ext_abstract_mesh.size):
|
||||
raise ValueError(
|
||||
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
|
||||
f" device assignment size: {len(device_assignment)}")
|
||||
if concrete_sharding:
|
||||
return len(device_assignment), device_assignment
|
||||
if ext_abstract_mesh is None:
|
||||
return len(device_assignment), device_assignment
|
||||
if lowering_platforms is None:
|
||||
raise ValueError(
|
||||
"Passing lowering_platforms via"
|
||||
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
|
||||
" only AbstractMesh exists in a jitted computation.")
|
||||
if prim_requires_devices:
|
||||
raise ValueError(
|
||||
"AbstractMesh cannot be used when jaxpr contains primitives that"
|
||||
" require devices to be present during lowering.")
|
||||
return ext_abstract_mesh.size, None
|
||||
|
||||
|
||||
MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
@ -2125,12 +2162,14 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _abstract_to_concrete_mesh(abstract_mesh):
|
||||
return mesh_lib.Mesh(
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
|
||||
return Mesh(
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
|
||||
axis_types=abstract_mesh.axis_types)
|
||||
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
|
||||
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
|
||||
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
|
||||
a.sharding.spec))
|
||||
else:
|
||||
@ -2150,7 +2189,7 @@ def lower_sharding_computation(
|
||||
donated_invars: Sequence[bool],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
context_mesh: mesh_lib.Mesh | None,
|
||||
context_mesh: Mesh | None,
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
@ -2208,6 +2247,7 @@ def lower_sharding_computation(
|
||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in unique_intermediate_shardings)),
|
||||
devices_from_context)
|
||||
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
out_shardings = _concretize_abstract_shardings(
|
||||
@ -2218,21 +2258,33 @@ def lower_sharding_computation(
|
||||
platforms = lowering_platforms or (
|
||||
getattr(backend, "_raw_platform", backend.platform),)
|
||||
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
|
||||
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
|
||||
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
# TODO(yashkatariya): All device specific logic should go in compilation
|
||||
# but this requires a big refactor. The current `_get_num_devices` logic
|
||||
# is good enough to lower with AbstractMesh but cannot be compiled. Once
|
||||
# I refactor, this will also work well with mesh being provided at
|
||||
# compile time.
|
||||
# Sets device_assignment to None if only abstractMesh and unspecified exists.
|
||||
num_devices, device_assignment = _get_num_devices( # type: ignore
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
unique_intermediate_shardings),
|
||||
device_assignment, lowering_platforms, prim_requires_devices)
|
||||
|
||||
committed = bool(
|
||||
devices_from_context
|
||||
or num_devices > 1
|
||||
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
|
||||
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))
|
||||
|
||||
da_object = (_create_da_object(tuple(device_assignment))
|
||||
if device_assignment is not None else None)
|
||||
|
||||
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings],
|
||||
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||
|
||||
if all_default_mem_kind:
|
||||
propagated_out_mem_kinds = (None,) * len(global_out_avals)
|
||||
@ -2241,12 +2293,12 @@ def lower_sharding_computation(
|
||||
closed_jaxpr, in_shardings)
|
||||
|
||||
# 2. Build up the HLO
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
abstract_mesh = None
|
||||
if prim_requires_devices:
|
||||
assert da_object is not None
|
||||
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
unique_intermediate_shardings):
|
||||
if isinstance(sharding, NamedSharding):
|
||||
if (abstract_mesh is not None and
|
||||
abstract_mesh != sharding.mesh.abstract_mesh):
|
||||
@ -2264,7 +2316,7 @@ def lower_sharding_computation(
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||
semantic_out_shardings, in_layouts, out_layouts, num_devices,
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, platforms,
|
||||
@ -2307,7 +2359,7 @@ def lower_sharding_computation(
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler,
|
||||
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
|
||||
intermediate_shardings=unique_intermediate_shardings,
|
||||
context_mesh=context_mesh)
|
||||
|
||||
|
||||
@ -2477,7 +2529,7 @@ def _register_out_sharding_handler(
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
|
||||
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
|
||||
assert isinstance(orig_in_s.mesh, Mesh)
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
|
||||
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
|
||||
@ -2529,7 +2581,7 @@ def _get_out_sharding_from_orig_sharding(
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
|
||||
intermediate_shardings=None, context_mesh: Mesh | None = None):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
@ -2817,7 +2869,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Any,
|
||||
kept_var_idx: set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device],
|
||||
device_assignment: xc.DeviceList | Sequence[xc.Device] | None,
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
@ -2829,8 +2881,15 @@ class UnloadedMeshExecutable:
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
context_mesh: mesh_lib.Mesh | None = None
|
||||
context_mesh: Mesh | None = None,
|
||||
) -> MeshExecutable:
|
||||
if (device_assignment is None or
|
||||
any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
|
||||
for s in it.chain(in_shardings, out_shardings))):
|
||||
raise RuntimeError(
|
||||
"A jitted computation cannot contain AbstractMesh in in_shardings and"
|
||||
" out_shardings during compilation. You can use `jax.export` to "
|
||||
" lower with an AbstractMesh and later compile with concrete devices.")
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
if isinstance(device_assignment, xc.DeviceList):
|
||||
@ -2851,6 +2910,7 @@ class UnloadedMeshExecutable:
|
||||
mesh = i.mesh
|
||||
break
|
||||
|
||||
util.test_event("pxla_cached_compilation")
|
||||
xla_executable = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
|
@ -37,9 +37,6 @@ map, unsafe_map = safe_map, map
|
||||
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
if not pred:
|
||||
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
||||
@ -91,7 +88,7 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
return [], [], []
|
||||
|
||||
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
|
||||
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
|
||||
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
|
||||
# If we get a `Ref` in the consts, we know it must come from an outer
|
||||
# `run_state`. We also know if shouldn't be boxed up in another tracer.
|
||||
# We assert that it is in fact a DynamicJaxprTracer
|
||||
|
@ -49,7 +49,6 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_avals_short,
|
||||
_check_tree_and_avals,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
@ -135,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
@ -227,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
return false_fun(*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
@ -513,7 +512,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
# jaxpr for each branch.
|
||||
branches_known_ : list[core.ClosedJaxpr] = []
|
||||
branches_staged_: list[core.ClosedJaxpr] = []
|
||||
branch_res_avals: list[core.AbstractValue] = []
|
||||
branch_res_avals: list[list[core.AbstractValue]] = []
|
||||
for jaxpr in branches:
|
||||
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(
|
||||
|
@ -44,7 +44,7 @@ from jax._src.typing import Array
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list, split_dict, weakref_lru_cache)
|
||||
from jax._src.lax.control_flow import loops
|
||||
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
|
||||
from jax._src.lax.control_flow.common import _initial_style_jaxpr
|
||||
import numpy as np
|
||||
|
||||
## JAX utilities
|
||||
@ -196,7 +196,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init_flat = tree_leaves(init)
|
||||
_, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(map(_abstractify, init_flat))
|
||||
carry_avals = tuple(map(core.get_aval, init_flat))
|
||||
jaxpr, _, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, "scan")
|
||||
return jaxpr, out_tree
|
||||
|
@ -47,7 +47,7 @@ from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax import windowed_reductions
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify, _avals_short, _initial_style_jaxpr,
|
||||
_avals_short, _initial_style_jaxpr,
|
||||
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
|
||||
_typecheck_param)
|
||||
from jax._src.lax.other import logaddexp
|
||||
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init_flat, init_tree = tree_flatten(init)
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
carry_avals = tuple(_map(core.get_aval, init_flat))
|
||||
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
|
||||
f, in_tree, (*carry_avals, *x_avals), "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
@ -361,7 +361,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
if p else 'the input carry')
|
||||
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
|
||||
paths, in_carry_flat = unzip2(leaves_and_paths)
|
||||
in_avals = _map(_abstractify, in_carry_flat)
|
||||
in_avals = _map(core.get_aval, in_carry_flat)
|
||||
if in_carry_tree != out_carry_tree:
|
||||
try:
|
||||
out_carry = tree_unflatten(out_carry_tree, out_avals)
|
||||
@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
f'of the carry output is a {thing2}, so {explanation}'
|
||||
for path, thing1, thing2, explanation
|
||||
in equality_errors(in_carry, out_carry)]
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
if len(diffs) == 1:
|
||||
differences = f'{diffs[0]}.\n'.capitalize()
|
||||
else:
|
||||
@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
|
||||
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
|
||||
if not core.typematch(in_aval, out_aval)]
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
if len(diffs) == 1:
|
||||
differences = f'{diffs[0]}.\n'.capitalize()
|
||||
else:
|
||||
@ -1315,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(_abstractify, init_vals))
|
||||
init_avals = tuple(_map(core.get_aval, init_vals))
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
||||
cond_fun, in_tree, init_avals, "while_cond")
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
||||
|
@ -32,7 +32,6 @@ from jax._src.util import split_list, safe_map
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_check_tree,
|
||||
_initial_style_jaxpr,
|
||||
)
|
||||
@ -87,7 +86,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
|
||||
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(_abstractify, guess_flat))
|
||||
guess_avals = tuple(_map(core.get_aval, guess_flat))
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals)
|
||||
|
||||
@ -230,7 +229,7 @@ def custom_linear_solve(
|
||||
transpose_solve = solve
|
||||
|
||||
b_flat, in_args_tree = tree_flatten((b,))
|
||||
b_avals = tuple(_map(_abstractify, b_flat))
|
||||
b_avals = tuple(_map(core.get_aval, b_flat))
|
||||
|
||||
tree, = treedef_children(in_args_tree)
|
||||
|
||||
|
@ -102,27 +102,31 @@ def _validate_shapes(shapes: Sequence[Shape]):
|
||||
else:
|
||||
map(_check_static_shape, shapes)
|
||||
|
||||
def _try_broadcast_shapes(
|
||||
shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...] | None:
|
||||
if len(shapes) == 1: return shapes[0]
|
||||
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
|
||||
"""
|
||||
Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
|
||||
"""
|
||||
if not shapes:
|
||||
raise TypeError(f"{name}: At least one shape is required.")
|
||||
ranks = {len(shape) for shape in shapes}
|
||||
if len(ranks) > 1: return None # must have consistent rank
|
||||
rank = ranks.pop()
|
||||
if not rank: return () # scalar case
|
||||
if len(ranks) != 1:
|
||||
raise TypeError(f'{name}: arrays must have the same number of dimensions,'
|
||||
f' got {ranks}')
|
||||
result_shape = []
|
||||
for ds in unsafe_zip(*shapes):
|
||||
for ds in zip(*shapes):
|
||||
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
|
||||
# if all axes are identical objects, the resulting size is the object
|
||||
result_shape.append(ds[0])
|
||||
else:
|
||||
# if all dims are equal (or 1), the result is the non-1 size (or 1)
|
||||
# if all dims are equal (or 1), the result is the non-1 size
|
||||
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
|
||||
if not non_1s:
|
||||
result_shape.append(1)
|
||||
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
|
||||
result_shape.append(non_1s[0])
|
||||
else:
|
||||
return None
|
||||
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
|
||||
f'{", ".join(map(str, map(tuple, shapes)))}.')
|
||||
return tuple(result_shape)
|
||||
|
||||
def asarray(x: ArrayLike) -> Array:
|
||||
@ -159,24 +163,39 @@ def _broadcast_shapes_uncached(*shapes):
|
||||
if not rst: return fst
|
||||
|
||||
# First check if we need only rank promotion (and not singleton-broadcasting).
|
||||
try: return _reduce(_broadcast_ranks, rst, fst)
|
||||
except ValueError: pass
|
||||
result_shape = _max(shapes, key=len)
|
||||
ndim = len(result_shape)
|
||||
if ndim == 0 or all(core.definitely_equal_shape(result_shape[ndim - len(s):], s) for s in shapes):
|
||||
return result_shape
|
||||
|
||||
# Next try singleton-broadcasting, padding out ranks using singletons.
|
||||
ndim = _max(len(shape) for shape in shapes)
|
||||
shape_list = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
|
||||
result_shape = _try_broadcast_shapes(shape_list)
|
||||
if result_shape is None:
|
||||
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
|
||||
return result_shape
|
||||
rank_promoted_shapes = tuple((*((1,) * (ndim - len(shape))), *shape) for shape in shapes)
|
||||
try:
|
||||
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
|
||||
except TypeError as err:
|
||||
# Raise ValueError here for backward compatibility.
|
||||
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
|
||||
|
||||
def _broadcast_ranks(s1, s2):
|
||||
if len(s1) > len(s2):
|
||||
s1, s2 = s2, s1
|
||||
assert len(s1) <= len(s2)
|
||||
s1_ = s2[len(s2) - len(s1):]
|
||||
if core.definitely_equal_shape(s1_, s1): return s2
|
||||
else: raise ValueError
|
||||
def broadcast_shardings(*avals) -> NamedSharding:
|
||||
fst, *rst = avals
|
||||
if not rst:
|
||||
return fst.sharding
|
||||
|
||||
# First check if we need only rank promotion (and not singleton-broadcasting).
|
||||
res_aval = _max(avals, key=lambda a: a.ndim)
|
||||
ndim = res_aval.ndim
|
||||
if ndim == 0 or all(
|
||||
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
|
||||
return res_aval.sharding
|
||||
|
||||
# Next try singleton-broadcasting, padding out ranks using singletons.
|
||||
aval_list = []
|
||||
for a in avals:
|
||||
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
|
||||
new_shape = (1,) * (ndim - a.ndim) + a.shape
|
||||
aval_list.append(a.update(shape=new_shape,
|
||||
sharding=a.sharding.with_spec(new_spec)))
|
||||
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
|
||||
|
||||
def _identity(x): return x
|
||||
|
||||
@ -1471,7 +1490,7 @@ def reduce(operands: Any,
|
||||
return _convert_element_type(monoid_reducer(*flat_operands, dimensions),
|
||||
weak_type=weak_type)
|
||||
else:
|
||||
flat_init_avals = safe_map(_abstractify, flat_init_values)
|
||||
flat_init_avals = safe_map(core.get_aval, flat_init_values)
|
||||
closed_jaxpr, out_tree = _variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree)
|
||||
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
|
||||
@ -1730,6 +1749,9 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
|
||||
val = ad_util.zeros_like_aval(aval.inner_aval)
|
||||
return core.mutable_array(val)
|
||||
|
||||
# TODO(dougalm): this is nonsense but it's here because in places like
|
||||
# custom_vjp we assume that all arguments have tangent spaces. We could have
|
||||
# a distinct NotATangentType value instead.
|
||||
ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore
|
||||
|
||||
def iota(dtype: DTypeLike, size: int) -> Array:
|
||||
@ -2137,27 +2159,7 @@ def broadcasting_shape_rule(name, *avals):
|
||||
shapes = [aval.shape for aval in avals if aval.shape]
|
||||
if not shapes:
|
||||
return ()
|
||||
if len({len(shape) for shape in shapes}) != 1:
|
||||
msg = '{}: arrays must have same number of dimensions, got {}.'
|
||||
raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes)))))
|
||||
# TODO(mattjj): de-duplicate with _try_broadcast_shapes
|
||||
result_shape = []
|
||||
for ds in zip(*shapes):
|
||||
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
|
||||
# if all axes are identical objects, the resulting size is the object
|
||||
result_shape.append(ds[0])
|
||||
else:
|
||||
# if all dims are equal (or 1), the result is the non-1 size
|
||||
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
|
||||
if not non_1s:
|
||||
result_shape.append(1)
|
||||
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
|
||||
result_shape.append(non_1s[0])
|
||||
else:
|
||||
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
|
||||
f'{", ".join(map(str, map(tuple, shapes)))}.')
|
||||
|
||||
return tuple(result_shape)
|
||||
return _try_broadcast_shapes(*shapes, name=name)
|
||||
|
||||
|
||||
def broadcasting_sharding_rule(name, *avals):
|
||||
@ -2206,7 +2208,7 @@ def broadcasting_sharding_rule(name, *avals):
|
||||
|
||||
|
||||
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
|
||||
require_same_dtypes=False):
|
||||
require_same_dtypes=True):
|
||||
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
|
||||
allow_extended_dtype=allow_extended_dtype,
|
||||
require_same=require_same_dtypes)
|
||||
@ -2759,8 +2761,8 @@ def _add_transpose(t, x, y):
|
||||
# some places (e.g. in custom_jvp) it may not always hold. For example, see
|
||||
# api_test.py's CustomJVPTest.test_jaxpr_zeros.
|
||||
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
||||
if type(t) is ad_util.Zero:
|
||||
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
||||
else:
|
||||
@ -2790,8 +2792,8 @@ def _sub_transpose(t, x, y):
|
||||
# Morally the following assertion is true, but see the comment in add_p's
|
||||
# transpose rule.
|
||||
# assert ad.is_undefined_primal(x) and ad.is_undefined_primal(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else _abstractify(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else _abstractify(y)
|
||||
x_aval = x.aval if ad.is_undefined_primal(x) else core.get_aval(x)
|
||||
y_aval = y.aval if ad.is_undefined_primal(y) else core.get_aval(y)
|
||||
if type(t) is ad_util.Zero:
|
||||
return [ad_util.Zero(x_aval), ad_util.Zero(y_aval)]
|
||||
else:
|
||||
@ -3773,6 +3775,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
if platform == "cpu" and precision not in {
|
||||
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
|
||||
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
|
||||
DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3,
|
||||
DotAlgorithmPreset.BF16_BF16_F32_X6,
|
||||
}:
|
||||
raise ValueError(
|
||||
f"The precision '{precision}' is not supported by dot_general on CPU")
|
||||
@ -5502,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
|
||||
|
||||
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
|
||||
shape = primals[0].shape
|
||||
iotas = []
|
||||
for dim, size in enumerate(shape):
|
||||
iotas.append(broadcasted_iota(np.int64, shape, dim))
|
||||
sorted_primals_and_idx = sort_p.bind(
|
||||
*primals, iotas[dimension], dimension=dimension,
|
||||
is_stable=is_stable, num_keys=num_keys)
|
||||
idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i]
|
||||
for i in range(len(shape)))
|
||||
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
|
||||
*primals, broadcasted_iota(np.uint64, shape, dimension),
|
||||
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
|
||||
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
|
||||
dimension))
|
||||
dnums = slicing.GatherDimensionNumbers(
|
||||
offset_dims=(),
|
||||
collapsed_slice_dims=(dimension,),
|
||||
start_index_map=(dimension,),
|
||||
operand_batching_dims=batch_dims,
|
||||
start_indices_batching_dims=batch_dims,
|
||||
)
|
||||
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
|
||||
gather_idx = partial(
|
||||
slicing.gather,
|
||||
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
|
||||
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
|
||||
)
|
||||
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
|
||||
for t in tangents]
|
||||
return tuple(sorted_primals_and_idx[:-1]), tangents_out
|
||||
|
||||
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):
|
||||
@ -6381,10 +6396,6 @@ def _eq_meet(a, b):
|
||||
return eq(a, b)
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
def empty(dtype):
|
||||
return empty_p.bind(dtype=dtype)
|
||||
empty_p = core.Primitive('empty')
|
||||
@ -6494,3 +6505,7 @@ optimization_barrier_p.def_impl(
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
||||
|
||||
def _optimization_barrier_batcher(batched_args, batch_dims, **params):
|
||||
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
|
||||
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
|
||||
|
@ -48,6 +48,7 @@ from jax._src.lax.lax import (
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -1328,7 +1329,6 @@ def _triangular_solve_lowering(
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))]
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
||||
|
||||
def _triangular_solve_cpu_lower(
|
||||
ctx, a, b, *, left_side, lower, transpose_a,
|
||||
@ -1341,10 +1341,12 @@ def _triangular_solve_cpu_lower(
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
return lapack.trsm_hlo(
|
||||
a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
|
||||
b_shape_vals=b_shape_vals)
|
||||
*ctx_args, a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
|
||||
b_shape_vals=b_shape_vals)
|
||||
else:
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
# TODO: Consider swapping XLA for LAPACK in batched case
|
||||
@ -1357,6 +1359,8 @@ def _triangular_solve_cpu_lower(
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))]
|
||||
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
||||
platform='cpu')
|
||||
|
||||
@ -2616,37 +2620,54 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable,
|
||||
a_shape_vals=a_shape_vals)
|
||||
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
T, vs, *_, info = gees_result
|
||||
if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat():
|
||||
schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result
|
||||
else:
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
schur_form, schur_vectors, *_, info = gees_result
|
||||
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
|
||||
select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
T = _broadcasting_select_hlo(
|
||||
select_schur_form_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
schur_form = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_T_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_T_aval,
|
||||
T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
|
||||
output = [T]
|
||||
mlir.broadcast_in_dim(
|
||||
ctx,
|
||||
ok,
|
||||
select_schur_form_aval,
|
||||
broadcast_dimensions=range(len(batch_dims)),
|
||||
),
|
||||
select_schur_form_aval,
|
||||
schur_form,
|
||||
ctx.avals_out[0],
|
||||
_nan_like_hlo(ctx, ctx.avals_out[0]),
|
||||
ctx.avals_out[0],
|
||||
)
|
||||
output = [schur_form]
|
||||
if compute_schur_vectors:
|
||||
select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
vs = _broadcasting_select_hlo(
|
||||
schur_vectors = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_vs_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
mlir.broadcast_in_dim(
|
||||
ctx, ok, select_vs_aval, broadcast_dimensions=range(len(batch_dims))
|
||||
),
|
||||
select_vs_aval,
|
||||
vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
|
||||
schur_vectors,
|
||||
ctx.avals_out[1],
|
||||
_nan_like_hlo(ctx, ctx.avals_out[1]),
|
||||
ctx.avals_out[1],
|
||||
)
|
||||
|
||||
output.append(vs)
|
||||
output.append(schur_vectors)
|
||||
|
||||
return output
|
||||
|
||||
@ -2819,24 +2840,35 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return tridiagonal(x), 0
|
||||
return tridiagonal(x, lower=lower), 0
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower):
|
||||
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower, platform):
|
||||
a_aval, = ctx.avals_in
|
||||
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
|
||||
cpu_args = []
|
||||
if platform == "cpu":
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
cpu_args.extend(ctx_args)
|
||||
a, d, e, taus, info = sytrd_impl(*cpu_args, a_aval.dtype, a, lower=lower)
|
||||
return a, d, e, taus, info
|
||||
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo),
|
||||
platform='cpu')
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo, platform="cpu"),
|
||||
platform="cpu",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd),
|
||||
platform='cuda')
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd, platform="cuda"),
|
||||
platform="cuda",
|
||||
)
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd),
|
||||
platform='rocm')
|
||||
tridiagonal_p,
|
||||
partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd, platform="rocm"),
|
||||
platform="rocm",
|
||||
)
|
||||
|
||||
# Utilities
|
||||
|
||||
|
@ -457,6 +457,55 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
|
||||
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
"""Ragged version of :func:`all_to_all`.
|
||||
|
||||
For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent
|
||||
and the outermost (ragged) dimension. ``axis_index_groups`` is default to all
|
||||
replicas (e.g. there is only one group and covers all axis indices).
|
||||
|
||||
Ragged arrays are defined by a set of three arrays:
|
||||
* ``data``: the ``data`` array is "ragged" along its outermost dimension,
|
||||
along which each indexed element has variable size.
|
||||
* ``offsets``: the ``offsets`` array indexes the outermost dimension of the
|
||||
``data`` array, and represents the starting offset of each ragged element of
|
||||
the ``data`` array.
|
||||
* ``sizes``: the ``sizes`` array represents the size of each ragged element of
|
||||
the ``data`` array, where the size is specified in units of sub-elements. A
|
||||
sub-element is defined as the suffix of the ``data`` array shape obtained by
|
||||
removing the outermost "ragged" dimension.
|
||||
The ``offsets`` and ``sizes`` arrays must have the same size.
|
||||
|
||||
# Example ragged tensor
|
||||
data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}}
|
||||
offsets: [3] = {0, 1, 4}
|
||||
sizes: [3] = {1, 3, 4}
|
||||
|
||||
# Index 'data' at 'offsets'[0], 'sizes'[0]'
|
||||
{a,b,c}
|
||||
|
||||
# Index 'data' at 'offsets'[1], 'sizes'[1]'
|
||||
{d,e,f},{g,h,i},{j,k,l}
|
||||
|
||||
# Index 'data' at 'offsets'[2], 'sizes'[2]'
|
||||
{m,n,o},{p,q,r},{s,t,u},{v,w,x}
|
||||
|
||||
Args:
|
||||
operand: array with ragged dimension along its outermost dimension.
|
||||
output: array of ragged input offsets.
|
||||
input_offsets: array of ragged input send sizes.
|
||||
send_sizes: array of ragged output data.
|
||||
output_offsets: array of ragged output offsets.
|
||||
recv_sizes: array of ragged output receive sizes.
|
||||
Returns:
|
||||
array with shape equal to ``output``.
|
||||
"""
|
||||
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
|
||||
output_offsets, recv_sizes)
|
||||
|
||||
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
||||
|
||||
|
||||
def axis_index(axis_name):
|
||||
"""Return the index along the mapped axis ``axis_name``.
|
||||
|
||||
@ -1052,6 +1101,64 @@ batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
|
||||
batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
N = input_offsets.type.shape[0]
|
||||
backend_config = ir.DictAttr.get({
|
||||
'replica_groups': ir.DenseIntElementsAttr.get(
|
||||
np.arange(0, N, 1, dtype=np.int64), shape=[1, N]
|
||||
)
|
||||
})
|
||||
return hlo.CustomCallOp(
|
||||
result=[output.type],
|
||||
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
|
||||
recv_sizes],
|
||||
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
|
||||
backend_config=backend_config,
|
||||
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
|
||||
).results
|
||||
|
||||
@ragged_all_to_all_p.def_abstract_eval
|
||||
def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes):
|
||||
if operand.shape != output.shape:
|
||||
raise ValueError('ragged_all_to_all input and output shapes must be equal.')
|
||||
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
|
||||
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
|
||||
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
|
||||
raise ValueError("ragged_all_to_all send_sizes must be integer type.")
|
||||
if not dtypes.issubdtype(output_offsets.dtype, np.integer):
|
||||
raise ValueError("ragged_all_to_all output_offsets must be integer type.")
|
||||
if not dtypes.issubdtype(recv_sizes.dtype, np.integer):
|
||||
raise ValueError("ragged_all_to_all recv_sizes must be integer type.")
|
||||
if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1:
|
||||
raise ValueError(
|
||||
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
||||
" size, but got shape {}".format(input_offsets.shape)
|
||||
)
|
||||
if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1:
|
||||
raise ValueError(
|
||||
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape {}".format(send_sizes.shape)
|
||||
)
|
||||
if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1:
|
||||
raise ValueError(
|
||||
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
||||
" dimension size, but got shape {}".format(output_offsets.shape)
|
||||
)
|
||||
if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1:
|
||||
raise ValueError(
|
||||
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
||||
" size, but got shape {}".format(recv_sizes.shape)
|
||||
)
|
||||
return output.update(
|
||||
shape=list(output.shape),
|
||||
dtype=output.dtype,
|
||||
weak_type=output.weak_type,
|
||||
)
|
||||
|
||||
ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p))
|
||||
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
|
||||
|
||||
|
||||
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
|
||||
"""Gather values of x across all replicas.
|
||||
|
||||
@ -1484,7 +1591,25 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
axis_name, = axis_name
|
||||
if axis_name not in axis_env.names:
|
||||
raise NameError(f"unbound axis name: {axis_name}")
|
||||
axis_context = ctx.module_context.axis_context
|
||||
axis_pos = list(axis_env.names).index(axis_name)
|
||||
|
||||
# For partial auto, lower using iota.
|
||||
if (isinstance(axis_context, SPMDAxisContext) and
|
||||
axis_context.manual_axes and
|
||||
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
|
||||
x = hlo.iota(ir.RankedTensorType.get(
|
||||
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
|
||||
sharding_proto = (
|
||||
NamedSharding(axis_context.mesh, P(axis_name))
|
||||
._to_xla_hlo_sharding(1).to_proto())
|
||||
aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32)
|
||||
aval_out = ShapedArray((1,), np.int32)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto)
|
||||
proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh)
|
||||
x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto)
|
||||
return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x)
|
||||
|
||||
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
|
||||
div = mlir.ir_constant(
|
||||
np.array(
|
||||
@ -1492,12 +1617,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
)
|
||||
)
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(SPMDAxisContext, ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
|
||||
device_id = hlo.partition_id()
|
||||
else:
|
||||
device_id = hlo.replica_id()
|
||||
|
@ -483,7 +483,7 @@ def scatter_add(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.add,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_add_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -536,7 +536,7 @@ def scatter_sub(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(
|
||||
lax.sub, lax._abstractify(lax._const(operand, 0))
|
||||
lax.sub, core.get_aval(lax._const(operand, 0))
|
||||
)
|
||||
return scatter_sub_p.bind(
|
||||
operand,
|
||||
@ -591,7 +591,7 @@ def scatter_mul(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul,
|
||||
lax._abstractify(lax._const(operand, 1)))
|
||||
core.get_aval(lax._const(operand, 1)))
|
||||
return scatter_mul_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -638,7 +638,7 @@ def scatter_min(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.min,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_min_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -685,7 +685,7 @@ def scatter_max(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.max,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.get_aval(lax._const(operand, 0)))
|
||||
return scatter_max_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -748,7 +748,7 @@ def scatter_apply(
|
||||
_apply = _scatter_apply_cache.setdefault(func, _apply)
|
||||
except TypeError: # func is not weak referenceable
|
||||
pass
|
||||
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
|
||||
jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand)))
|
||||
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, unused, update_jaxpr=jaxpr,
|
||||
|
@ -90,7 +90,7 @@ def _reduce_window(
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation)
|
||||
else:
|
||||
flat_init_avals = map(lax._abstractify, flat_init_values)
|
||||
flat_init_avals = map(core.get_aval, flat_init_values)
|
||||
jaxpr, out_tree = lax._variadic_reduction_jaxpr(
|
||||
computation, tuple(flat_init_avals), init_value_tree
|
||||
)
|
||||
@ -176,7 +176,7 @@ def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None) -> Array:
|
||||
init_value = lax._const(operand, 1)
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
|
||||
jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(init_value))
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
@ -226,7 +226,7 @@ def _reduce_window_logaddexp(
|
||||
base_dilation: Sequence[int] | None = None,
|
||||
window_dilation: Sequence[int] | None = None) -> Array:
|
||||
init_value = lax._const(operand, -np.inf)
|
||||
jaxpr, consts = lax._reduction_jaxpr(logaddexp, lax._abstractify(init_value))
|
||||
jaxpr, consts = lax._reduction_jaxpr(logaddexp, core.get_aval(init_value))
|
||||
if base_dilation is None:
|
||||
base_dilation = (1,) * len(window_dimensions)
|
||||
if window_dilation is None:
|
||||
@ -245,9 +245,9 @@ def _select_and_scatter(operand: Array, select: Callable,
|
||||
padding: Sequence[tuple[int, int]], source: Array,
|
||||
init_value: Array, scatter: Callable) -> Array:
|
||||
select_jaxpr, select_consts = lax._reduction_jaxpr(
|
||||
select, lax._abstractify(init_value))
|
||||
select, core.get_aval(init_value))
|
||||
scatter_jaxpr, scatter_consts = lax._reduction_jaxpr(
|
||||
scatter, lax._abstractify(init_value))
|
||||
scatter, core.get_aval(init_value))
|
||||
return select_and_scatter_p.bind(
|
||||
operand, source, init_value, select_jaxpr=select_jaxpr,
|
||||
select_consts=select_consts, scatter_jaxpr=scatter_jaxpr,
|
||||
|
@ -111,8 +111,6 @@ class AxisTypes(enum.Enum):
|
||||
return self.name
|
||||
|
||||
def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
|
||||
if axis_types is None:
|
||||
return {}
|
||||
d = {}
|
||||
for t, names in axis_types.items():
|
||||
if isinstance(names, tuple):
|
||||
@ -124,6 +122,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
|
||||
|
||||
class Mesh(contextlib.ContextDecorator):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
@ -178,11 +177,11 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
devices: np.ndarray
|
||||
axis_names: tuple[MeshAxisName, ...]
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
|
||||
axis_types: MeshAxisType
|
||||
|
||||
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
|
||||
axis_names: str | Sequence[MeshAxisName],
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
|
||||
axis_names: str | Sequence[MeshAxisName], *,
|
||||
axis_types: MeshAxisType | None = None):
|
||||
if not isinstance(devices, np.ndarray):
|
||||
devices = np.array(devices)
|
||||
if isinstance(axis_names, str):
|
||||
@ -198,9 +197,9 @@ class Mesh(contextlib.ContextDecorator):
|
||||
f"devices.ndim == {devices.ndim} and "
|
||||
f"len(axis_names) == {len(axis_names)}.")
|
||||
|
||||
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
|
||||
axis_types_tuple = (None if axis_types is None else
|
||||
tuple(axis_types.items()))
|
||||
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
|
||||
axis_types)
|
||||
axis_types_tuple = tuple(axis_types.items())
|
||||
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
|
||||
val = _mesh_object_dict.get(key, None)
|
||||
if val is not None:
|
||||
@ -216,7 +215,8 @@ class Mesh(contextlib.ContextDecorator):
|
||||
return self
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (self.devices, self.axis_names, self.axis_types))
|
||||
return (type(self), (self.devices, self.axis_names),
|
||||
{'axis_types': self.axis_types})
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mesh):
|
||||
@ -335,7 +335,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
def _repr(self):
|
||||
if self.empty:
|
||||
return "Mesh(device_ids=[], axis_names=())"
|
||||
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
|
||||
atr = f", axis_types={self.axis_types}"
|
||||
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
|
||||
|
||||
def __repr__(self):
|
||||
@ -348,7 +348,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
@functools.cached_property
|
||||
def abstract_mesh(self):
|
||||
return AbstractMesh(self.shape_tuple, self.axis_types)
|
||||
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
@ -373,17 +373,16 @@ class AbstractMesh:
|
||||
details.
|
||||
"""
|
||||
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
|
||||
axis_types: MeshAxisType | None = None):
|
||||
self.shape_tuple = shape_tuple
|
||||
self.axis_types = axis_types
|
||||
if self.shape_tuple:
|
||||
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
|
||||
else:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
|
||||
self._axis_types_tuple = (None if axis_types is None else
|
||||
tuple(axis_types.items()))
|
||||
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
|
||||
else axis_types)
|
||||
self._axis_types_tuple = tuple(self.axis_types.items())
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape_tuple, self._axis_types_tuple))
|
||||
@ -397,7 +396,7 @@ class AbstractMesh:
|
||||
self._axis_types_tuple == other._axis_types_tuple)
|
||||
|
||||
def __repr__(self):
|
||||
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
|
||||
atr = f", axis_types={self.axis_types}"
|
||||
return f"AbstractMesh({self.shape_tuple}{atr})"
|
||||
|
||||
@property
|
||||
@ -430,10 +429,20 @@ class AbstractMesh:
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_collective(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
_raise_value_error("devices")
|
||||
|
@ -386,7 +386,7 @@ def _create_device_mesh_for_nd_torus_splitting_axes(
|
||||
)
|
||||
):
|
||||
best_logical_axis_assignment = logical_axis_assignment
|
||||
assignment[:, logical_axis] = best_logical_axis_assignment
|
||||
assignment[:, logical_axis] = best_logical_axis_assignment # type: ignore # numpy 2.2
|
||||
|
||||
# Read out the assignment.
|
||||
logical_mesh = _generate_logical_mesh(
|
||||
@ -597,10 +597,10 @@ def _generate_logical_mesh(
|
||||
zip(logical_indices, physical_indices, range(len(logical_indices)))
|
||||
)
|
||||
)
|
||||
logical_mesh = np.transpose(logical_mesh, transpose_axes)
|
||||
logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2
|
||||
|
||||
# Reshape to add the trivial dimensions back.
|
||||
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)
|
||||
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2
|
||||
|
||||
return logical_mesh
|
||||
|
||||
|
@ -67,7 +67,7 @@ def relu(x: ArrayLike) -> Array:
|
||||
|
||||
For more information see
|
||||
`Numerical influence of ReLU’(0) on backpropagation
|
||||
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
|
||||
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
|
||||
|
||||
Args:
|
||||
x : input array
|
||||
@ -84,7 +84,7 @@ def relu(x: ArrayLike) -> Array:
|
||||
|
||||
"""
|
||||
return jnp.maximum(x, 0)
|
||||
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
|
||||
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
|
||||
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||||
|
||||
@jax.jit
|
||||
|
@ -9168,6 +9168,89 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
|
||||
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)
|
||||
|
||||
|
||||
@export
|
||||
@jit
|
||||
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
"""Batched matrix-vector product.
|
||||
|
||||
JAX implementation of :func:`numpy.matvec`.
|
||||
|
||||
Args:
|
||||
x1: array of shape ``(..., M, N)``
|
||||
x2: array of shape ``(..., N)``. Leading dimensions must be broadcast-compatible
|
||||
with leading dimensions of ``x1``.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(..., M)`` containing the batched matrix-vector product.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
|
||||
- :func:`jax.numpy.vecmat`: vector-matrix product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
|
||||
Examples:
|
||||
Simple matrix-vector product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> x2 = jnp.array([7, 8, 9])
|
||||
>>> jnp.matvec(x1, x2)
|
||||
Array([ 50, 122], dtype=int32)
|
||||
|
||||
Batched matrix-vector product:
|
||||
|
||||
>>> x2 = jnp.array([[7, 8, 9],
|
||||
... [5, 6, 7]])
|
||||
>>> jnp.matvec(x1, x2)
|
||||
Array([[ 50, 122],
|
||||
[ 38, 92]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("matvec", x1, x2)
|
||||
return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2)
|
||||
|
||||
|
||||
@export
|
||||
@jit
|
||||
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
"""Batched conjugate vector-matrix product.
|
||||
|
||||
JAX implementation of :func:`numpy.vecmat`.
|
||||
|
||||
Args:
|
||||
x1: array of shape ``(..., M)``.
|
||||
x2: array of shape ``(..., M, N)``. Leading dimensions must be broadcast-compatible
|
||||
with leading dimensions of ``x1``.
|
||||
|
||||
Returns:
|
||||
An array of shape ``(..., N)`` containing the batched conjugate vector-matrix product.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
|
||||
- :func:`jax.numpy.matvec`: matrix-vector product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
|
||||
Examples:
|
||||
Simple vector-matrix product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3]])
|
||||
>>> x2 = jnp.array([[4, 5],
|
||||
... [6, 7],
|
||||
... [8, 9]])
|
||||
>>> jnp.vecmat(x1, x2)
|
||||
Array([[40, 46]], dtype=int32)
|
||||
|
||||
Batched vector-matrix product:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 3],
|
||||
... [4, 5, 6]])
|
||||
>>> jnp.vecmat(x1, x2)
|
||||
Array([[ 40, 46],
|
||||
[ 94, 109]], dtype=int32)
|
||||
"""
|
||||
util.check_arraylike("matvec", x1, x2)
|
||||
return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2)
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
def vdot(
|
||||
@ -9244,6 +9327,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.vdot`: flattened vector product.
|
||||
- :func:`jax.numpy.vecmat`: vector-matrix product.
|
||||
- :func:`jax.numpy.matmul`: general matrix multiplication.
|
||||
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
|
||||
|
||||
|
@ -1374,7 +1374,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
|
||||
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
|
||||
else:
|
||||
if rcond is None:
|
||||
rcond = jnp.finfo(dtype).eps * max(n, m)
|
||||
rcond = float(jnp.finfo(dtype).eps) * max(n, m)
|
||||
else:
|
||||
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
|
||||
u, s, vt = svd(a, full_matrices=False)
|
||||
@ -1517,7 +1517,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
|
||||
|
||||
@export
|
||||
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
|
||||
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array:
|
||||
"""Compute the norm of a matrix or stack of matrices.
|
||||
|
||||
JAX implementation of :func:`numpy.linalg.matrix_norm`
|
||||
|
@ -246,7 +246,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
|
||||
|
||||
# set rcond
|
||||
if rcond is None:
|
||||
rcond = len(x_arr) * finfo(x_arr.dtype).eps
|
||||
rcond = len(x_arr) * float(finfo(x_arr.dtype).eps)
|
||||
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
|
||||
# set up least squares equation for powers of x
|
||||
lhs = vander(x_arr, order)
|
||||
|
@ -23,6 +23,7 @@ from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import api_util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
@ -213,14 +214,18 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
avals = [api_util.shaped_abstractify(arg) for arg in args]
|
||||
shapes = [a.shape for a in avals]
|
||||
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
|
||||
if config.sharding_in_types.value else None)
|
||||
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
|
||||
|
||||
|
||||
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
|
||||
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
|
||||
) -> Array:
|
||||
check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
@ -240,7 +245,8 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape) -> Array:
|
||||
if nlead < 0 or not compatible:
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))))
|
||||
return lax.broadcast_in_dim(arr, shape, tuple(range(nlead, len(shape))),
|
||||
sharding=sharding)
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
|
@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices(
|
||||
|
||||
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||||
for _ in range(num_replicas):
|
||||
indices[next(device_it)] = idxs
|
||||
indices[next(device_it)] = idxs # type: ignore # numpy 2.2
|
||||
return indices
|
||||
|
||||
|
||||
|
@ -13,14 +13,17 @@
|
||||
# limitations under the License.
|
||||
"""Helper tool for automatic cost estimation."""
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src.state import discharge
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map
|
||||
@ -87,10 +90,9 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate:
|
||||
A pallas_core.CostEstimate object containing the cost estimate.
|
||||
"""
|
||||
flattened_args, treedef = jax.tree.flatten(args)
|
||||
def _partial_fun(*flat_args):
|
||||
return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs)
|
||||
wrapped_fun = lu.wrap_init(
|
||||
lambda *args, **kwargs: (_partial_fun(*args, **kwargs),))
|
||||
partial_fun = functools.partial(fun, **kwargs)
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun),
|
||||
treedef)
|
||||
avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts))
|
||||
@ -243,3 +245,12 @@ def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_):
|
||||
bytes_accessed=inner_cost.bytes_accessed,
|
||||
)
|
||||
register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule)
|
||||
|
||||
def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2):
|
||||
inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr))
|
||||
return CostEstimate(
|
||||
flops=inner_cost.flops,
|
||||
transcendentals=inner_cost.transcendentals,
|
||||
bytes_accessed=inner_cost.bytes_accessed,
|
||||
)
|
||||
register_cost_rule(discharge.run_state_p, _run_state_rule)
|
||||
|
@ -63,7 +63,7 @@ from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state.types import RefBitcaster, RefReshaper
|
||||
from jax._src.state.utils import dtype_bitwidth
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax._src.typing import Array, DTypeLike
|
||||
from jax._src.util import safe_map
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import split_list
|
||||
@ -2295,7 +2295,49 @@ _cmpf_lowering_types = {
|
||||
}
|
||||
|
||||
|
||||
def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
|
||||
# The relationship between comparison operations on booleans and boolean
|
||||
# algebra is as follows:
|
||||
# eq(x, y) = !(x ^ y)
|
||||
# ne(x, y) = x ^ y
|
||||
# lt(x, y) = !x && y
|
||||
# le(x, y) = !x || y
|
||||
# gt(x, y) = x && !y
|
||||
# ge(x, y) = x || !y
|
||||
def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array):
|
||||
"""A helper function for lowering comparison operations for boolean inputs.
|
||||
|
||||
Args:
|
||||
primitive: A JAX primitive representing a comparison operation, which is
|
||||
one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals),
|
||||
`lax.lt_p` (less than), `lax.le_p` (less than or equal to),
|
||||
`lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to).
|
||||
x: A boolean array representing the first operand in the comparison.
|
||||
y: A boolean array representing the second operand in the comparison.
|
||||
|
||||
Returns:
|
||||
A boolean array that is the result of applying the comparison operation
|
||||
between `x` and `y` based on the given primitive.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported comparison primitive is provided.
|
||||
"""
|
||||
if primitive == lax.eq_p:
|
||||
return jnp.logical_not(jnp.logical_xor(x, y))
|
||||
elif primitive == lax.ne_p:
|
||||
return jnp.logical_xor(x, y)
|
||||
elif primitive == lax.lt_p:
|
||||
return jnp.logical_and(jnp.logical_not(x), y)
|
||||
elif primitive == lax.le_p:
|
||||
return jnp.logical_or(jnp.logical_not(x), y)
|
||||
elif primitive == lax.gt_p:
|
||||
return jnp.logical_and(x, jnp.logical_not(y))
|
||||
elif primitive == lax.ge_p:
|
||||
return jnp.logical_or(x, jnp.logical_not(y))
|
||||
else:
|
||||
raise ValueError(f"Unsupported comparison primitive: {primitive}")
|
||||
|
||||
|
||||
def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
if x_aval.dtype != y_aval.dtype:
|
||||
@ -2304,60 +2346,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
|
||||
)
|
||||
dtype = x_aval.dtype
|
||||
|
||||
# For boolean comparisons, we handle them in two different ways. For `ne`,
|
||||
# we directly use the xor operation since they are equivalent. For all
|
||||
# other comparisons, we convert the boolean values to `int32` and use select
|
||||
# operations to perform the comparison.
|
||||
#
|
||||
# The relationship between comparison operations on booleans and boolean
|
||||
# algebra is as follows:
|
||||
#
|
||||
# eq(a, b) = !(a ^ b)
|
||||
# ne(a, b) = a ^ b
|
||||
# lt(a, b) = !a && b
|
||||
# le(a, b) = !a || b
|
||||
# gt(a, b) = a && !b
|
||||
# ge(a, b) = a || !b
|
||||
#
|
||||
# However, except for `ne`, all other operations require negation, which is
|
||||
# currently not supported. At present, even if negation were supported,
|
||||
# it would still need to be implemented using `select` operations, making
|
||||
# it equivalent to our current approach. For more details on negation support,
|
||||
# see https://github.com/jax-ml/jax/issues/24243.
|
||||
if jnp.issubdtype(dtype, jnp.bool_):
|
||||
if prim == lax.ne_p:
|
||||
return arith.xori(x, y)
|
||||
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
vtype = ir.VectorType.get(x_aval.shape, i32)
|
||||
|
||||
# Convert `x` and `y` from `bool` to `int32` for comparison, with 2
|
||||
# for true and 0 for false. For example, comparing `x > y` is equivalent
|
||||
# to `(x ? 2 : 0) > (y ? 2 : 0)`.
|
||||
#
|
||||
# Note that we cannot use 1 for true because the select operation will be
|
||||
# misteriously eliminated.
|
||||
two = arith.constant(i32, 2)
|
||||
zero = arith.constant(i32, 0)
|
||||
|
||||
out_aval, = ctx.avals_out
|
||||
if out_aval.shape != ():
|
||||
# broadcast to vectors if we are comparing vectors
|
||||
two = vector.broadcast(vtype, two)
|
||||
zero = vector.broadcast(vtype, zero)
|
||||
|
||||
x = arith.select(x, two, zero)
|
||||
y = arith.select(y, two, zero)
|
||||
dtype = jnp.int32
|
||||
return lower_fun(
|
||||
functools.partial(_cmp_boolean_lowering_helper, primitive),
|
||||
multiple_results=False,
|
||||
)(ctx, x, y)
|
||||
|
||||
if jnp.issubdtype(dtype, jnp.integer):
|
||||
is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
|
||||
pred = (_cmpui_lowering_types if is_uint else _cmpsi_lowering_types)[prim]
|
||||
pred = (
|
||||
_cmpui_lowering_types if is_uint else _cmpsi_lowering_types
|
||||
)[primitive]
|
||||
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
||||
return arith.cmpi(predicate, x, y)
|
||||
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
pred = _cmpf_lowering_types[prim]
|
||||
pred = _cmpf_lowering_types[primitive]
|
||||
predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
|
||||
return arith.cmpf(predicate, x, y)
|
||||
|
||||
@ -2381,6 +2385,15 @@ lowering_rules[lax.and_p] = _and_lowering_rule
|
||||
skip_mlir_conversions.add(lax.and_p)
|
||||
|
||||
|
||||
def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
out_aval, = ctx.avals_out
|
||||
out_type = aval_to_ir_type(out_aval)
|
||||
return _not_lowering_rule(ctx, tpu.weird(out_type, x))
|
||||
|
||||
|
||||
lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule
|
||||
|
||||
|
||||
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
return arith.ori(x, y)
|
||||
|
@ -21,10 +21,9 @@ import tempfile
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax import dtypes
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_src_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import tpu_custom_call
|
||||
from jax._src.interpreters import mlir
|
||||
@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
# Replace in_avals to physical avals.
|
||||
# This step is required for mapping logical types to physical types.
|
||||
# (e.g. PRNG key -> uint32[2])
|
||||
physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
|
||||
physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
|
||||
ctx = ctx.replace(avals_in=physical_avals)
|
||||
|
||||
# Booleans are loaded into the kernel as integers.
|
||||
|
@ -139,7 +139,7 @@ def roll(
|
||||
@roll_p.def_abstract_eval
|
||||
def _roll_abstract_eval(x, shift, **_):
|
||||
del shift
|
||||
return jax_core.raise_to_shaped(x)
|
||||
return x
|
||||
|
||||
|
||||
def _roll_lowering_rule(
|
||||
|
@ -44,6 +44,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":lowering",
|
||||
"//jax",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax/_src/pallas",
|
||||
|
@ -1256,13 +1256,11 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
|
||||
[x_aval] = ctx.avals_in
|
||||
match x.layout:
|
||||
case mgpu.WGStridedFragLayout():
|
||||
if axes != (0,):
|
||||
raise NotImplementedError("No support for axes other than 0 yet")
|
||||
if set(axes) != set(range(x_aval.ndim)):
|
||||
raise NotImplementedError("No support for axes yet")
|
||||
scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)
|
||||
with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]:
|
||||
return mgpu.FragmentedArray.splat(
|
||||
x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype)
|
||||
)
|
||||
return x.reduce_sum(scratch)
|
||||
case mgpu.WGMMA_LAYOUT:
|
||||
if axes != (x_aval.ndim - 1,):
|
||||
raise NotImplementedError
|
||||
|
@ -23,7 +23,7 @@ from typing import Any
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
|
@ -112,7 +112,6 @@ def _pad_values_to_block_dimension(value,
|
||||
return value
|
||||
|
||||
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
||||
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
|
||||
return tuple(
|
||||
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
|
||||
)
|
||||
@ -1151,7 +1150,7 @@ def checkify_pallas_kernel_body_jaxpr(
|
||||
grid_mapping: GridMapping) -> tuple[
|
||||
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
|
||||
err_vals, err_tree = tree_util.tree_flatten(error)
|
||||
err_vals = map(checkify.get_shaped_aval, err_vals)
|
||||
err_vals = map(jax_core.get_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
||||
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
@ -1274,13 +1273,13 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
closed_jaxpr, enabled_errors, error, grid_mapping)
|
||||
error = error._add_placeholder_effects(error_effects)
|
||||
err_vals, err_in_tree = jax.tree.flatten(error)
|
||||
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
|
||||
shaped_err_avals = map(jax_core.get_aval, err_vals)
|
||||
|
||||
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
||||
# all enabled errors removed, but have the error as inputs and return values.
|
||||
input_avals = [v.aval for v in jaxpr.invars]
|
||||
num_err_vals = len(err_vals)
|
||||
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
|
||||
shaped_input_avals = tuple(input_avals)
|
||||
checkify_in_avals = [*shaped_err_avals,
|
||||
*shaped_input_avals]
|
||||
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
@ -1416,8 +1415,7 @@ def _trace_kernel_to_jaxpr(
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
if consts:
|
||||
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
|
||||
for c in consts]
|
||||
consts_avals = [jax_core.get_aval(c) for c in consts]
|
||||
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
@ -1804,8 +1802,7 @@ def pallas_call(
|
||||
def wrapped(*args):
|
||||
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
|
||||
in_paths, flat_args = unzip2(flat_args_with_paths)
|
||||
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args)
|
||||
flat_in_avals = tuple(jax_core.get_aval(a) for a in flat_args)
|
||||
|
||||
flat_out_avals = tuple(_convert_out_shape_to_aval(v)
|
||||
for v in flat_out_shapes)
|
||||
|
@ -76,6 +76,7 @@ pytype_strict_library(
|
||||
":lowering",
|
||||
"//jax",
|
||||
"//jax:config",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
|
@ -102,7 +102,7 @@ class LoweringRuleContext:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoweringResult:
|
||||
"""Keeps pybind11 objects alive."""
|
||||
"""Keeps python objects alive."""
|
||||
|
||||
module: ir.Module
|
||||
grid: tuple[int, ...]
|
||||
@ -1983,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation):
|
||||
return tt_dialect.trans(x, permutation)
|
||||
|
||||
|
||||
def _check_dot_operands(
|
||||
x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any
|
||||
):
|
||||
# TODO(slebedev): Ensure that the dtypes are supported by CUDA.
|
||||
return
|
||||
|
||||
|
||||
def _dot(
|
||||
x: ir.Value,
|
||||
y: ir.Value,
|
||||
acc: ir.Value | None = None,
|
||||
*,
|
||||
allow_tf32: bool = True,
|
||||
max_num_imprecise_acc: int | None = None,
|
||||
out_type: ir.Type | None = None,
|
||||
) -> ir.Value:
|
||||
if out_type is None:
|
||||
out_type = ir.F32Type.get()
|
||||
elif isinstance(out_type, ir.BF16Type):
|
||||
raise NotImplementedError(f"unsupported output type: {out_type}")
|
||||
|
||||
x_type = ir.RankedTensorType(x.type)
|
||||
y_type = ir.RankedTensorType(y.type)
|
||||
if min(*x_type.shape, *y_type.shape) < 16:
|
||||
raise ValueError("all dimensions of x and y must be >= 16 ")
|
||||
if x_type.element_type != y_type.element_type:
|
||||
raise ValueError(
|
||||
"x and y must have the same element type, but got:"
|
||||
f" {x_type.element_type} and {y_type.element_type}"
|
||||
)
|
||||
|
||||
_check_dot_operands(x_type, y_type, object())
|
||||
|
||||
element_type = x_type.element_type
|
||||
if isinstance(element_type, ir.IntegerType):
|
||||
if element_type.width != 8:
|
||||
raise TypeError(f"unsupported element type: {element_type}")
|
||||
element_type = ir.IntegerType.get_signless(32)
|
||||
elif isinstance(element_type, (ir.F32Type, ir.BF16Type)):
|
||||
element_type = ir.F32Type.get()
|
||||
else:
|
||||
element_type = out_type
|
||||
|
||||
if element_type != out_type:
|
||||
raise TypeError(
|
||||
f"output type {out_type} does not match element type {element_type}"
|
||||
)
|
||||
|
||||
m, _ = x_type.shape
|
||||
_, n = y_type.shape
|
||||
|
||||
if acc is None:
|
||||
acc = _full(ir.RankedTensorType.get([m, n], element_type), 0)
|
||||
|
||||
if max_num_imprecise_acc is None:
|
||||
if isinstance(element_type, ir.FloatType) and element_type.width == 8:
|
||||
# TODO(slebedev): Fill in from options.
|
||||
raise NotImplementedError
|
||||
else:
|
||||
max_num_imprecise_acc = 0
|
||||
|
||||
# Ideally, replace all allow_tf32 usages with InputPrecision directly.
|
||||
input_precision = tt_dialect.InputPrecision.IEEE
|
||||
if allow_tf32:
|
||||
input_precision = tt_dialect.InputPrecision.TF32
|
||||
|
||||
return tt_dialect.dot(
|
||||
x,
|
||||
y,
|
||||
acc,
|
||||
max_num_imprecise_acc=max_num_imprecise_acc,
|
||||
input_precision=input_precision
|
||||
)
|
||||
|
||||
|
||||
_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
|
||||
|
||||
|
||||
@ -2081,27 +2006,63 @@ def _dot_general_lowering(
|
||||
if b_contract_dim == 1:
|
||||
b = tt_dialect.trans(b, (1, 0))
|
||||
|
||||
if precision is None:
|
||||
allow_tf32 = True
|
||||
else:
|
||||
prec_a, prec_b = precision
|
||||
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
|
||||
|
||||
a_aval, b_aval = ctx.avals_in
|
||||
[out_aval] = ctx.avals_out
|
||||
out_dtype = acc_dtype = out_aval.dtype
|
||||
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
|
||||
acc_dtype = jnp.dtype(jnp.float32)
|
||||
|
||||
return _cast(
|
||||
_dot(
|
||||
a,
|
||||
b,
|
||||
allow_tf32=allow_tf32,
|
||||
out_type=_dtype_to_ir_type(acc_dtype),
|
||||
),
|
||||
acc_dtype,
|
||||
out_dtype,
|
||||
)
|
||||
if precision is None or (precision == lax.DotAlgorithmPreset.DEFAULT):
|
||||
precision = (lax.Precision.DEFAULT, lax.Precision.DEFAULT)
|
||||
|
||||
if isinstance(precision, lax.DotAlgorithmPreset):
|
||||
match precision:
|
||||
case lax.DotAlgorithmPreset.TF32_TF32_F32:
|
||||
input_precision = tt_dialect.InputPrecision.TF32
|
||||
case lax.DotAlgorithmPreset.TF32_TF32_F32_X3:
|
||||
input_precision = tt_dialect.InputPrecision.TF32x3
|
||||
case lax.DotAlgorithmPreset.F32_F32_F32:
|
||||
input_precision = tt_dialect.InputPrecision.IEEE
|
||||
case (
|
||||
lax.DotAlgorithmPreset.F16_F16_F16
|
||||
| lax.DotAlgorithmPreset.F16_F16_F32
|
||||
| lax.DotAlgorithmPreset.BF16_BF16_BF16
|
||||
| lax.DotAlgorithmPreset.BF16_BF16_F32
|
||||
):
|
||||
input_precision = None
|
||||
case _:
|
||||
raise NotImplementedError(f"Unsupported dot algorithm: {precision}.")
|
||||
|
||||
a = _cast(a, a_aval.dtype, precision.supported_lhs_types[0])
|
||||
b = _cast(b, b_aval.dtype, precision.supported_rhs_types[0])
|
||||
acc_dtype = precision.accumulation_type
|
||||
elif isinstance(precision, tuple):
|
||||
a_precision, b_precision = precision
|
||||
if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS:
|
||||
input_precision = tt_dialect.InputPrecision.TF32
|
||||
elif a_aval.dtype == jnp.float32:
|
||||
input_precision = tt_dialect.InputPrecision.IEEE
|
||||
else:
|
||||
input_precision = None
|
||||
|
||||
acc_dtype = out_aval.dtype
|
||||
if acc_dtype != jnp.int32 and acc_dtype != jnp.float16:
|
||||
acc_dtype = jnp.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dot precision: {precision}.")
|
||||
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
b_type = ir.RankedTensorType(b.type)
|
||||
if min(*a_type.shape, *b_type.shape) < 16:
|
||||
raise ValueError("all dimensions of a and b must be >= 16 ")
|
||||
if a_type.element_type != b_type.element_type:
|
||||
raise ValueError(
|
||||
"a and b must have the same element type, but got:"
|
||||
f" {a_type.element_type} and {b_type.element_type}"
|
||||
)
|
||||
|
||||
m, _ = a_type.shape
|
||||
_, n = b_type.shape
|
||||
acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0)
|
||||
acc = tt_dialect.dot(a, b, acc, input_precision=input_precision)
|
||||
return _cast(acc, acc_dtype, out_aval.dtype)
|
||||
|
||||
|
||||
def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes):
|
||||
@ -2623,7 +2584,8 @@ def _i64_constant(v: int) -> ir.Value:
|
||||
return arith_dialect.constant(ir.IntegerType.get_signless(64), v)
|
||||
|
||||
|
||||
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
|
||||
def _dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type:
|
||||
dtype = jnp.dtype(dtype)
|
||||
if jnp.issubdtype(dtype, np.integer):
|
||||
# All integer types in Triton are signless.
|
||||
return ir.IntegerType.get_signless(dtype.itemsize * 8)
|
||||
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from jax import core as jax_core
|
||||
import jax._src.core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas.triton import lowering
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
class UnconstrainedSingleton:
|
||||
|
||||
def __repr__(self):
|
||||
return "UNCONSTRAINED"
|
||||
@ -23,7 +23,7 @@ class _UnconstrainedPartitionSingleton:
|
||||
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
||||
# which the user wants XLA to assign the best partitioning.
|
||||
# TODO(yashkatariya): May rename to AUTO.
|
||||
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
|
@ -498,7 +498,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
|
||||
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
|
||||
pgle_profiler=None)
|
||||
pgle_profiler=None)
|
||||
return stages.Traced(
|
||||
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
|
||||
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
|
||||
@ -549,6 +549,7 @@ def _infer_params_impl(
|
||||
kwargs: dict[str, Any],
|
||||
in_avals: tuple[core.AbstractValue, ...] | None,
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
util.test_event("pjit._infer_params_impl", fun)
|
||||
have_kwargs = bool(kwargs)
|
||||
if have_kwargs and ji.user_specified_in_shardings:
|
||||
raise ValueError(
|
||||
@ -698,9 +699,6 @@ def get_abstract_mesh_from_avals(in_avals):
|
||||
return None
|
||||
m = None
|
||||
for a in in_avals:
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if a.sharding is None: # type: ignore
|
||||
continue
|
||||
if m is not None and m != a.sharding.mesh:
|
||||
raise ValueError(
|
||||
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
|
||||
@ -1300,6 +1298,7 @@ def _create_pjit_jaxpr(
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
util.test_event("create_pjit_jaxpr")
|
||||
del ignored_inline # just for explain_cache_miss
|
||||
if config.no_tracing.value:
|
||||
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
|
||||
@ -1787,10 +1786,9 @@ def _pjit_lower(
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if config.sharding_in_types.value:
|
||||
cur_mesh = mesh_lib.get_concrete_mesh()
|
||||
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
|
||||
api_name = 'jit'
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
if resource_env is not None else (None, 'jit'))
|
||||
@ -2156,8 +2154,18 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
unknown_ins = tuple(not k for k in known_ins)
|
||||
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
|
||||
if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect))
|
||||
for e in jaxpr.effects):
|
||||
known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \
|
||||
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins,
|
||||
False, False, None)
|
||||
if num_res_ref: raise NotImplementedError
|
||||
known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts)
|
||||
unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts)
|
||||
res_avals = unknown_jaxpr.in_avals[:num_res_val]
|
||||
else:
|
||||
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
|
||||
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
|
||||
unknown_outs = tuple(unknown_outs)
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
@ -2343,8 +2351,7 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
|
||||
*prune_type(ad.Zero, out_layouts, cts_in)
|
||||
)
|
||||
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
|
||||
for ct in primals_and_nz_cts_in)
|
||||
global_cts_in_avals = tuple(core.get_aval(ct) for ct in primals_and_nz_cts_in)
|
||||
|
||||
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
|
||||
body, global_cts_in_avals)
|
||||
|
@ -1249,7 +1249,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
|
||||
|
||||
random_gamma_p = core.Primitive('random_gamma')
|
||||
random_gamma_p.def_impl(_gamma_impl)
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: a)
|
||||
ad.defjvp2(
|
||||
random_gamma_p, None,
|
||||
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
||||
|
@ -130,7 +130,7 @@ class Rotation(typing.NamedTuple):
|
||||
else:
|
||||
return self.quat.shape[0]
|
||||
|
||||
def __mul__(self, other):
|
||||
def __mul__(self, other) -> Rotation:
|
||||
"""Compose this rotation with the other."""
|
||||
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
|
||||
|
||||
|
@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
|
||||
f"is also found in manual_axes: {_manual_axes}.") from None
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _check_axis_type_consistency(mesh, parsed_pspec):
|
||||
if mesh.axis_types is None:
|
||||
return
|
||||
for p in parsed_pspec:
|
||||
if p is not None:
|
||||
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
|
||||
raise ValueError(
|
||||
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
|
||||
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
|
||||
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
|
||||
|
||||
|
||||
def hashed_index(x) -> int:
|
||||
# This works for both `pjit` indices and `pmap` indices (which might
|
||||
# have an integer instead of a slice).
|
||||
@ -725,7 +738,7 @@ class PositionalSharding(sharding.Sharding):
|
||||
ids = self._ids.copy()
|
||||
platform_name = self._devices[0].platform.upper()
|
||||
for idx, x in np.ndenumerate(ids):
|
||||
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
|
||||
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2
|
||||
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
|
||||
max_line_width=100)
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
PartitionSpec() if spec is None else spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
_check_axis_type_consistency(mesh, parsed_pspec)
|
||||
return parsed_pspec
|
||||
|
||||
|
||||
@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(
|
||||
|
||||
|
||||
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
|
||||
*, devices: Sequence[xc.Device] | None = None,
|
||||
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
|
||||
"""Creates an efficient mesh with the shape and axis names specified.
|
||||
|
||||
This function attempts to automatically compute a good mapping from a set of
|
||||
@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
mesh_devices = mesh_utils.create_device_mesh(
|
||||
new_axis_shapes, devices,
|
||||
allow_split_physical_axes=allow_split_physical_axes)
|
||||
return mesh_lib.Mesh(mesh_devices, axis_names)
|
||||
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
|
||||
|
@ -97,7 +97,7 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
|
||||
# is used to extract the corresponding shard of the logical array.
|
||||
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
|
||||
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||||
shard_indices[i] = idxs
|
||||
shard_indices[i] = idxs # type: ignore # numpy 2.2
|
||||
shard_indices = shard_indices.reshape(shard_indices_shape)
|
||||
|
||||
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
||||
|
@ -533,6 +533,7 @@ class Compiled(Stage):
|
||||
|
||||
@staticmethod
|
||||
def call(*args, **kwargs):
|
||||
util.test_event("stages_compiled_call")
|
||||
# This is because `__call__` passes in `self._params` as the first argument.
|
||||
# Instead of making the call signature `call(params, *args, **kwargs)`
|
||||
# extract it from args because `params` can be passed as a kwarg by users
|
||||
|
@ -153,6 +153,10 @@ def _eval_jaxpr_discharge_state(
|
||||
[invar], [outvar] = eqn.invars, eqn.outvars
|
||||
ans = env.read(invar)
|
||||
refs_to_discharge.add(id(outvar.aval))
|
||||
elif eqn.primitive is core.freeze_p:
|
||||
[invar], [outvar] = eqn.invars, eqn.outvars
|
||||
ans = env.read(invar)
|
||||
refs_to_discharge.remove(id(invar.aval))
|
||||
elif (any(should_discharge)
|
||||
or core.internal_mutable_array_effect in eqn.effects
|
||||
):
|
||||
@ -364,7 +368,7 @@ def transform_swap_array(x, transforms, val):
|
||||
case indexing.NDIndexer():
|
||||
indexer = transform
|
||||
if _is_trivial_indexer(indexer):
|
||||
_results.append(None)
|
||||
_results.append(_results[-1])
|
||||
continue
|
||||
# If everything in the indexer is a slice or ()-shaped, we can also
|
||||
# use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices.
|
||||
|
@ -654,3 +654,8 @@ def _broadcast_to_abstract_eval(aval, *, shape):
|
||||
mlir.register_lowering(
|
||||
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
|
||||
)
|
||||
|
||||
# === AD rules for mutable arrays ===
|
||||
|
||||
ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g))
|
||||
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))
|
||||
|
@ -414,7 +414,7 @@ _ref_type_aval_mappings: dict[
|
||||
|
||||
def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]:
|
||||
# Default type mapping just creates an AbstractRef from the array's aval.
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
aval = core.get_aval(x)
|
||||
return AbstractRef(aval), x
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user