Merge pull request #194 from ROCm/ci-upstream-sync-80_1

CI: 01/07/25 upstream sync
This commit is contained in:
github-actions[bot] 2025-01-07 11:20:18 -06:00 committed by GitHub
commit 972f95b95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 850 additions and 651 deletions

View File

@ -132,8 +132,8 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
documentation_render:

View File

@ -19,17 +19,27 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).
* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.
* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
@ -37,6 +47,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.
## jax 0.4.38 (Dec 17, 2024)

View File

@ -468,6 +468,9 @@ async def main():
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
if clang_major_version < 19:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
@ -477,6 +480,10 @@ async def main():
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
gcc_major_version = utils.get_gcc_major_version(gcc_path)
if gcc_major_version < 13:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":

View File

@ -74,9 +74,8 @@ class System(object):
env = dict(os.environ)
if self.pkgbin == "apt":
env["DEBIAN_FRONTEND"] = "noninteractive"
# Update indexes.
subprocess.check_call(["apt-get", "update"])
# Update indexes.
subprocess.check_call(["apt-get", "update"])
LOG.info("Running %r" % cmd)
subprocess.check_call(cmd, env=env)

View File

@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):
return major_version
def get_gcc_major_version(gcc_path: str):
gcc_version_proc = subprocess.run(
[gcc_path, "-dumpversion"],
check=True,
capture_output=True,
text=True,
)
major_version = int(gcc_version_proc.stdout)
return major_version
def get_jax_configure_bazel_options(bazel_command: list[str]):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so

View File

@ -362,4 +362,5 @@ rediraffe_redirects = {
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
'jax.extend.ffi.rst': 'jax.ffi.rst',
}

View File

@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
E.g., you can run:
```
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
```
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in

View File

@ -21,7 +21,7 @@
"JAX's FFI support is provided in two parts:\n",
"\n",
"1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n",
"2. A Python front end, available in the `jax.extend.ffi` submodule.\n",
"2. A Python front end, available in the `jax.ffi` submodule.\n",
"\n",
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
@ -191,9 +191,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n",
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.\n",
"This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n",
"JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:"
"JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:"
]
},
{
@ -204,12 +204,11 @@
"source": [
"import ctypes\n",
"from pathlib import Path\n",
"import jax.extend as jex\n",
"\n",
"path = next(Path(\"ffi\").glob(\"librms_norm*\"))\n",
"rms_norm_lib = ctypes.cdll.LoadLibrary(path)\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm\", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
"jax.ffi.register_ffi_target(\n",
" \"rms_norm\", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
]
},
{
@ -217,7 +216,7 @@
"metadata": {},
"source": [
"```{tip}\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"```\n",
"\n",
"**An alternative approach**:\n",
@ -251,7 +250,7 @@
"# Assuming that we compiled a nanobind extension called `rms_norm`:\n",
"import rms_norm as rms_norm_lib\n",
"\n",
"jex.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"jax.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"```"
]
},
@ -261,7 +260,7 @@
"source": [
"## Frontend code\n",
"\n",
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:"
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:"
]
},
{
@ -282,7 +281,7 @@
" if x.dtype != jnp.float32:\n",
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
"\n",
" call = jex.ffi.ffi_call(\n",
" call = jax.ffi.ffi_call(\n",
" # The target name must be the same string as we used to register the target\n",
" # above in `register_custom_call_target`\n",
" \"rms_norm\",\n",
@ -314,25 +313,25 @@
"metadata": {},
"source": [
"This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n",
"Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"\n",
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"\n",
"```{tip}\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n",
"In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n",
"One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"```\n",
"\n",
"(ffi-call-vmap)=\n",
"### Batching with `vmap`\n",
"\n",
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
"{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n",
"\n",
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
@ -395,7 +394,7 @@
"outputs": [],
"source": [
"def rms_norm_sequential(x, eps=1e-5):\n",
" return jex.ffi.ffi_call(\n",
" return jax.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" vmap_method=\"sequential\",\n",
@ -418,9 +417,9 @@
"source": [
"### Differentiation\n",
"\n",
"Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n",
"Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"\n",
"More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n",
"In this case, we actually define two new FFI calls:\n",
@ -429,7 +428,7 @@
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"\n",
"This custom derivative rule can be wired in as follows:"
]
@ -440,16 +439,16 @@
"metadata": {},
"outputs": [],
"source": [
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
")\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
")\n",
"\n",
"\n",
"def rms_norm_fwd(x, eps=1e-5):\n",
" y, res = jex.ffi.ffi_call(\n",
" y, res = jax.ffi.ffi_call(\n",
" \"rms_norm_fwd\",\n",
" (\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
@ -466,7 +465,7 @@
" assert res.shape == ct.shape[:-1]\n",
" assert x.shape == ct.shape\n",
" return (\n",
" jex.ffi.ffi_call(\n",
" jax.ffi.ffi_call(\n",
" \"rms_norm_bwd\",\n",
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
" vmap_method=\"broadcast_all\",\n",
@ -533,7 +532,7 @@
"On the front end, the registration code would be updated to specify the appropriate platform:\n",
"\n",
"```python\n",
"jex.ffi.register_ffi_target(\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n",
")\n",
"```\n",
@ -554,7 +553,7 @@
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
"\n",
" def impl(target_name):\n",
" return lambda x: jex.ffi.ffi_call(\n",
" return lambda x: jax.ffi.ffi_call(\n",
" target_name,\n",
" out_type,\n",
" vmap_method=\"broadcast_all\",\n",
@ -620,9 +619,9 @@
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
"We will leave these topics to future tutorials, but here are some possibly useful references:\n",
"\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"\n",
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."
]

View File

@ -29,7 +29,7 @@ We will discuss some possible approaches below, but it is important to call this
JAX's FFI support is provided in two parts:
1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and
2. A Python front end, available in the `jax.extend.ffi` submodule.
2. A Python front end, available in the `jax.ffi` submodule.
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.
@ -171,23 +171,22 @@ To compile the shared library, we're using CMake here, but you should be able to
!cmake --install ffi/_build
```
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.
This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).
JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:
JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:
```{code-cell} ipython3
import ctypes
from pathlib import Path
import jax.extend as jex
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
"rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
jax.ffi.register_ffi_target(
"rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
```
```{tip}
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
```
**An alternative approach**:
@ -221,14 +220,14 @@ Then, in Python we can register this handler using:
# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib
jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
```
+++
## Frontend code
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:
```{code-cell} ipython3
import numpy as np
@ -243,7 +242,7 @@ def rms_norm(x, eps=1e-5):
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
call = jex.ffi.ffi_call(
call = jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
@ -271,25 +270,25 @@ np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
```
This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.
Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
```{tip}
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.
In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.
One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
```
(ffi-call-vmap)=
### Batching with `vmap`
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.
The simplest `vmap_method` is `"sequential"`.
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
@ -326,7 +325,7 @@ Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {f
```{code-cell} ipython3
def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="sequential",
@ -342,9 +341,9 @@ If your foreign function provides an efficient batching rule that isn't supporte
### Differentiation
Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.
Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.
Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.
More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.
In this case, we actually define two new FFI calls:
@ -353,21 +352,21 @@ In this case, we actually define two new FFI calls:
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
This custom derivative rule can be wired in as follows:
```{code-cell} ipython3
jex.ffi.register_ffi_target(
"rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
jax.ffi.register_ffi_target(
"rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jex.ffi.register_ffi_target(
"rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
jax.ffi.register_ffi_target(
"rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
@ -384,7 +383,7 @@ def rms_norm_bwd(eps, res, ct):
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",
@ -447,7 +446,7 @@ Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels.
On the front end, the registration code would be updated to specify the appropriate platform:
```python
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)
```
@ -462,7 +461,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
def impl(target_name):
return lambda x: jex.ffi.ffi_call(
return lambda x: jax.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_all",
@ -499,8 +498,8 @@ and there will be no runtime overhead to using {func}`jax.lax.platform_dependent
This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.
We will leave these topics to future tutorials, but here are some possibly useful references:
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details.

View File

@ -1,28 +0,0 @@
``jax.experimental.array_api`` module
=====================================
.. note::
The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and
importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy`
implements the array API standard directly by default. See :ref:`python-array-api`
for details.
This module includes experimental JAX support for the `Python array API standard`_.
Support for this is currently experimental and not fully complete.
Example Usage::
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2023.12'
>>> arr = xp.arange(1000)
>>> arr.sum()
Array(499500, dtype=int32)
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
and implements most of the API listed in the standard.
.. _Python array API standard: https://data-apis.org/array-api/

View File

@ -14,7 +14,6 @@ Experimental Modules
.. toctree::
:maxdepth: 1
jax.experimental.array_api
jax.experimental.checkify
jax.experimental.compilation_cache
jax.experimental.custom_partitioning

View File

@ -1,12 +0,0 @@
``jax.extend.ffi`` module
=========================
.. automodule:: jax.extend.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target

View File

@ -12,7 +12,6 @@ Modules
:maxdepth: 1
jax.extend.core
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir
jax.extend.random

31
docs/jax.ffi.rst Normal file
View File

@ -0,0 +1,31 @@
``jax.ffi`` module
==================
.. automodule:: jax.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target
register_ffi_type_id
``jax.extend.ffi`` module (deprecated)
======================================
The ``jax.extend.ffi`` module has been moved to ``jax.ffi``, and that import
path should be used instead, but these functions remain documented here while
the legacy import is being deprecated.
.. automodule:: jax.extend.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target

View File

@ -542,7 +542,8 @@ Python Array API standard
Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
module is no longer required, and will raise a deprecation warning.
module is no longer required, and will raise a deprecation warning. After
JAX v0.5.0, this import will raise an error.
Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
with the `Python Array API Standard`_. You can access the Array API namespace via

View File

@ -18,6 +18,7 @@ Subpackages
jax.dlpack
jax.distributed
jax.dtypes
jax.ffi
jax.flatten_util
jax.image
jax.nn

View File

@ -121,13 +121,13 @@
}
],
"source": [
"from scipy import misc\n",
"from scipy import datasets\n",
"import jax.scipy as jsp\n",
"\n",
"fig, ax = plt.subplots(1, 3, figsize=(12, 5))\n",
"\n",
"# Load a sample image; compute mean() to convert from RGB to grayscale.\n",
"image = jnp.array(misc.face().mean(-1))\n",
"image = jnp.array(datasets.face().mean(-1))\n",
"ax[0].imshow(image, cmap='binary_r')\n",
"ax[0].set_title('original')\n",
"\n",

View File

@ -75,13 +75,13 @@ For example, here is a simple approach to de-noising an image based on convoluti
:id: Jk5qdnbv6QgT
:outputId: 292205eb-aa09-446f-eec2-af8c23cfc718
from scipy import misc
from scipy import datasets
import jax.scipy as jsp
fig, ax = plt.subplots(1, 3, figsize=(12, 5))
# Load a sample image; compute mean() to convert from RGB to grayscale.
image = jnp.array(misc.face().mean(-1))
image = jnp.array(datasets.face().mean(-1))
ax[0].imshow(image, cmap='binary_r')
ax[0].set_title('original')

View File

@ -17,6 +17,7 @@ pytest-xdist
# Packages used for notebook execution
matplotlib
scikit-learn
pooch
numpy
rich[jupyter]
cmake

View File

@ -15,28 +15,27 @@
import numpy as np
import jax
import jax.extend as jex
from jax_ffi_example import _cpu_examples
for name, target in _cpu_examples.registrations().items():
jex.ffi.register_ffi_target(name, target)
jax.ffi.register_ffi_target(name, target)
def array_attr(num: int):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
)(array=np.arange(num, dtype=np.int32))
def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
)(**kwargs)
def counter(index):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

View File

@ -24,15 +24,14 @@ import numpy as np
import jax
import jax.numpy as jnp
import jax.extend as jex
# Load the shared library with the FFI target definitions
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
jax.ffi.register_ffi_target("foo-fwd", jax.ffi.pycapsule(library.FooFwd),
platform="CUDA")
jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd),
jax.ffi.register_ffi_target("foo-bwd", jax.ffi.pycapsule(library.FooBwd),
platform="CUDA")
@ -42,7 +41,7 @@ def foo_fwd(a, b):
assert a.dtype == b.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
c, b_plus_1 = jax.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
return c, (a, b_plus_1)
@ -55,7 +54,7 @@ def foo_bwd(res, c_grad):
assert a.dtype == b_plus_1.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
return jax.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
n=n)

View File

@ -16,7 +16,7 @@
This example is exactly the same as the one in the `FFI tutorial
<https://jax.readthedocs.io/en/latest/ffi.html>`, so more details can be found
on that page. But, the high level summary is that we implement our custom
extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in
extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in
this module. The behavior under autodiff is implemented using
``jax.custom_vjp``.
"""
@ -26,13 +26,12 @@ from functools import partial
import numpy as np
import jax
import jax.extend as jex
import jax.numpy as jnp
from jax_ffi_example import _rms_norm
for name, target in _rms_norm.registrations().items():
jex.ffi.register_ffi_target(name, target)
jax.ffi.register_ffi_target(name, target)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
@ -53,7 +52,7 @@ def rms_norm(x, eps=1e-5):
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_ffi_target`
"rms_norm",
@ -63,7 +62,7 @@ def rms_norm(x, eps=1e-5):
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
@ -80,7 +79,7 @@ def rms_norm_bwd(eps, res, ct):
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",

View File

@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"if_building_jaxlib",
"jax_export_file_visibility",
"jax_extend_internal_users",
"jax_extra_deps",
"jax_internal_export_back_compat_test_util_visibility",
@ -62,6 +63,11 @@ exports_files([
"version.py",
])
exports_files(
["_src/export/serialization.fbs"],
visibility = jax_export_file_visibility,
)
# Packages that have access to JAX-internal implementation details.
package_group(
name = "internal",
@ -199,6 +205,7 @@ py_library_providing_imports_info(
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",
@ -730,7 +737,6 @@ py_library(
":jax",
":mlir",
"//jax/_src/lib",
"//jax/extend:ffi",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:func_dialect",
@ -1047,19 +1053,6 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "experimental_array_api",
srcs = glob(
[
"experimental/array_api/*.py",
],
),
visibility = [":internal"],
deps = [
":jax",
],
)
pytype_library(
name = "experimental_sparse",
srcs = glob(

View File

@ -160,6 +160,7 @@ from jax import debug as debug
from jax import dlpack as dlpack
from jax import dtypes as dtypes
from jax import errors as errors
from jax import ffi as ffi
from jax import image as image
from jax import lax as lax
from jax import monitoring as monitoring

View File

@ -96,7 +96,7 @@ class Array(abc.ABC):
def __truediv__(self, other) -> Array: ...
def __floordiv__(self, other) -> Array: ...
def __mod__(self, other) -> Array: ...
def __divmod__(self, other) -> Array: ...
def __divmod__(self, other) -> tuple[Array, Array]: ...
def __pow__(self, other) -> Array: ...
def __lshift__(self, other) -> Array: ...
def __rshift__(self, other) -> Array: ...

View File

@ -2095,11 +2095,6 @@ def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimS
# If d < window_size then (d - window_size) // window_stride < 0
return max_dim((d - window_size) // window_stride + 1, 0)
# TODO(necula): Deprecated Jan 2024, to be removed.
def non_negative_dim(d: DimSize) -> DimSize:
"""max(d, 0)."""
return max_dim(d, 0)
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)

View File

@ -390,6 +390,10 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/jax-ml/jax/issues')
am = axis_context.abstract_mesh
if am is not None:
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
am.axis_names)
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:

View File

@ -128,9 +128,6 @@ class _DimFactor:
MOD = "mod"
MAX = "max"
MIN = "min"
# TODO(necula): remove non_negative
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
__slots__ = ["var", "operation", "operands", "_hash", "_size"]
@ -446,11 +443,6 @@ class _DimExpr:
@staticmethod
def _from_operation(operation: str, *operands: DimSize,
scope: SymbolicScope) -> DimSize:
if operation == _DimFactor.NON_NEGATIVE: # For parsing, for backwards compatibility
return _DimExpr._from_term(
_DimTerm.from_operation(_DimFactor.MAX, *operands, 0,
scope=scope), 1,
scope=scope)
return _DimExpr._from_term(
_DimTerm.from_operation(operation, *operands, scope=scope), 1,
scope=scope)
@ -1665,8 +1657,6 @@ class _Parser:
if tok.exact_type == tokenize.NAME:
if tok.string in (_DimFactor.MOD, _DimFactor.FLOORDIV, _DimFactor.MAX, _DimFactor.MIN):
return self.factor_binary_op(tok.string, self.next_tok())
if tok.string == _DimFactor.NON_NEGATIVE: # We still parse this for backwards compatibility
return self.factor_unary_op(_DimFactor.NON_NEGATIVE, self.next_tok())
return _DimExpr._from_var(tok.string, self.scope), self.next_tok()
number_sign = 1
if tok.exact_type == tokenize.MINUS: # -k are negative constants

View File

@ -306,18 +306,21 @@ class _DecisionByElimination:
scope=scope):
# `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t
# AND c_s > 0.
# rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest` contains only terms smaller than `t`.
rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s,
c._sorted_terms, 0, c_s)
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0,
BoundsPrecision.BEST)
if rest_ub < np.inf:
# We have: e[i:]*t_s = rest - c*c_s <= rest_ub
if t_s > 0:
ub = min(ub, int(np.floor(rest_ub / t_s)))
else:
lb = max(lb, int(np.ceil(rest_ub / t_s)))
if rest_lb > - np.inf and c_eq == Comparator.EQ:
# We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb
if t_s > 0:
lb = max(lb, int(np.ceil(rest_lb / t_s)))
else:

View File

@ -72,12 +72,27 @@ def register_ffi_target(
**kwargs)
def register_ffi_type_id(
name: str,
obj: Any,
platform: str = "cpu",
) -> None:
"""Registers a custom type ID for a FFI target.
Args:
name: the name of the type ID. This name must be unique within the process.
obj: a ``PyCapsule`` object encapsulating a pointer to the type ID.
platform: the target platform.
"""
return xla_client.register_custom_type_id(name, obj, platform=platform)
def pycapsule(funcptr):
"""Wrap a ctypes function pointer in a PyCapsule.
The primary use of this function, and the reason why it lives with in the
``jax.extend.ffi`` submodule, is to wrap function calls from external
compiled libraries to be registered as XLA custom calls.
``jax.ffi`` submodule, is to wrap function calls from external compiled
libraries to be registered as XLA custom calls.
Example usage::
@ -88,7 +103,7 @@ def pycapsule(funcptr):
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
name="bar",
fn=jax.extend.ffi.pycapsule(libfoo.bar),
fn=jax.ffi.pycapsule(libfoo.bar),
platform=PLATFORM,
api_version=API_VERSION
)
@ -145,7 +160,7 @@ def ffi_lowering(
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``.
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
@ -310,7 +325,7 @@ def ffi_call(
Args:
target_name: the name of the XLA FFI custom call target that was registered
using :func:`~jax.extend.ffi.register_ffi_target`.
using :func:`~jax.ffi.register_ffi_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often

View File

@ -34,7 +34,7 @@ from jax._src import dtypes
from jax._src import util
from jax._src.core import (
Primitive, ShapedArray, is_constant_dim, is_constant_shape)
from jax._src.extend import ffi
from jax._src import ffi
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir

View File

@ -1634,6 +1634,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
if (isinstance(axis_context, SPMDAxisContext) and
axis_context.manual_axes and
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
if axis_env.sizes[axis_pos] == 1:
return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32)))
x = hlo.iota(ir.RankedTensorType.get(
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
sharding_proto = (

View File

@ -7667,13 +7667,14 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the upper triangle of the given
array with elements below the sub-diagonal specified by ``k`` are set to zero.
An array with same shape as input containing the lower triangle of the given
array with elements above the sub-diagonal specified by ``k`` are set to
zero.
See also:
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
diagonal and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3, 4],
@ -7729,13 +7730,14 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
to sub-diagonal above the main diagonal.
Returns:
An array with same shape as input containing the lower triangle of the given
array with elements above the sub-diagonal specified by ``k`` are set to zero.
An array with same shape as input containing the upper triangle of the given
array with elements below the sub-diagonal specified by ``k`` are set to
zero.
See also:
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
and zeros elsewhere.
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
diagonal and zeros elsewhere.
Examples:
>>> x = jnp.array([[1, 2, 3],
@ -9501,7 +9503,7 @@ def einsum(
subscript: str, /,
*operands: ArrayLike,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
@ -9514,7 +9516,7 @@ def einsum(
axes: Sequence[Any], /,
*operands: ArrayLike | Sequence[Any],
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
@ -9526,7 +9528,7 @@ def einsum(
subscripts, /,
*operands,
out: None = None,
optimize: str | bool | list[tuple[int, ...]] = "optimal",
optimize: str | bool | list[tuple[int, ...]] = "auto",
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None,
_dot_general: Callable[..., Array] = lax.dot_general,
@ -9546,10 +9548,10 @@ def einsum(
subscripts: string containing axes names separated by commas.
*operands: sequence of one or more arrays corresponding to the subscripts.
optimize: specify how to optimize the order of computation. In JAX this defaults
to ``"optimal"`` which produces optimized expressions via the opt_einsum_
to ``"auto"`` which produces optimized expressions via the opt_einsum_
package. Other options are ``True`` (same as ``"optimal"``), ``False``
(unoptimized), or any string supported by ``opt_einsum``, which
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
precision: either ``None`` (default), which means the default precision for
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,

View File

@ -1068,22 +1068,35 @@ def threefry_2x32(keypair, count):
msg = "threefry_2x32 requires uint32 arguments, got {}"
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
odd_size = count.size % 2
if not isinstance(odd_size, int):
msg = ("jax.random functions have limited support for shape polymorphism "
"when using threefry. "
f"In particular, the array size ({count.size}) must be even.")
raise core.InconclusiveDimensionOperation(msg)
if odd_size:
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
flat_count = count.ravel()
odd_size = flat_count.shape[0] % 2
if core.is_constant_dim(odd_size):
if odd_size:
x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2))
else:
x = list(jnp.split(flat_count, 2))
else:
x = list(jnp.split(count.ravel(), 2))
# With symbolic shapes we cannot always tell statically if odd_size is true
# or false, so we rewrite this without a conditional.
flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])])
flat_count_padded_half_size = flat_count_padded.shape[0] // 2
x = [
lax.dynamic_slice(flat_count_padded, (0,),
(flat_count_padded_half_size,)),
lax.dynamic_slice(flat_count_padded,
(flat_count_padded_half_size,),
(flat_count_padded_half_size,))
]
assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape)
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
out = jnp.concatenate(x)
assert out.dtype == np.uint32
return lax.reshape(out[:-1] if odd_size else out, count.shape)
if core.is_constant_dim(odd_size):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
else:
out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],))
return lax.reshape(out_no_padding, count.shape)
def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:

View File

@ -2086,10 +2086,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [

View File

@ -22,7 +22,6 @@ import warnings
import jax
from jax._src.lib import xla_client
from jax.extend import ffi
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
@ -54,7 +53,7 @@ P = ParamSpec("P")
def _event_record(args, *, copy_before):
flat_args, treedef = jax.tree.flatten(args)
event, *flat_outs = ffi.ffi_call(
event, *flat_outs = jax.ffi.ffi_call(
"mgpu_event_record",
result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args),
input_output_aliases={i: i + 1 for i in range(len(flat_args))},
@ -63,7 +62,7 @@ def _event_record(args, *, copy_before):
def _event_elapsed(start_event, end_event):
return ffi.ffi_call(
return jax.ffi.ffi_call(
"mgpu_event_elapsed",
result_shape_dtypes=jax.core.ShapedArray((), jnp.float32),
)(start_event, end_event)

View File

@ -52,7 +52,7 @@ from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
special, control_flow, ann)
from jax._src.extend import ffi
from jax._src import ffi
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
@ -70,6 +70,7 @@ from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
generate_key_paths, KeyPath)
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
global_array_to_host_local_array)
from jax._src.pjit import sharding_constraint_p
P = PartitionSpec
@ -1130,7 +1131,7 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p,
control_flow.loops.cumprod_p, control_flow.loops.cummax_p,
control_flow.loops.cummin_p]:
control_flow.loops.cummin_p, sharding_constraint_p]:
register_standard_check(p)
register_standard_rewrite(p)

View File

@ -12,13 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src import ffi as _ffi
from jax._src.extend.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
)
_deprecations = {
# Added 2024-12-20
"ffi_call": (
"jax.extend.ffi.ffi_call is deprecated, use jax.ffi.ffi_call instead.",
_ffi.ffi_call,
),
"ffi_lowering": (
"jax.extend.ffi.ffi_lowering is deprecated, use jax.ffi.ffi_lowering instead.",
_ffi.ffi_lowering,
),
"include_dir": (
"jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.",
_ffi.include_dir,
),
"pycapsule": (
"jax.extend.ffi.pycapsule is deprecated, use jax.ffi.pycapsule instead.",
_ffi.pycapsule,
),
"register_ffi_target": (
"jax.extend.ffi.register_ffi_target is deprecated, use jax.ffi.register_ffi_target instead.",
_ffi.register_ffi_target,
),
}
import typing
if typing.TYPE_CHECKING:
ffi_call = _ffi.ffi_call
ffi_lowering = _ffi.ffi_lowering
include_dir = _ffi.include_dir
pycapsule = _ffi.pycapsule
register_ffi_target = _ffi.register_ffi_target
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _ffi

View File

@ -15,18 +15,11 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
import sys as _sys
import warnings as _warnings
import jax.numpy as _array_api
# Added 2024-08-01
_warnings.warn(
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
"jax.numpy supports the array API by default.",
DeprecationWarning, stacklevel=2
from jax._src.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
register_ffi_type_id as register_ffi_type_id,
)
_sys.modules['jax.experimental.array_api'] = _array_api
del _array_api, _sys, _warnings

View File

@ -46,8 +46,8 @@ from jax._src.scipy.special import (
log_softmax as log_softmax,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
lpmn as _deprecated_lpmn,
lpmn_values as _deprecated_lpmn_values,
multigammaln as multigammaln,
ndtr as ndtr,
ndtri as ndtri,
@ -65,3 +65,25 @@ from jax._src.scipy.special import (
from jax._src.third_party.scipy.special import (
fresnel as fresnel,
)
_deprecations = {
# Added Jan 3 2024
"lpmn": (
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
_deprecated_lpmn,
),
"lpmn_values": (
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
_deprecated_lpmn_values,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
lpmn = _deprecated_lpmn
lpmn_values = _deprecated_lpmn_values
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing

View File

@ -93,5 +93,11 @@ def initialize():
)
for _name, _value in cuda_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
xla_client.register_custom_type_id_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.register_custom_type_id, c_api
),
)
else:
logger.warning('cuda_plugin_extension is not found.')

View File

@ -94,5 +94,11 @@ def initialize():
)
for _name, _value in rocm_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
xla_client.register_custom_type_id_handler(
"ROCM",
functools.partial(
rocm_plugin_extension.register_custom_type_id, c_api
),
)
else:
logger.warning('rocm_plugin_extension is not found.')

View File

@ -16,8 +16,8 @@
load(
"//jaxlib:jax.bzl",
"nanobind_extension",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_files")
@ -198,7 +198,7 @@ cc_library(
# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong
# target architecture.
pybind_extension(
nanobind_extension(
name = "cpu_feature_guard",
srcs = ["cpu_feature_guard.c"],
module_name = "cpu_feature_guard",
@ -207,7 +207,7 @@ pybind_extension(
],
)
pybind_extension(
nanobind_extension(
name = "utils",
srcs = ["utils.cc"],
module_name = "utils",
@ -238,6 +238,7 @@ cc_library(
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
@ -246,7 +247,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "cuda_plugin_extension",
srcs = ["cuda_plugin_extension.cc"],
module_name = "cuda_plugin_extension",
@ -262,7 +263,7 @@ pybind_extension(
],
)
pybind_extension(
nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",

View File

@ -16,7 +16,7 @@
load(
"//jaxlib:jax.bzl",
"pybind_extension",
"nanobind_extension",
)
licenses(["notice"])
@ -53,7 +53,7 @@ cc_library(
alwayslink = 1,
)
pybind_extension(
nanobind_extension(
name = "_lapack",
srcs = ["lapack.cc"],
copts = [

View File

@ -758,6 +758,35 @@ static ffi::Error SvdKernel(
work_data.get(), &workspace_dim_v,
iwork_data.get(), info_data);
}
// Suppress MSAN warnings when using a copy of LAPACK uninstrumented by
// MSAN.
using T [[maybe_unused]] = typename svd::SVDType<dtype>::ValueType;
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(*info_data));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data,
x_cols_v * x_leading_dim_v * sizeof(T));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
singular_values_data, std::min(x_rows_v, x_cols_v) * sizeof(RealType));
if (mode_v == 'A') {
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
u_data, u_leading_dim_v * x_rows_v * sizeof(T));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
} else if (mode_v == 'O') {
if (x_rows_v < x_cols_v) {
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
u_data, u_leading_dim_v * x_rows_v * sizeof(T));
} else {
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
}
} else if (mode_v == 'S') {
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
u_data, u_leading_dim_v * std::min(x_rows_v, x_cols_v) * sizeof(T));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
}
x_out_data += x_out_step;
singular_values_data += singular_values_step;
u_data += u_step;

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "xla/service/custom_call_status.h"
// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either
// by the pybind wrapper that links them to an existing SciPy lapack instance,
// by the nanobind wrapper that links them to an existing SciPy lapack instance,
// or using the lapack_kernels_strong.cc static initialization to link them
// directly to lapack for use in a pure C++ context.

View File

@ -19,7 +19,7 @@ load(
"//jaxlib:jax.bzl",
"cuda_library",
"if_cuda_is_configured",
"pybind_extension",
"nanobind_extension",
)
licenses(["notice"])
@ -125,7 +125,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
copts = [
@ -173,7 +173,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_rnn",
srcs = ["//jaxlib/gpu:rnn.cc"],
copts = [
@ -265,7 +265,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_solver",
srcs = ["//jaxlib/gpu:solver.cc"],
copts = [
@ -321,7 +321,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_sparse",
srcs = ["//jaxlib/gpu:sparse.cc"],
copts = [
@ -399,7 +399,7 @@ cuda_library(
],
)
pybind_extension(
nanobind_extension(
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
copts = [
@ -457,7 +457,7 @@ cuda_library(
],
)
pybind_extension(
nanobind_extension(
name = "_prng",
srcs = ["//jaxlib/gpu:prng.cc"],
copts = [
@ -497,7 +497,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_hybrid",
srcs = ["//jaxlib/gpu:hybrid.cc"],
copts = [
@ -588,7 +588,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_triton",
srcs = ["//jaxlib/gpu:triton.cc"],
copts = [
@ -639,7 +639,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_versions",
srcs = ["versions.cc"],
copts = [

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/status_casters.h"
@ -44,21 +45,14 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
size_t fn_name_size, nb::object fn,
int api_version,
XLA_FFI_Handler_Traits traits) {
if (c_api->extension_start == nullptr) {
return Unimplemented("The plugin does not have extension.");
}
const PJRT_Extension_Base* next =
reinterpret_cast<const PJRT_Extension_Base*>(c_api->extension_start);
while (next != nullptr &&
next->type !=
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
next = next->next;
}
if (next == nullptr) {
const PJRT_Gpu_Custom_Call* custom_call_ext =
pjrt::FindExtension<PJRT_Gpu_Custom_Call>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call);
if (custom_call_ext == nullptr) {
return Unimplemented("The plugin does not have a custom call extension.");
}
PJRT_Gpu_Register_Custom_Call* register_custom_call =
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call;
custom_call_ext->custom_call;
if (traits != 0) {
return Unimplemented("The plugin does not support custom call traits.");
@ -137,6 +131,34 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
#endif
}
absl::Status RegisterCustomTypeId(const PJRT_Api* c_api,
const char* type_name_c_str,
size_t type_name_size, nb::object type_id) {
const PJRT_FFI_Extension* ffi_ext = pjrt::FindExtension<PJRT_FFI_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI);
if (ffi_ext == nullptr) {
return Unimplemented("The plugin does not have the FFI extension.");
}
PJRT_FFI_TypeID_Register_Args args;
args.struct_size = PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE;
args.type_name = type_name_c_str;
args.type_name_size = type_name_size;
RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_id_register(&args), c_api);
nb::capsule capsule;
if (!nb::try_cast<nb::capsule>(type_id, capsule)) {
return absl::InvalidArgumentError(
"The type_id argument to register_custom_call_type_id must be a "
"PyCapsule object holding a pointer to a XLA_FFI_TypeId.");
}
XLA_FFI_TypeId* type_id_ptr =
reinterpret_cast<XLA_FFI_TypeId*>(static_cast<void*>(capsule.data()));
type_id_ptr->type_id = args.type_id;
return absl::OkStatus();
}
nb::dict Registrations() {
nb::dict dict;
dict["xla_python_gpu_callback"] =
@ -171,6 +193,16 @@ void BuildGpuPluginExtension(nanobind::module_& m) {
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,
nb::arg("traits") = 0);
m.def(
"register_custom_type_id",
[](nb::capsule c_api, nb::str type_name_py, nb::object type_id) {
const char* type_name_c_str = type_name_py.c_str();
size_t type_name_size = nb::len(type_name_py);
xla::ThrowIfError(RegisterCustomTypeId(
static_cast<const PJRT_Api*>(c_api.data()), type_name_c_str,
type_name_size, std::move(type_id)));
},
nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id"));
m.def("registrations", &Registrations);
}

View File

@ -29,7 +29,7 @@ cc_proto_library = _cc_proto_library
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_test = native.py_test
pybind_extension = _pybind_extension
nanobind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
if_rocm_is_configured = _if_rocm_is_configured
if_windows = _if_windows
@ -120,7 +120,7 @@ def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pyt
lib_rule(name = name, **kwargs)
def py_extension(name, srcs, copts, deps, linkopts = []):
pybind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []):
"""Workaround DLL building issue.
@ -399,6 +399,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
jax_test_file_visibility = []
jax_export_file_visibility = []
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
pass

View File

@ -15,8 +15,8 @@
load(
"//jaxlib:jax.bzl",
"if_windows",
"nanobind_extension",
"py_extension",
"pybind_extension",
"windows_cc_shared_mlir_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
@ -190,7 +190,7 @@ py_extension(
)
# This target contains the extension and it's Python dependencies, which are not
# supported by the `py_extension`/`pybind_extension` macros.
# supported by the `py_extension`/`nanobind_extension` macros.
py_library(
name = "_tpu_ext_lib",
deps = [
@ -200,7 +200,7 @@ py_library(
],
)
pybind_extension(
nanobind_extension(
name = "_triton_ext",
srcs = ["triton_ext.cc"],
copts = COPTS,

View File

@ -67,12 +67,12 @@ constexpr const char IR_MODULE[] = "jaxlib.mlir.ir";
// TODO(tlongeri): Get rid of this somehow
constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128};
// TODO(tlongeri): Add type annotations from pybind11/typing.h once there is
// TODO(tlongeri): Add type annotations via nanobind once there is
// a release for it (and maybe add a custom Sequence<T> one as well).
// TODO(tlongeri): For our use-case, we don't really need C++ exceptions - just
// setting the exception object and returning NULL to Python should suffice, but
// not sure if this is possible with pybind.
// not sure if this is possible with nanobind.
class NotImplementedException : public std::runtime_error {
using runtime_error::runtime_error;
};

View File

@ -13,6 +13,8 @@
// NOLINTNEXTLINE(misc-include-cleaner)
#include "mlir/Dialect/MemRef/IR/MemRef.h"
// NOLINTNEXTLINE(misc-include-cleaner)
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -23,6 +25,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/include/mlir/IR/AffineExpr.h"
@ -294,8 +297,13 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
auto element_type = ty.getElementType();
// PowFOp and DivFOp do not seem to be supported in bf16 on later
// hardware.
// There's an annoying hodgepodge of elementwise ops that need to be
// rewritten to f32 on later hardware.
// TODO(mvoz): Look into (1) what it would take to support these ops
// natively on later hardware, and (2) how to better organize this list.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
isa<arith::DivFOp>(op);
isa<arith::DivFOp>(op) || isa<math::TanhOp>(op) ||
isa<math::ExpOp>(op) || isa<math::LogOp>(op);
if (needs_cast && element_type.isBF16()) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
@ -552,7 +560,10 @@ const llvm::StringSet<> &elementwise_convertible_ops() {
arith::SubFOp::getOperationName(),
arith::MaximumFOp::getOperationName(),
arith::MinimumFOp::getOperationName(),
math::PowFOp::getOperationName()};
math::PowFOp::getOperationName(),
math::TanhOp::getOperationName(),
math::ExpOp::getOperationName(),
math::LogOp::getOperationName()};
return *ops;
}

View File

@ -57,6 +57,9 @@ FailureOr<TypedValue<VectorType>> relayout(
auto src_int_vty = make_vty(src.bitwidth());
auto dst_int_vty = make_vty(dst.bitwidth());
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
// extSI or truncI below, we can reuse the inferExt and inferTrunc from
// infer-vector-layout pass.
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
@ -66,6 +69,12 @@ FailureOr<TypedValue<VectorType>> relayout(
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
setLayout(ext_op, src, src);

View File

@ -13,7 +13,7 @@
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "pybind_extension")
load("//jaxlib:jax.bzl", "nanobind_extension")
package(
default_applicable_licenses = [],
@ -171,7 +171,7 @@ cc_library(
alwayslink = True,
)
pybind_extension(
nanobind_extension(
name = "_mosaic_gpu_ext",
srcs = ["mosaic_gpu_ext.cc"],
copts = [

View File

@ -18,7 +18,7 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"if_rocm_is_configured",
"pybind_extension",
"nanobind_extension",
"rocm_library",
)
@ -113,7 +113,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
copts = [
@ -206,7 +206,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_solver",
srcs = ["//jaxlib/gpu:solver.cc"],
copts = [
@ -252,7 +252,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_sparse",
srcs = ["//jaxlib/gpu:sparse.cc"],
copts = [
@ -317,7 +317,7 @@ rocm_library(
],
)
pybind_extension(
nanobind_extension(
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
copts = [
@ -370,7 +370,7 @@ rocm_library(
],
)
pybind_extension(
nanobind_extension(
name = "_prng",
srcs = ["//jaxlib/gpu:prng.cc"],
copts = [
@ -410,7 +410,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_hybrid",
srcs = ["//jaxlib/gpu:hybrid.cc"],
copts = [
@ -472,7 +472,7 @@ cc_library(
],
)
pybind_extension(
nanobind_extension(
name = "_triton",
srcs = ["//jaxlib/gpu:triton.cc"],
copts = [

View File

@ -59,7 +59,6 @@ jax_py_test(
srcs = ["array_api_test.py"],
deps = [
"//jax",
"//jax:experimental_array_api",
"//jax:test_util",
] + py_deps("absl/testing"),
)
@ -142,6 +141,13 @@ jax_multiplatform_test(
deps = ["//jax:extend"],
)
jax_multiplatform_test(
name = "ffi_test",
srcs = ["ffi_test.py"],
# TODO(dfm): Remove after removal of jex.ffi imports.
deps = ["//jax:extend"],
)
jax_multiplatform_test(
name = "fft_test",
srcs = ["fft_test.py"],

View File

@ -246,12 +246,6 @@ class ArrayAPISmokeTest(absltest.TestCase):
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
def test_deprecated_import(self):
msg = "jax.experimental.array_api import is no longer required"
with self.assertWarnsRegex(DeprecationWarning, msg):
import jax.experimental.array_api as nx
self.assertIs(nx, ARRAY_API_NAMESPACE)
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

View File

@ -1133,9 +1133,9 @@ class InspectShardingTest(jtu.JaxTestCase):
f(np.arange(8, dtype=jnp.float32))
self.assertTrue(is_called)
def test_inspect_sharding_3d_input_pos_sharding(self):
def test_inspect_sharding_3d_jit(self):
def _cb(sd):
self.assertIsInstance(sd, jax.sharding.PositionalSharding)
self.assertIsInstance(sd, jax.sharding.NamedSharding)
self.assertLen(sd.device_set, 2)
def f_(x):
@ -1149,7 +1149,7 @@ class InspectShardingTest(jtu.JaxTestCase):
f(arr)
def test_inspect_sharding_3d_input_named_sharding(self):
def test_inspect_sharding_3d_pjit(self):
def _cb(sd):
self.assertIsInstance(sd, jax.sharding.NamedSharding)
self.assertLen(sd.device_set, 2)

View File

@ -12,35 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp
import jax.sharding as shd
from jax._src import abstract_arrays
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
@ -110,294 +94,6 @@ class RandomTest(jtu.JaxTestCase):
self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})")
class FfiTest(jtu.JaxTestCase):
def find_custom_call_in_module(self, module):
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
return op
self.fail("No custom_call found in the lowered IR")
def testHeadersExist(self):
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
@parameterized.parameters([
(tuple(range(3)), tuple(range(3))),
(None, tuple(reversed(range(3)))),
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
])
def testLoweringLayouts(self, layout_spec, expected_layout):
# Regression test to ensure that the lowering rule properly captures
# layouts.
def lowering_rule(ctx, x):
aval, = ctx.avals_in
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
prim.def_impl(lambda x: x)
prim.def_abstract_eval(lambda x: x)
mlir.register_lowering(prim, lowering_rule)
x = jnp.ones((3,) * len(expected_layout))
lowered = jax.jit(prim.bind).lower(x)
module = lowered.compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
text = lowered.as_text()
expected = ", ".join(map(str, expected_layout))
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
@parameterized.parameters([
(True, mlir.ir.BoolAttr.get),
(1, mlir.i64_attr),
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
("param", mlir.ir.StringAttr.get),
(np.float32(0.5),
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
])
def testParams(self, param, expected_builder):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, param=param)
# Here we inspect the lowered IR to test that the parameter has been
# serialized with the appropriate type.
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
def testToken(self):
def fun():
token = lax.create_token()
return jex.ffi.ffi_call("test_ffi", core.abstract_token)(token)
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
def testEffectsHlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jex.ffi.ffi_call(target_name, (), has_side_effect=True)()
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())
def testJvpError(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
with self.assertRaisesRegex(
ValueError, "The FFI call to `.+` cannot be differentiated."):
jax.jvp(fun, (0.5,), (0.5,))
def testNonHashableAttributes(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = {a = 1", hlo)
# If non-hashable arguments aren't handled properly, this will raise a
# TypeError. We make sure it doesn't.
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
@jtu.run_on_devices("gpu", "cpu")
def testFfiCall(self, shape):
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = ffi_call_geqrf(x)
for a, b in zip(actual, expected):
self.assertArraysEqual(a, b)
@jtu.sample_product(
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
)
@jtu.run_on_devices("gpu", "cpu")
def testFfiCallBatching(self, shape, vmap_method):
shape = (10,) + shape
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating
# point tolerance for this check.
self.assertArraysAllClose(a, b)
else:
self.assertArraysEqual(a, b)
@jtu.run_on_devices("gpu", "cpu")
def testVectorizedDeprecation(self):
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(ffi_call_geqrf)(x)
def testBackwardCompatSyntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def testInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
def testInvalidInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
with self.assertRaisesRegex(ValueError, "with input index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
with self.assertRaisesRegex(ValueError, "with output index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
x.dtype),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def testLegacyBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, custom_call_api_version=2,
legacy_backend_config="12345")(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, 'backend_config = "12345"')
def testInvalidBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
with self.assertRaisesRegex(ValueError,
"The use of the legacy_backend_config"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x,
custom_call_api_version=2)(x, attribute=1)
with self.assertRaisesRegex(ValueError,
"The use of ffi_call attributes requires"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def testAllow64(self):
if config.enable_x64.value:
self.skipTest("Requires enable_x64=False")
def fun():
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
def testInvalidResultType(self):
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 0"):
jex.ffi.ffi_call("test", None)()
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 1"):
jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
@jtu.run_on_devices("gpu", "cpu")
def testShardMap(self):
mesh = jtu.create_mesh((1,), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
out_specs=shd.PartitionSpec('i'))
def f(x):
return ffi_call_geqrf(x)
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
def ffi_call_geqrf(x, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
output_types = [
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
def call(platform, x):
target_name = dict(
cpu="lapack_sgeqrf_ffi",
rocm="hipsolver_geqrf_ffi",
cuda="cusolver_geqrf_ffi",
)[platform]
return jex.ffi.ffi_call(
target_name, output_types, input_output_aliases={0: 0},
input_layouts=[x_major_to_minor],
output_layouts=[x_major_to_minor, None],
**kwargs)(x)
return lax.platform_dependent(
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
cuda=partial(call, "cuda"))
class MlirRegisterLoweringTest(jtu.JaxTestCase):
def test_unknown_platform_error(self):

337
tests/ffi_test.py Normal file
View File

@ -0,0 +1,337 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp
import jax.sharding as shd
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
class FfiTest(jtu.JaxTestCase):
def find_custom_call_in_module(self, module):
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
return op
self.fail("No custom_call found in the lowered IR")
def test_headers_exist(self):
base_dir = os.path.join(jax.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
@parameterized.parameters([
(tuple(range(3)), tuple(range(3))),
(None, tuple(reversed(range(3)))),
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
])
def test_lowering_layouts(self, layout_spec, expected_layout):
# Regression test to ensure that the lowering rule properly captures
# layouts.
def lowering_rule(ctx, x):
return jax.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
prim.def_impl(lambda x: x)
prim.def_abstract_eval(lambda x: x)
mlir.register_lowering(prim, lowering_rule)
x = jnp.ones((3,) * len(expected_layout))
lowered = jax.jit(prim.bind).lower(x)
module = lowered.compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
text = lowered.as_text()
expected = ", ".join(map(str, expected_layout))
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
@parameterized.parameters([
(True, mlir.ir.BoolAttr.get),
(1, mlir.i64_attr),
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
("param", mlir.ir.StringAttr.get),
(np.float32(0.5),
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
])
def test_params(self, param, expected_builder):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, param=param)
# Here we inspect the lowered IR to test that the parameter has been
# serialized with the appropriate type.
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
def test_token(self):
def fun():
token = lax.create_token()
return jax.ffi.ffi_call("test_ffi", core.abstract_token)(token)
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
def test_effects_hlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jax.ffi.ffi_call(target_name, (), has_side_effect=True)()
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())
def test_jvp_error(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
with self.assertRaisesRegex(
ValueError, "The FFI call to `.+` cannot be differentiated."):
jax.jvp(fun, (0.5,), (0.5,))
def test_non_hashable_attributes(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = {a = 1", hlo)
# If non-hashable arguments aren't handled properly, this will raise a
# TypeError. We make sure it doesn't.
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
@jtu.run_on_devices("gpu", "cpu")
def test_ffi_call(self, shape):
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = ffi_call_geqrf(x)
for a, b in zip(actual, expected):
self.assertArraysEqual(a, b)
@jtu.sample_product(
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
)
@jtu.run_on_devices("gpu", "cpu")
def test_ffi_call_batching(self, shape, vmap_method):
shape = (10,) + shape
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating
# point tolerance for this check.
self.assertArraysAllClose(a, b)
else:
self.assertArraysEqual(a, b)
@jtu.run_on_devices("gpu", "cpu")
def test_vectorized_deprecation(self):
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(ffi_call_geqrf)(x)
def test_backward_compat_syntax(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def test_input_output_aliases(self):
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
def test_invalid_input_output_aliases(self):
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
with self.assertRaisesRegex(ValueError, "with input index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
with self.assertRaisesRegex(ValueError, "with output index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
x.dtype),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def test_legacy_backend_config(self):
def fun(x):
return jax.ffi.ffi_call("test", x, custom_call_api_version=2,
legacy_backend_config="12345")(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, 'backend_config = "12345"')
def test_invalid_backend_config(self):
def fun(x):
return jax.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
with self.assertRaisesRegex(ValueError,
"The use of the legacy_backend_config"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", x,
custom_call_api_version=2)(x, attribute=1)
with self.assertRaisesRegex(ValueError,
"The use of ffi_call attributes requires"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def test_allow_x64(self):
if config.enable_x64.value:
self.skipTest("Requires enable_x64=False")
def fun():
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
def test_invalid_result_type(self):
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 0"):
jax.ffi.ffi_call("test", None)()
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 1"):
jax.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
@jtu.run_on_devices("gpu", "cpu")
def test_shard_map(self):
mesh = jtu.create_mesh((1,), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
out_specs=shd.PartitionSpec('i'))
def f(x):
return ffi_call_geqrf(x)
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
@jtu.run_on_devices("gpu", "cpu")
@jtu.ignore_warning(category=DeprecationWarning)
def test_extend_import_shim(self):
ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True)
def ffi_call_geqrf(x, _use_extend=False, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
output_types = [
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
def call(platform, x):
target_name = dict(
cpu="lapack_sgeqrf_ffi",
rocm="hipsolver_geqrf_ffi",
cuda="cusolver_geqrf_ffi",
)[platform]
f = jex.ffi.ffi_call if _use_extend else jax.ffi.ffi_call
return f(
target_name, output_types, input_output_aliases={0: 0},
input_layouts=[x_major_to_minor],
output_layouts=[x_major_to_minor, None],
**kwargs)(x)
return lax.platform_dependent(
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
cuda=partial(call, "cuda"))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -327,12 +327,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1E-6)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[1, 2, 3, 6],
shape=[(5,), (10,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testLpmn(self, l_max, shape, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
@ -350,12 +350,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
atol=3e-3, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
l_max=[3, 4, 6, 32],
shape=[(2,), (3,), (4,), (64,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
@ -383,7 +383,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rtol=1e-5, atol=1e-5, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]
@ -398,6 +399,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeZero(self):
"""Tests the spherical harmonics of order zero and degree zero."""
@ -411,6 +414,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeOne(self):
"""Tests the spherical harmonics of order one and degree zero."""
@ -424,6 +429,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderOneDegreeOne(self):
"""Tests the spherical harmonics of order one and degree one."""
@ -445,6 +452,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
],
dtype=jtu.dtypes.all_integer,
)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
@ -469,6 +478,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
with self.subTest('Test against numpy.'):
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmCornerCaseWithWrongNmax(self):
"""Tests the corner case where `n_max` is not the maximum value of `n`."""

View File

@ -200,16 +200,11 @@ class DimExprTest(jtu.JaxTestCase):
"3*a*mod(a + 2, b + 2)"),
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2,
"6*floordiv(a + 2, b + 2)"),
# Keep for backwards compatibility. We ought to be able to parse
# non_negative
("non_negative(a - 2)", "build_inside", "max(a - 2, 0)"),
("max(a, b)", "build_inside", "max(a, b)"),
("min(a, b)", "build_inside", "min(a, b)"),
]])
def test_parse_dim(self, dim_spec, dim_poly, expected_str):
if dim_spec == "non_negative(a - 2)":
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
elif dim_spec == "max(a, b)":
if dim_spec == "max(a, b)":
dim_poly = core.max_dim(DimExprTest.a, DimExprTest.b)
elif dim_spec == "min(a, b)":
dim_poly = core.min_dim(DimExprTest.a, DimExprTest.b)
@ -382,13 +377,6 @@ class DimExprTest(jtu.JaxTestCase):
[b * (a % 4), b * (a // 4), a * (a // 4), a // 4,
a * a, b, 15])
# This failed with a previous implementation of factor equality
self.assertNotEqual(shape_poly._DimTerm.from_operation(shape_poly._DimFactor.NON_NEGATIVE,
a - b - 1,
scope=a.scope),
shape_poly._DimTerm.from_operation(shape_poly._DimFactor.NON_NEGATIVE,
a - 2 * b - 1,
scope=a.scope))
def test_bounds_arithmetic(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
bounded_le4 = 5 - a
@ -471,6 +459,10 @@ class DimExprTest(jtu.JaxTestCase):
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
self.assertEqual(_bounds((a + 3) - (a + 3) // 2), (2, np.inf))
self.assertEqual(_bounds((a + 6) - 1 * (a + 6) // 4), (6, np.inf))
self.assertEqual(_bounds((a + 6) - 2 * ((a + 6) // 4)), (4, np.inf))
self.assertEqual(_bounds((a + 6) - 3 * ((a + 6) // 4)), (2, np.inf))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Possible division by 0"):
@ -495,15 +487,6 @@ class DimExprTest(jtu.JaxTestCase):
self.assertGreaterEqual(fact_val, lb)
self.assertLessEqual(fact_val, ub)
def test_bounds_non_negative(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertEqual(_bounds(core.non_negative_dim(a)), (1, np.inf))
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)), (1, 3))
def test_max_dim(self):
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
@ -599,7 +582,7 @@ class DimExprTest(jtu.JaxTestCase):
def test_bounds_complex(self):
a, b = shape_poly.symbolic_shape("a, b")
min_a_b = b - core.non_negative_dim(b - a)
min_a_b = b - core.max_dim(0, b - a)
# This comes up in slicing with stride
self.assertGreaterEqual(min_a_b // 2, 0)
@ -805,16 +788,6 @@ class DimExprTest(jtu.JaxTestCase):
set(decision.combine_term_with_existing(_m(d), 2, scope=scope,
only_smaller_than_t=True)))
def test_non_negative_dim(self):
a, = shape_poly.symbolic_shape("a,")
self.sampled_assertion(2, core.non_negative_dim, 2)
self.sampled_assertion(0, core.non_negative_dim, 0)
self.sampled_assertion(0, core.non_negative_dim, -1)
self.sampled_assertion(a, core.non_negative_dim, a)
self.sampled_assertion(2 * a - 1, core.non_negative_dim, 2 * a - 1)
self.sampled_assertion(core.non_negative_dim(a - 2),
core.non_negative_dim, a - 2)
def test_dilate_dim(self):
"""0 if d == 0 else 1 + dilation * (d - 1))"""
@ -3013,31 +2986,30 @@ _POLY_SHAPE_TEST_HARNESSES = [
RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
override_jax_config_flags=override_jax_config_flags), # type: ignore
# TODO(necula): The known dimensions product must be even.
PolyHarness("random_categorical", f"axis=0_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=0),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 8), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}",
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 8), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, b1, b2"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size
lambda key, a: jax.random.categorical(
jax.random.wrap_key_data(key), a, axis=1),
jax.random.wrap_key_data(key), a, axis=1),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5, 0), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
@ -3055,14 +3027,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
RandArg((64, 12, 4), _f32), # sample on axis=1
RandArg((3, 4), _f32),
StaticArg(use_p)],
# TODO(necula): threefry requires even-sized samples.
polymorphic_shapes=[None,
"_, 2*b1, _" if arr_poly else None,
"b0, b1, b2" if arr_poly else None,
"b3, b4" if shape_poly else None],
# The array sampled dimension must be larger than res_shape.size
symbolic_constraints=[
"2*b1 >= 12" if arr_poly else "1 >= 0",
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"b1 >= 12" if arr_poly else "1 >= 0",
"b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
"12 >= b3*b4" if shape_poly else "1 >= 0"
],
override_jax_config_flags=override_jax_config_flags,
@ -3089,24 +3060,20 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
polymorphic_shapes=[None, "b0, 4, 5"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"even_2_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
(2 * a.shape[0], a.shape[1]),
dtype=_f32),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)],
polymorphic_shapes=[None, "b0, b1, ..."],
polymorphic_shapes=[None, "b0, 2*b1"],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
PolyHarness("random_uniform", f"error_unknown_evenness_{flags_name}",
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
a.shape, dtype=_f32),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 5), _f32)],
polymorphic_shapes=[None, "b0, ..."],
expect_error=(
(core.InconclusiveDimensionOperation,
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
polymorphic_shapes=[None, "b0, b1"],
override_jax_config_flags=override_jax_config_flags) # type: ignore
]
for key_size, flags_name, override_jax_config_flags in [

View File

@ -1904,7 +1904,6 @@ class ShardMapTest(jtu.JaxTestCase):
x = shard_map(g, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x)
return jax.lax.with_sharding_constraint(
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
@ -2162,7 +2161,22 @@ class ShardMapTest(jtu.JaxTestCase):
mesh, in_specs=P('i', None), out_specs=P('i', None),
check_rep=False, auto=frozenset({'j'}))()
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
def test_partial_auto_axis_index_degenerated_axis(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
out_sharding = NamedSharding(mesh, P('i', None))
@partial(jax.jit, out_shardings=out_sharding)
def f():
return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1),
mesh, in_specs=P('i', None), out_specs=P('i', None),
check_rep=False, auto=frozenset({'j'}))()
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
def test_partial_auto_ppermute(self):
if xla_extension_version < 302:

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "ac6e71fe0cf864eec152de5ba761b76d8bef3153"
XLA_SHA256 = "2b568ff365bc4b5c2b257002aa71f094a2b60357ceb1f2a1c6c33f4ad1a411bd"
XLA_COMMIT = "1a6361a734c5cd10dc93938fc6163a51fd37b82e"
XLA_SHA256 = "01159fd52f0e402829a3823472a309562817c72d0212f81cd5555f77394c094f"
def repo():
tf_http_archive(