mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #194 from ROCm/ci-upstream-sync-80_1
CI: 01/07/25 upstream sync
This commit is contained in:
commit
972f95b95d
4
.github/workflows/ci-build.yaml
vendored
4
.github/workflows/ci-build.yaml
vendored
@ -132,8 +132,8 @@ jobs:
|
||||
JAX_ARRAY: 1
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
|
||||
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
|
||||
|
||||
|
||||
documentation_render:
|
||||
|
13
CHANGELOG.md
13
CHANGELOG.md
@ -19,17 +19,27 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Changes:
|
||||
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
|
||||
supported version until June 2025.
|
||||
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
|
||||
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
|
||||
the case of many arguments ({jax-issue}`#25214`).
|
||||
|
||||
* New Features
|
||||
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
|
||||
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
|
||||
transforms in more than 3 dimensions, which was previously the limit. See
|
||||
{jax-issue}`#25606` for more details.
|
||||
* Support added for user defined state in the FFI via the new
|
||||
{func}`jax.ffi.register_ffi_type_id` function.
|
||||
|
||||
* Deprecations
|
||||
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
|
||||
are now deprecated, having been replaced by symbols of the same name
|
||||
in {mod}`jax.core`.
|
||||
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
|
||||
are deprecated, following their deprecation in SciPy v1.15.0. There are
|
||||
no plans to replace these deprecated functions with new APIs.
|
||||
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
|
||||
previous import path is deprecated.
|
||||
|
||||
* Deletions
|
||||
* `jax_enable_memories` flag has been deleted and the behavior of that flag
|
||||
@ -37,6 +47,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
|
||||
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
|
||||
and `jax.errors.JaxRuntimeError` respectively.
|
||||
* The `jax.experimental.array_api` module has been removed after being
|
||||
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
|
||||
the array API directly.
|
||||
|
||||
## jax 0.4.38 (Dec 17, 2024)
|
||||
|
||||
|
@ -468,6 +468,9 @@ async def main():
|
||||
# Enable clang settings that are needed for the build to work with newer
|
||||
# versions of Clang.
|
||||
wheel_build_command_base.append("--config=clang")
|
||||
if clang_major_version < 19:
|
||||
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
||||
|
||||
else:
|
||||
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
|
||||
logging.debug(
|
||||
@ -477,6 +480,10 @@ async def main():
|
||||
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
|
||||
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
|
||||
|
||||
gcc_major_version = utils.get_gcc_major_version(gcc_path)
|
||||
if gcc_major_version < 13:
|
||||
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
||||
|
||||
if not args.disable_mkl_dnn:
|
||||
logging.debug("Enabling MKL DNN")
|
||||
if target_cpu == "aarch64":
|
||||
|
@ -74,9 +74,8 @@ class System(object):
|
||||
env = dict(os.environ)
|
||||
if self.pkgbin == "apt":
|
||||
env["DEBIAN_FRONTEND"] = "noninteractive"
|
||||
|
||||
# Update indexes.
|
||||
subprocess.check_call(["apt-get", "update"])
|
||||
# Update indexes.
|
||||
subprocess.check_call(["apt-get", "update"])
|
||||
|
||||
LOG.info("Running %r" % cmd)
|
||||
subprocess.check_call(cmd, env=env)
|
||||
|
@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):
|
||||
|
||||
return major_version
|
||||
|
||||
def get_gcc_major_version(gcc_path: str):
|
||||
gcc_version_proc = subprocess.run(
|
||||
[gcc_path, "-dumpversion"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
major_version = int(gcc_version_proc.stdout)
|
||||
|
||||
return major_version
|
||||
|
||||
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str]):
|
||||
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
|
||||
# Get the index of the "run" parameter. Build options will come after "run" so
|
||||
|
@ -362,4 +362,5 @@ rediraffe_redirects = {
|
||||
'jax-101/index.rst': 'tutorials.rst',
|
||||
'notebooks/external_callbacks.md': 'external-callbacks.md',
|
||||
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
|
||||
'jax.extend.ffi.rst': 'jax.ffi.rst',
|
||||
}
|
||||
|
@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
|
||||
E.g., you can run:
|
||||
|
||||
```
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
||||
```
|
||||
|
||||
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
|
||||
|
@ -21,7 +21,7 @@
|
||||
"JAX's FFI support is provided in two parts:\n",
|
||||
"\n",
|
||||
"1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n",
|
||||
"2. A Python front end, available in the `jax.extend.ffi` submodule.\n",
|
||||
"2. A Python front end, available in the `jax.ffi` submodule.\n",
|
||||
"\n",
|
||||
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
|
||||
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
|
||||
@ -191,9 +191,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n",
|
||||
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.\n",
|
||||
"This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n",
|
||||
"JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:"
|
||||
"JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -204,12 +204,11 @@
|
||||
"source": [
|
||||
"import ctypes\n",
|
||||
"from pathlib import Path\n",
|
||||
"import jax.extend as jex\n",
|
||||
"\n",
|
||||
"path = next(Path(\"ffi\").glob(\"librms_norm*\"))\n",
|
||||
"rms_norm_lib = ctypes.cdll.LoadLibrary(path)\n",
|
||||
"jex.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm\", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
|
||||
"jax.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm\", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -217,7 +216,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{tip}\n",
|
||||
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
|
||||
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"**An alternative approach**:\n",
|
||||
@ -251,7 +250,7 @@
|
||||
"# Assuming that we compiled a nanobind extension called `rms_norm`:\n",
|
||||
"import rms_norm as rms_norm_lib\n",
|
||||
"\n",
|
||||
"jex.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
|
||||
"jax.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
@ -261,7 +260,7 @@
|
||||
"source": [
|
||||
"## Frontend code\n",
|
||||
"\n",
|
||||
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:"
|
||||
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -282,7 +281,7 @@
|
||||
" if x.dtype != jnp.float32:\n",
|
||||
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
|
||||
"\n",
|
||||
" call = jex.ffi.ffi_call(\n",
|
||||
" call = jax.ffi.ffi_call(\n",
|
||||
" # The target name must be the same string as we used to register the target\n",
|
||||
" # above in `register_custom_call_target`\n",
|
||||
" \"rms_norm\",\n",
|
||||
@ -314,25 +313,25 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n",
|
||||
"Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
|
||||
"It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
|
||||
"Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
|
||||
"It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
|
||||
"\n",
|
||||
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.\n",
|
||||
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
|
||||
"\n",
|
||||
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
|
||||
"The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
|
||||
"\n",
|
||||
"```{tip}\n",
|
||||
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n",
|
||||
"In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n",
|
||||
"One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
|
||||
"One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"(ffi-call-vmap)=\n",
|
||||
"### Batching with `vmap`\n",
|
||||
"\n",
|
||||
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
|
||||
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
"{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
|
||||
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n",
|
||||
"\n",
|
||||
"The simplest `vmap_method` is `\"sequential\"`.\n",
|
||||
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
|
||||
@ -395,7 +394,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rms_norm_sequential(x, eps=1e-5):\n",
|
||||
" return jex.ffi.ffi_call(\n",
|
||||
" return jax.ffi.ffi_call(\n",
|
||||
" \"rms_norm\",\n",
|
||||
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
|
||||
" vmap_method=\"sequential\",\n",
|
||||
@ -418,9 +417,9 @@
|
||||
"source": [
|
||||
"### Differentiation\n",
|
||||
"\n",
|
||||
"Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
|
||||
"Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
|
||||
"As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n",
|
||||
"Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
|
||||
"Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
|
||||
"\n",
|
||||
"More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n",
|
||||
"In this case, we actually define two new FFI calls:\n",
|
||||
@ -429,7 +428,7 @@
|
||||
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
|
||||
"\n",
|
||||
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
|
||||
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
|
||||
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
|
||||
"\n",
|
||||
"This custom derivative rule can be wired in as follows:"
|
||||
]
|
||||
@ -440,16 +439,16 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"jex.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm_fwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
|
||||
"jax.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm_fwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
|
||||
")\n",
|
||||
"jex.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm_bwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
|
||||
"jax.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm_bwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rms_norm_fwd(x, eps=1e-5):\n",
|
||||
" y, res = jex.ffi.ffi_call(\n",
|
||||
" y, res = jax.ffi.ffi_call(\n",
|
||||
" \"rms_norm_fwd\",\n",
|
||||
" (\n",
|
||||
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
|
||||
@ -466,7 +465,7 @@
|
||||
" assert res.shape == ct.shape[:-1]\n",
|
||||
" assert x.shape == ct.shape\n",
|
||||
" return (\n",
|
||||
" jex.ffi.ffi_call(\n",
|
||||
" jax.ffi.ffi_call(\n",
|
||||
" \"rms_norm_bwd\",\n",
|
||||
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
@ -533,7 +532,7 @@
|
||||
"On the front end, the registration code would be updated to specify the appropriate platform:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"jex.ffi.register_ffi_target(\n",
|
||||
"jax.ffi.register_ffi_target(\n",
|
||||
" \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
@ -554,7 +553,7 @@
|
||||
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
||||
"\n",
|
||||
" def impl(target_name):\n",
|
||||
" return lambda x: jex.ffi.ffi_call(\n",
|
||||
" return lambda x: jax.ffi.ffi_call(\n",
|
||||
" target_name,\n",
|
||||
" out_type,\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
@ -620,9 +619,9 @@
|
||||
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
|
||||
"We will leave these topics to future tutorials, but here are some possibly useful references:\n",
|
||||
"\n",
|
||||
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
|
||||
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
|
||||
"\n",
|
||||
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
|
||||
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
|
||||
"\n",
|
||||
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."
|
||||
]
|
||||
|
63
docs/ffi.md
63
docs/ffi.md
@ -29,7 +29,7 @@ We will discuss some possible approaches below, but it is important to call this
|
||||
JAX's FFI support is provided in two parts:
|
||||
|
||||
1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and
|
||||
2. A Python front end, available in the `jax.extend.ffi` submodule.
|
||||
2. A Python front end, available in the `jax.ffi` submodule.
|
||||
|
||||
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
|
||||
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.
|
||||
@ -171,23 +171,22 @@ To compile the shared library, we're using CMake here, but you should be able to
|
||||
!cmake --install ffi/_build
|
||||
```
|
||||
|
||||
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.
|
||||
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.
|
||||
This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).
|
||||
JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:
|
||||
JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
import jax.extend as jex
|
||||
|
||||
path = next(Path("ffi").glob("librms_norm*"))
|
||||
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
|
||||
jex.ffi.register_ffi_target(
|
||||
"rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
|
||||
jax.ffi.register_ffi_target(
|
||||
"rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
|
||||
```
|
||||
|
||||
```{tip}
|
||||
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
|
||||
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
|
||||
```
|
||||
|
||||
**An alternative approach**:
|
||||
@ -221,14 +220,14 @@ Then, in Python we can register this handler using:
|
||||
# Assuming that we compiled a nanobind extension called `rms_norm`:
|
||||
import rms_norm as rms_norm_lib
|
||||
|
||||
jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
|
||||
jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
|
||||
```
|
||||
|
||||
+++
|
||||
|
||||
## Frontend code
|
||||
|
||||
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:
|
||||
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:
|
||||
|
||||
```{code-cell} ipython3
|
||||
import numpy as np
|
||||
@ -243,7 +242,7 @@ def rms_norm(x, eps=1e-5):
|
||||
if x.dtype != jnp.float32:
|
||||
raise ValueError("Only the float32 dtype is implemented by rms_norm")
|
||||
|
||||
call = jex.ffi.ffi_call(
|
||||
call = jax.ffi.ffi_call(
|
||||
# The target name must be the same string as we used to register the target
|
||||
# above in `register_custom_call_target`
|
||||
"rms_norm",
|
||||
@ -271,25 +270,25 @@ np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
|
||||
```
|
||||
|
||||
This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.
|
||||
Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
|
||||
It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
|
||||
Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
|
||||
It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
|
||||
|
||||
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
|
||||
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.
|
||||
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
|
||||
|
||||
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
|
||||
The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
|
||||
|
||||
```{tip}
|
||||
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
|
||||
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.
|
||||
In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.
|
||||
One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
|
||||
One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
|
||||
```
|
||||
|
||||
(ffi-call-vmap)=
|
||||
### Batching with `vmap`
|
||||
|
||||
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
|
||||
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
|
||||
{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
|
||||
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.
|
||||
|
||||
The simplest `vmap_method` is `"sequential"`.
|
||||
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
|
||||
@ -326,7 +325,7 @@ Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {f
|
||||
|
||||
```{code-cell} ipython3
|
||||
def rms_norm_sequential(x, eps=1e-5):
|
||||
return jex.ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
"rms_norm",
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
vmap_method="sequential",
|
||||
@ -342,9 +341,9 @@ If your foreign function provides an efficient batching rule that isn't supporte
|
||||
|
||||
### Differentiation
|
||||
|
||||
Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
|
||||
Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
|
||||
As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.
|
||||
Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.
|
||||
Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.
|
||||
|
||||
More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.
|
||||
In this case, we actually define two new FFI calls:
|
||||
@ -353,21 +352,21 @@ In this case, we actually define two new FFI calls:
|
||||
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.
|
||||
|
||||
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
|
||||
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
|
||||
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
|
||||
|
||||
This custom derivative rule can be wired in as follows:
|
||||
|
||||
```{code-cell} ipython3
|
||||
jex.ffi.register_ffi_target(
|
||||
"rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
|
||||
jax.ffi.register_ffi_target(
|
||||
"rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
|
||||
)
|
||||
jex.ffi.register_ffi_target(
|
||||
"rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
|
||||
jax.ffi.register_ffi_target(
|
||||
"rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_fwd(x, eps=1e-5):
|
||||
y, res = jex.ffi.ffi_call(
|
||||
y, res = jax.ffi.ffi_call(
|
||||
"rms_norm_fwd",
|
||||
(
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
@ -384,7 +383,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
assert res.shape == ct.shape[:-1]
|
||||
assert x.shape == ct.shape
|
||||
return (
|
||||
jex.ffi.ffi_call(
|
||||
jax.ffi.ffi_call(
|
||||
"rms_norm_bwd",
|
||||
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
|
||||
vmap_method="broadcast_all",
|
||||
@ -447,7 +446,7 @@ Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels.
|
||||
On the front end, the registration code would be updated to specify the appropriate platform:
|
||||
|
||||
```python
|
||||
jex.ffi.register_ffi_target(
|
||||
jax.ffi.register_ffi_target(
|
||||
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
|
||||
)
|
||||
```
|
||||
@ -462,7 +461,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
|
||||
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
|
||||
def impl(target_name):
|
||||
return lambda x: jex.ffi.ffi_call(
|
||||
return lambda x: jax.ffi.ffi_call(
|
||||
target_name,
|
||||
out_type,
|
||||
vmap_method="broadcast_all",
|
||||
@ -499,8 +498,8 @@ and there will be no runtime overhead to using {func}`jax.lax.platform_dependent
|
||||
This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.
|
||||
We will leave these topics to future tutorials, but here are some possibly useful references:
|
||||
|
||||
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
|
||||
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
|
||||
|
||||
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
|
||||
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
|
||||
|
||||
* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details.
|
||||
|
@ -1,28 +0,0 @@
|
||||
``jax.experimental.array_api`` module
|
||||
=====================================
|
||||
|
||||
.. note::
|
||||
The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and
|
||||
importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy`
|
||||
implements the array API standard directly by default. See :ref:`python-array-api`
|
||||
for details.
|
||||
|
||||
This module includes experimental JAX support for the `Python array API standard`_.
|
||||
Support for this is currently experimental and not fully complete.
|
||||
|
||||
Example Usage::
|
||||
|
||||
>>> from jax.experimental import array_api as xp
|
||||
|
||||
>>> xp.__array_api_version__
|
||||
'2023.12'
|
||||
|
||||
>>> arr = xp.arange(1000)
|
||||
|
||||
>>> arr.sum()
|
||||
Array(499500, dtype=int32)
|
||||
|
||||
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
|
||||
and implements most of the API listed in the standard.
|
||||
|
||||
.. _Python array API standard: https://data-apis.org/array-api/
|
@ -14,7 +14,6 @@ Experimental Modules
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.array_api
|
||||
jax.experimental.checkify
|
||||
jax.experimental.compilation_cache
|
||||
jax.experimental.custom_partitioning
|
||||
|
@ -1,12 +0,0 @@
|
||||
``jax.extend.ffi`` module
|
||||
=========================
|
||||
|
||||
.. automodule:: jax.extend.ffi
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ffi_call
|
||||
ffi_lowering
|
||||
pycapsule
|
||||
register_ffi_target
|
@ -12,7 +12,6 @@ Modules
|
||||
:maxdepth: 1
|
||||
|
||||
jax.extend.core
|
||||
jax.extend.ffi
|
||||
jax.extend.linear_util
|
||||
jax.extend.mlir
|
||||
jax.extend.random
|
||||
|
31
docs/jax.ffi.rst
Normal file
31
docs/jax.ffi.rst
Normal file
@ -0,0 +1,31 @@
|
||||
``jax.ffi`` module
|
||||
==================
|
||||
|
||||
.. automodule:: jax.ffi
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ffi_call
|
||||
ffi_lowering
|
||||
pycapsule
|
||||
register_ffi_target
|
||||
register_ffi_type_id
|
||||
|
||||
|
||||
``jax.extend.ffi`` module (deprecated)
|
||||
======================================
|
||||
|
||||
The ``jax.extend.ffi`` module has been moved to ``jax.ffi``, and that import
|
||||
path should be used instead, but these functions remain documented here while
|
||||
the legacy import is being deprecated.
|
||||
|
||||
.. automodule:: jax.extend.ffi
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ffi_call
|
||||
ffi_lowering
|
||||
pycapsule
|
||||
register_ffi_target
|
@ -542,7 +542,8 @@ Python Array API standard
|
||||
|
||||
Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
|
||||
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
|
||||
module is no longer required, and will raise a deprecation warning.
|
||||
module is no longer required, and will raise a deprecation warning. After
|
||||
JAX v0.5.0, this import will raise an error.
|
||||
|
||||
Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
|
||||
with the `Python Array API Standard`_. You can access the Array API namespace via
|
||||
|
@ -18,6 +18,7 @@ Subpackages
|
||||
jax.dlpack
|
||||
jax.distributed
|
||||
jax.dtypes
|
||||
jax.ffi
|
||||
jax.flatten_util
|
||||
jax.image
|
||||
jax.nn
|
||||
|
@ -121,13 +121,13 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from scipy import misc\n",
|
||||
"from scipy import datasets\n",
|
||||
"import jax.scipy as jsp\n",
|
||||
"\n",
|
||||
"fig, ax = plt.subplots(1, 3, figsize=(12, 5))\n",
|
||||
"\n",
|
||||
"# Load a sample image; compute mean() to convert from RGB to grayscale.\n",
|
||||
"image = jnp.array(misc.face().mean(-1))\n",
|
||||
"image = jnp.array(datasets.face().mean(-1))\n",
|
||||
"ax[0].imshow(image, cmap='binary_r')\n",
|
||||
"ax[0].set_title('original')\n",
|
||||
"\n",
|
||||
|
@ -75,13 +75,13 @@ For example, here is a simple approach to de-noising an image based on convoluti
|
||||
:id: Jk5qdnbv6QgT
|
||||
:outputId: 292205eb-aa09-446f-eec2-af8c23cfc718
|
||||
|
||||
from scipy import misc
|
||||
from scipy import datasets
|
||||
import jax.scipy as jsp
|
||||
|
||||
fig, ax = plt.subplots(1, 3, figsize=(12, 5))
|
||||
|
||||
# Load a sample image; compute mean() to convert from RGB to grayscale.
|
||||
image = jnp.array(misc.face().mean(-1))
|
||||
image = jnp.array(datasets.face().mean(-1))
|
||||
ax[0].imshow(image, cmap='binary_r')
|
||||
ax[0].set_title('original')
|
||||
|
||||
|
@ -17,6 +17,7 @@ pytest-xdist
|
||||
# Packages used for notebook execution
|
||||
matplotlib
|
||||
scikit-learn
|
||||
pooch
|
||||
numpy
|
||||
rich[jupyter]
|
||||
cmake
|
||||
|
@ -15,28 +15,27 @@
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax_ffi_example import _cpu_examples
|
||||
|
||||
for name, target in _cpu_examples.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
jax.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
def array_attr(num: int):
|
||||
return jex.ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
"array_attr",
|
||||
jax.ShapeDtypeStruct((), np.int32),
|
||||
)(array=np.arange(num, dtype=np.int32))
|
||||
|
||||
|
||||
def dictionary_attr(**kwargs):
|
||||
return jex.ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
"dictionary_attr",
|
||||
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def counter(index):
|
||||
return jex.ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
|
||||
|
@ -24,15 +24,14 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.extend as jex
|
||||
|
||||
# Load the shared library with the FFI target definitions
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
|
||||
jax.ffi.register_ffi_target("foo-fwd", jax.ffi.pycapsule(library.FooFwd),
|
||||
platform="CUDA")
|
||||
jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd),
|
||||
jax.ffi.register_ffi_target("foo-bwd", jax.ffi.pycapsule(library.FooBwd),
|
||||
platform="CUDA")
|
||||
|
||||
|
||||
@ -42,7 +41,7 @@ def foo_fwd(a, b):
|
||||
assert a.dtype == b.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
|
||||
c, b_plus_1 = jax.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
|
||||
return c, (a, b_plus_1)
|
||||
|
||||
|
||||
@ -55,7 +54,7 @@ def foo_bwd(res, c_grad):
|
||||
assert a.dtype == b_plus_1.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
|
||||
return jax.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
|
||||
n=n)
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
This example is exactly the same as the one in the `FFI tutorial
|
||||
<https://jax.readthedocs.io/en/latest/ffi.html>`, so more details can be found
|
||||
on that page. But, the high level summary is that we implement our custom
|
||||
extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in
|
||||
extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in
|
||||
this module. The behavior under autodiff is implemented using
|
||||
``jax.custom_vjp``.
|
||||
"""
|
||||
@ -26,13 +26,12 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax_ffi_example import _rms_norm
|
||||
|
||||
for name, target in _rms_norm.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
jax.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(1,))
|
||||
@ -53,7 +52,7 @@ def rms_norm(x, eps=1e-5):
|
||||
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
|
||||
# type (which corresponds to numpy's `float32` type), and it must be a
|
||||
# static parameter (i.e. not a JAX array).
|
||||
return jex.ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
# The target name must be the same string as we used to register the target
|
||||
# above in `register_ffi_target`
|
||||
"rms_norm",
|
||||
@ -63,7 +62,7 @@ def rms_norm(x, eps=1e-5):
|
||||
|
||||
|
||||
def rms_norm_fwd(x, eps=1e-5):
|
||||
y, res = jex.ffi.ffi_call(
|
||||
y, res = jax.ffi.ffi_call(
|
||||
"rms_norm_fwd",
|
||||
(
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
@ -80,7 +79,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
assert res.shape == ct.shape[:-1]
|
||||
assert x.shape == ct.shape
|
||||
return (
|
||||
jex.ffi.ffi_call(
|
||||
jax.ffi.ffi_call(
|
||||
"rms_norm_bwd",
|
||||
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
|
||||
vmap_method="broadcast_all",
|
||||
|
21
jax/BUILD
21
jax/BUILD
@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_building_jaxlib",
|
||||
"jax_export_file_visibility",
|
||||
"jax_extend_internal_users",
|
||||
"jax_extra_deps",
|
||||
"jax_internal_export_back_compat_test_util_visibility",
|
||||
@ -62,6 +63,11 @@ exports_files([
|
||||
"version.py",
|
||||
])
|
||||
|
||||
exports_files(
|
||||
["_src/export/serialization.fbs"],
|
||||
visibility = jax_export_file_visibility,
|
||||
)
|
||||
|
||||
# Packages that have access to JAX-internal implementation details.
|
||||
package_group(
|
||||
name = "internal",
|
||||
@ -199,6 +205,7 @@ py_library_providing_imports_info(
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/earray.py",
|
||||
"_src/ffi.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
"_src/interpreters/ad.py",
|
||||
@ -730,7 +737,6 @@ py_library(
|
||||
":jax",
|
||||
":mlir",
|
||||
"//jax/_src/lib",
|
||||
"//jax/extend:ffi",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
"//jaxlib/mlir:func_dialect",
|
||||
@ -1047,19 +1053,6 @@ pytype_library(
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_array_api",
|
||||
srcs = glob(
|
||||
[
|
||||
"experimental/array_api/*.py",
|
||||
],
|
||||
),
|
||||
visibility = [":internal"],
|
||||
deps = [
|
||||
":jax",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental_sparse",
|
||||
srcs = glob(
|
||||
|
@ -160,6 +160,7 @@ from jax import debug as debug
|
||||
from jax import dlpack as dlpack
|
||||
from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import ffi as ffi
|
||||
from jax import image as image
|
||||
from jax import lax as lax
|
||||
from jax import monitoring as monitoring
|
||||
|
@ -96,7 +96,7 @@ class Array(abc.ABC):
|
||||
def __truediv__(self, other) -> Array: ...
|
||||
def __floordiv__(self, other) -> Array: ...
|
||||
def __mod__(self, other) -> Array: ...
|
||||
def __divmod__(self, other) -> Array: ...
|
||||
def __divmod__(self, other) -> tuple[Array, Array]: ...
|
||||
def __pow__(self, other) -> Array: ...
|
||||
def __lshift__(self, other) -> Array: ...
|
||||
def __rshift__(self, other) -> Array: ...
|
||||
|
@ -2095,11 +2095,6 @@ def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimS
|
||||
# If d < window_size then (d - window_size) // window_stride < 0
|
||||
return max_dim((d - window_size) // window_stride + 1, 0)
|
||||
|
||||
# TODO(necula): Deprecated Jan 2024, to be removed.
|
||||
def non_negative_dim(d: DimSize) -> DimSize:
|
||||
"""max(d, 0)."""
|
||||
return max_dim(d, 0)
|
||||
|
||||
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
|
||||
d1_is_constant = is_constant_dim(d1)
|
||||
|
@ -390,6 +390,10 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/jax-ml/jax/issues')
|
||||
am = axis_context.abstract_mesh
|
||||
if am is not None:
|
||||
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
|
||||
am.axis_names)
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = axis_context.mesh._flat_devices_tuple
|
||||
else:
|
||||
|
@ -128,9 +128,6 @@ class _DimFactor:
|
||||
MOD = "mod"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
# TODO(necula): remove non_negative
|
||||
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
|
||||
# max but kept here for backwards compatibility.
|
||||
|
||||
__slots__ = ["var", "operation", "operands", "_hash", "_size"]
|
||||
|
||||
@ -446,11 +443,6 @@ class _DimExpr:
|
||||
@staticmethod
|
||||
def _from_operation(operation: str, *operands: DimSize,
|
||||
scope: SymbolicScope) -> DimSize:
|
||||
if operation == _DimFactor.NON_NEGATIVE: # For parsing, for backwards compatibility
|
||||
return _DimExpr._from_term(
|
||||
_DimTerm.from_operation(_DimFactor.MAX, *operands, 0,
|
||||
scope=scope), 1,
|
||||
scope=scope)
|
||||
return _DimExpr._from_term(
|
||||
_DimTerm.from_operation(operation, *operands, scope=scope), 1,
|
||||
scope=scope)
|
||||
@ -1665,8 +1657,6 @@ class _Parser:
|
||||
if tok.exact_type == tokenize.NAME:
|
||||
if tok.string in (_DimFactor.MOD, _DimFactor.FLOORDIV, _DimFactor.MAX, _DimFactor.MIN):
|
||||
return self.factor_binary_op(tok.string, self.next_tok())
|
||||
if tok.string == _DimFactor.NON_NEGATIVE: # We still parse this for backwards compatibility
|
||||
return self.factor_unary_op(_DimFactor.NON_NEGATIVE, self.next_tok())
|
||||
return _DimExpr._from_var(tok.string, self.scope), self.next_tok()
|
||||
number_sign = 1
|
||||
if tok.exact_type == tokenize.MINUS: # -k are negative constants
|
||||
|
@ -306,18 +306,21 @@ class _DecisionByElimination:
|
||||
scope=scope):
|
||||
# `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t
|
||||
# AND c_s > 0.
|
||||
# rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
|
||||
# `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
|
||||
# `rest` contains only terms smaller than `t`.
|
||||
rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s,
|
||||
c._sorted_terms, 0, c_s)
|
||||
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0,
|
||||
BoundsPrecision.BEST)
|
||||
if rest_ub < np.inf:
|
||||
# We have: e[i:]*t_s = rest - c*c_s <= rest_ub
|
||||
if t_s > 0:
|
||||
ub = min(ub, int(np.floor(rest_ub / t_s)))
|
||||
else:
|
||||
lb = max(lb, int(np.ceil(rest_ub / t_s)))
|
||||
|
||||
if rest_lb > - np.inf and c_eq == Comparator.EQ:
|
||||
# We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb
|
||||
if t_s > 0:
|
||||
lb = max(lb, int(np.ceil(rest_lb / t_s)))
|
||||
else:
|
||||
|
@ -72,12 +72,27 @@ def register_ffi_target(
|
||||
**kwargs)
|
||||
|
||||
|
||||
def register_ffi_type_id(
|
||||
name: str,
|
||||
obj: Any,
|
||||
platform: str = "cpu",
|
||||
) -> None:
|
||||
"""Registers a custom type ID for a FFI target.
|
||||
|
||||
Args:
|
||||
name: the name of the type ID. This name must be unique within the process.
|
||||
obj: a ``PyCapsule`` object encapsulating a pointer to the type ID.
|
||||
platform: the target platform.
|
||||
"""
|
||||
return xla_client.register_custom_type_id(name, obj, platform=platform)
|
||||
|
||||
|
||||
def pycapsule(funcptr):
|
||||
"""Wrap a ctypes function pointer in a PyCapsule.
|
||||
|
||||
The primary use of this function, and the reason why it lives with in the
|
||||
``jax.extend.ffi`` submodule, is to wrap function calls from external
|
||||
compiled libraries to be registered as XLA custom calls.
|
||||
``jax.ffi`` submodule, is to wrap function calls from external compiled
|
||||
libraries to be registered as XLA custom calls.
|
||||
|
||||
Example usage::
|
||||
|
||||
@ -88,7 +103,7 @@ def pycapsule(funcptr):
|
||||
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
|
||||
xla_client.register_custom_call_target(
|
||||
name="bar",
|
||||
fn=jax.extend.ffi.pycapsule(libfoo.bar),
|
||||
fn=jax.ffi.pycapsule(libfoo.bar),
|
||||
platform=PLATFORM,
|
||||
api_version=API_VERSION
|
||||
)
|
||||
@ -145,7 +160,7 @@ def ffi_lowering(
|
||||
|
||||
Note that layouts passed to this function as tuples should be in
|
||||
minor-to-major order (as expected by XLA) rather than major-to-minor as used
|
||||
by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``.
|
||||
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
|
||||
|
||||
If keyword arguments are passed to the lowering rule, these are treated as
|
||||
attributes, and added to `backend_config`.
|
||||
@ -310,7 +325,7 @@ def ffi_call(
|
||||
|
||||
Args:
|
||||
target_name: the name of the XLA FFI custom call target that was registered
|
||||
using :func:`~jax.extend.ffi.register_ffi_target`.
|
||||
using :func:`~jax.ffi.register_ffi_target`.
|
||||
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
|
||||
``dtype`` attributes which are expected to match the shape and dtype of
|
||||
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
|
@ -34,7 +34,7 @@ from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, is_constant_dim, is_constant_shape)
|
||||
from jax._src.extend import ffi
|
||||
from jax._src import ffi
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
|
@ -1634,6 +1634,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
if (isinstance(axis_context, SPMDAxisContext) and
|
||||
axis_context.manual_axes and
|
||||
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
|
||||
if axis_env.sizes[axis_pos] == 1:
|
||||
return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32)))
|
||||
x = hlo.iota(ir.RankedTensorType.get(
|
||||
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
|
||||
sharding_proto = (
|
||||
|
@ -7667,13 +7667,14 @@ def tril(m: ArrayLike, k: int = 0) -> Array:
|
||||
to sub-diagonal above the main diagonal.
|
||||
|
||||
Returns:
|
||||
An array with same shape as input containing the upper triangle of the given
|
||||
array with elements below the sub-diagonal specified by ``k`` are set to zero.
|
||||
An array with same shape as input containing the lower triangle of the given
|
||||
array with elements above the sub-diagonal specified by ``k`` are set to
|
||||
zero.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.triu`: Returns an upper triangle of an array.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
|
||||
and zeros elsewhere.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
|
||||
diagonal and zeros elsewhere.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([[1, 2, 3, 4],
|
||||
@ -7729,13 +7730,14 @@ def triu(m: ArrayLike, k: int = 0) -> Array:
|
||||
to sub-diagonal above the main diagonal.
|
||||
|
||||
Returns:
|
||||
An array with same shape as input containing the lower triangle of the given
|
||||
array with elements above the sub-diagonal specified by ``k`` are set to zero.
|
||||
An array with same shape as input containing the upper triangle of the given
|
||||
array with elements below the sub-diagonal specified by ``k`` are set to
|
||||
zero.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.tril`: Returns a lower triangle of an array.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the diagonal
|
||||
and zeros elsewhere.
|
||||
- :func:`jax.numpy.tri`: Returns an array with ones on and below the
|
||||
diagonal and zeros elsewhere.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([[1, 2, 3],
|
||||
@ -9501,7 +9503,7 @@ def einsum(
|
||||
subscript: str, /,
|
||||
*operands: ArrayLike,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9514,7 +9516,7 @@ def einsum(
|
||||
axes: Sequence[Any], /,
|
||||
*operands: ArrayLike | Sequence[Any],
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9526,7 +9528,7 @@ def einsum(
|
||||
subscripts, /,
|
||||
*operands,
|
||||
out: None = None,
|
||||
optimize: str | bool | list[tuple[int, ...]] = "optimal",
|
||||
optimize: str | bool | list[tuple[int, ...]] = "auto",
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
@ -9546,10 +9548,10 @@ def einsum(
|
||||
subscripts: string containing axes names separated by commas.
|
||||
*operands: sequence of one or more arrays corresponding to the subscripts.
|
||||
optimize: specify how to optimize the order of computation. In JAX this defaults
|
||||
to ``"optimal"`` which produces optimized expressions via the opt_einsum_
|
||||
to ``"auto"`` which produces optimized expressions via the opt_einsum_
|
||||
package. Other options are ``True`` (same as ``"optimal"``), ``False``
|
||||
(unoptimized), or any string supported by ``opt_einsum``, which
|
||||
includes ``"auto"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
includes ``"optimal"``, ``"greedy"``, ``"eager"``, and others. It may also
|
||||
be a pre-computed path (see :func:`~jax.numpy.einsum_path`).
|
||||
precision: either ``None`` (default), which means the default precision for
|
||||
the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``,
|
||||
|
@ -1068,22 +1068,35 @@ def threefry_2x32(keypair, count):
|
||||
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
||||
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
||||
|
||||
odd_size = count.size % 2
|
||||
if not isinstance(odd_size, int):
|
||||
msg = ("jax.random functions have limited support for shape polymorphism "
|
||||
"when using threefry. "
|
||||
f"In particular, the array size ({count.size}) must be even.")
|
||||
raise core.InconclusiveDimensionOperation(msg)
|
||||
|
||||
if odd_size:
|
||||
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
|
||||
flat_count = count.ravel()
|
||||
odd_size = flat_count.shape[0] % 2
|
||||
if core.is_constant_dim(odd_size):
|
||||
if odd_size:
|
||||
x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2))
|
||||
else:
|
||||
x = list(jnp.split(flat_count, 2))
|
||||
else:
|
||||
x = list(jnp.split(count.ravel(), 2))
|
||||
# With symbolic shapes we cannot always tell statically if odd_size is true
|
||||
# or false, so we rewrite this without a conditional.
|
||||
flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])])
|
||||
flat_count_padded_half_size = flat_count_padded.shape[0] // 2
|
||||
x = [
|
||||
lax.dynamic_slice(flat_count_padded, (0,),
|
||||
(flat_count_padded_half_size,)),
|
||||
lax.dynamic_slice(flat_count_padded,
|
||||
(flat_count_padded_half_size,),
|
||||
(flat_count_padded_half_size,))
|
||||
]
|
||||
assert x[0].shape == x[1].shape, (x[0].shape, x[1].shape)
|
||||
|
||||
x = threefry2x32_p.bind(key1, key2, x[0], x[1])
|
||||
out = jnp.concatenate(x)
|
||||
assert out.dtype == np.uint32
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
if core.is_constant_dim(odd_size):
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
else:
|
||||
out_no_padding = lax.dynamic_slice(out, (0,), (flat_count.shape[0],))
|
||||
return lax.reshape(out_no_padding, count.shape)
|
||||
|
||||
|
||||
def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
|
||||
|
@ -2086,10 +2086,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
|
||||
lambda key, a: jax.random.uniform(key, a.shape, dtype=_f32),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
expect_error=(
|
||||
(core.InconclusiveDimensionOperation,
|
||||
"array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)),
|
||||
polymorphic_shapes=[None, "b0, b1"],
|
||||
override_jax_config_flags=override_jax_config_flags) # type: ignore
|
||||
]
|
||||
for key_size, flags_name, override_jax_config_flags in [
|
||||
|
@ -22,7 +22,6 @@ import warnings
|
||||
|
||||
import jax
|
||||
from jax._src.lib import xla_client
|
||||
from jax.extend import ffi
|
||||
import jax.numpy as jnp
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
@ -54,7 +53,7 @@ P = ParamSpec("P")
|
||||
|
||||
def _event_record(args, *, copy_before):
|
||||
flat_args, treedef = jax.tree.flatten(args)
|
||||
event, *flat_outs = ffi.ffi_call(
|
||||
event, *flat_outs = jax.ffi.ffi_call(
|
||||
"mgpu_event_record",
|
||||
result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args),
|
||||
input_output_aliases={i: i + 1 for i in range(len(flat_args))},
|
||||
@ -63,7 +62,7 @@ def _event_record(args, *, copy_before):
|
||||
|
||||
|
||||
def _event_elapsed(start_event, end_event):
|
||||
return ffi.ffi_call(
|
||||
return jax.ffi.ffi_call(
|
||||
"mgpu_event_elapsed",
|
||||
result_shape_dtypes=jax.core.ShapedArray((), jnp.float32),
|
||||
)(start_event, end_event)
|
||||
|
@ -52,7 +52,7 @@ from jax._src.api import _shared_code_pmap, _prepare_pmap
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, convolution, fft, linalg,
|
||||
special, control_flow, ann)
|
||||
from jax._src.extend import ffi
|
||||
from jax._src import ffi
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
||||
@ -70,6 +70,7 @@ from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
|
||||
generate_key_paths, KeyPath)
|
||||
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
|
||||
global_array_to_host_local_array)
|
||||
from jax._src.pjit import sharding_constraint_p
|
||||
|
||||
P = PartitionSpec
|
||||
|
||||
@ -1130,7 +1131,7 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
|
||||
for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p,
|
||||
control_flow.loops.cumprod_p, control_flow.loops.cummax_p,
|
||||
control_flow.loops.cummin_p]:
|
||||
control_flow.loops.cummin_p, sharding_constraint_p]:
|
||||
register_standard_check(p)
|
||||
register_standard_rewrite(p)
|
||||
|
||||
|
@ -12,13 +12,42 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
from jax._src import ffi as _ffi
|
||||
|
||||
from jax._src.extend.ffi import (
|
||||
ffi_call as ffi_call,
|
||||
ffi_lowering as ffi_lowering,
|
||||
include_dir as include_dir,
|
||||
pycapsule as pycapsule,
|
||||
register_ffi_target as register_ffi_target,
|
||||
)
|
||||
_deprecations = {
|
||||
# Added 2024-12-20
|
||||
"ffi_call": (
|
||||
"jax.extend.ffi.ffi_call is deprecated, use jax.ffi.ffi_call instead.",
|
||||
_ffi.ffi_call,
|
||||
),
|
||||
"ffi_lowering": (
|
||||
"jax.extend.ffi.ffi_lowering is deprecated, use jax.ffi.ffi_lowering instead.",
|
||||
_ffi.ffi_lowering,
|
||||
),
|
||||
"include_dir": (
|
||||
"jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.",
|
||||
_ffi.include_dir,
|
||||
),
|
||||
"pycapsule": (
|
||||
"jax.extend.ffi.pycapsule is deprecated, use jax.ffi.pycapsule instead.",
|
||||
_ffi.pycapsule,
|
||||
),
|
||||
"register_ffi_target": (
|
||||
"jax.extend.ffi.register_ffi_target is deprecated, use jax.ffi.register_ffi_target instead.",
|
||||
_ffi.register_ffi_target,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
ffi_call = _ffi.ffi_call
|
||||
ffi_lowering = _ffi.ffi_lowering
|
||||
include_dir = _ffi.include_dir
|
||||
pycapsule = _ffi.pycapsule
|
||||
register_ffi_target = _ffi.register_ffi_target
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _ffi
|
||||
|
@ -15,18 +15,11 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
import sys as _sys
|
||||
import warnings as _warnings
|
||||
|
||||
import jax.numpy as _array_api
|
||||
|
||||
# Added 2024-08-01
|
||||
_warnings.warn(
|
||||
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
|
||||
"jax.numpy supports the array API by default.",
|
||||
DeprecationWarning, stacklevel=2
|
||||
from jax._src.ffi import (
|
||||
ffi_call as ffi_call,
|
||||
ffi_lowering as ffi_lowering,
|
||||
include_dir as include_dir,
|
||||
pycapsule as pycapsule,
|
||||
register_ffi_target as register_ffi_target,
|
||||
register_ffi_type_id as register_ffi_type_id,
|
||||
)
|
||||
|
||||
_sys.modules['jax.experimental.array_api'] = _array_api
|
||||
|
||||
del _array_api, _sys, _warnings
|
@ -46,8 +46,8 @@ from jax._src.scipy.special import (
|
||||
log_softmax as log_softmax,
|
||||
logit as logit,
|
||||
logsumexp as logsumexp,
|
||||
lpmn as lpmn,
|
||||
lpmn_values as lpmn_values,
|
||||
lpmn as _deprecated_lpmn,
|
||||
lpmn_values as _deprecated_lpmn_values,
|
||||
multigammaln as multigammaln,
|
||||
ndtr as ndtr,
|
||||
ndtri as ndtri,
|
||||
@ -65,3 +65,25 @@ from jax._src.scipy.special import (
|
||||
from jax._src.third_party.scipy.special import (
|
||||
fresnel as fresnel,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added Jan 3 2024
|
||||
"lpmn": (
|
||||
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
|
||||
_deprecated_lpmn,
|
||||
),
|
||||
"lpmn_values": (
|
||||
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
|
||||
_deprecated_lpmn_values,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
lpmn = _deprecated_lpmn
|
||||
lpmn_values = _deprecated_lpmn_values
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
@ -93,5 +93,11 @@ def initialize():
|
||||
)
|
||||
for _name, _value in cuda_plugin_extension.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
xla_client.register_custom_type_id_handler(
|
||||
"CUDA",
|
||||
functools.partial(
|
||||
cuda_plugin_extension.register_custom_type_id, c_api
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.warning('cuda_plugin_extension is not found.')
|
||||
|
@ -94,5 +94,11 @@ def initialize():
|
||||
)
|
||||
for _name, _value in rocm_plugin_extension.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
xla_client.register_custom_type_id_handler(
|
||||
"ROCM",
|
||||
functools.partial(
|
||||
rocm_plugin_extension.register_custom_type_id, c_api
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.warning('rocm_plugin_extension is not found.')
|
||||
|
11
jaxlib/BUILD
11
jaxlib/BUILD
@ -16,8 +16,8 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"nanobind_extension",
|
||||
"py_library_providing_imports_info",
|
||||
"pybind_extension",
|
||||
"pytype_library",
|
||||
)
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_files")
|
||||
@ -198,7 +198,7 @@ cc_library(
|
||||
|
||||
# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong
|
||||
# target architecture.
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "cpu_feature_guard",
|
||||
srcs = ["cpu_feature_guard.c"],
|
||||
module_name = "cpu_feature_guard",
|
||||
@ -207,7 +207,7 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "utils",
|
||||
srcs = ["utils.cc"],
|
||||
module_name = "utils",
|
||||
@ -238,6 +238,7 @@ cc_library(
|
||||
"@xla//xla:util",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/pjrt:status_casters",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
|
||||
@ -246,7 +247,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "cuda_plugin_extension",
|
||||
srcs = ["cuda_plugin_extension.cc"],
|
||||
module_name = "cuda_plugin_extension",
|
||||
@ -262,7 +263,7 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "rocm_plugin_extension",
|
||||
srcs = ["rocm_plugin_extension.cc"],
|
||||
module_name = "rocm_plugin_extension",
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"pybind_extension",
|
||||
"nanobind_extension",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -53,7 +53,7 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_lapack",
|
||||
srcs = ["lapack.cc"],
|
||||
copts = [
|
||||
|
@ -758,6 +758,35 @@ static ffi::Error SvdKernel(
|
||||
work_data.get(), &workspace_dim_v,
|
||||
iwork_data.get(), info_data);
|
||||
}
|
||||
|
||||
// Suppress MSAN warnings when using a copy of LAPACK uninstrumented by
|
||||
// MSAN.
|
||||
using T [[maybe_unused]] = typename svd::SVDType<dtype>::ValueType;
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(*info_data));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data,
|
||||
x_cols_v * x_leading_dim_v * sizeof(T));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
singular_values_data, std::min(x_rows_v, x_cols_v) * sizeof(RealType));
|
||||
if (mode_v == 'A') {
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
u_data, u_leading_dim_v * x_rows_v * sizeof(T));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
|
||||
} else if (mode_v == 'O') {
|
||||
if (x_rows_v < x_cols_v) {
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
u_data, u_leading_dim_v * x_rows_v * sizeof(T));
|
||||
} else {
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
|
||||
}
|
||||
} else if (mode_v == 'S') {
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
u_data, u_leading_dim_v * std::min(x_rows_v, x_cols_v) * sizeof(T));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(
|
||||
vt_data, vt_leading_dim_v * x_cols_v * sizeof(T));
|
||||
}
|
||||
|
||||
x_out_data += x_out_step;
|
||||
singular_values_data += singular_values_step;
|
||||
u_data += u_step;
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either
|
||||
// by the pybind wrapper that links them to an existing SciPy lapack instance,
|
||||
// by the nanobind wrapper that links them to an existing SciPy lapack instance,
|
||||
// or using the lapack_kernels_strong.cc static initialization to link them
|
||||
// directly to lapack for use in a pure C++ context.
|
||||
|
||||
|
@ -19,7 +19,7 @@ load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"cuda_library",
|
||||
"if_cuda_is_configured",
|
||||
"pybind_extension",
|
||||
"nanobind_extension",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -125,7 +125,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
copts = [
|
||||
@ -173,7 +173,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_rnn",
|
||||
srcs = ["//jaxlib/gpu:rnn.cc"],
|
||||
copts = [
|
||||
@ -265,7 +265,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_solver",
|
||||
srcs = ["//jaxlib/gpu:solver.cc"],
|
||||
copts = [
|
||||
@ -321,7 +321,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_sparse",
|
||||
srcs = ["//jaxlib/gpu:sparse.cc"],
|
||||
copts = [
|
||||
@ -399,7 +399,7 @@ cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_linalg",
|
||||
srcs = ["//jaxlib/gpu:linalg.cc"],
|
||||
copts = [
|
||||
@ -457,7 +457,7 @@ cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_prng",
|
||||
srcs = ["//jaxlib/gpu:prng.cc"],
|
||||
copts = [
|
||||
@ -497,7 +497,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_hybrid",
|
||||
srcs = ["//jaxlib/gpu:hybrid.cc"],
|
||||
copts = [
|
||||
@ -588,7 +588,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_triton",
|
||||
srcs = ["//jaxlib/gpu:triton.cc"],
|
||||
copts = [
|
||||
@ -639,7 +639,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_versions",
|
||||
srcs = ["versions.cc"],
|
||||
copts = [
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||
#include "xla/pjrt/status_casters.h"
|
||||
@ -44,21 +45,14 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
|
||||
size_t fn_name_size, nb::object fn,
|
||||
int api_version,
|
||||
XLA_FFI_Handler_Traits traits) {
|
||||
if (c_api->extension_start == nullptr) {
|
||||
return Unimplemented("The plugin does not have extension.");
|
||||
}
|
||||
const PJRT_Extension_Base* next =
|
||||
reinterpret_cast<const PJRT_Extension_Base*>(c_api->extension_start);
|
||||
while (next != nullptr &&
|
||||
next->type !=
|
||||
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
|
||||
next = next->next;
|
||||
}
|
||||
if (next == nullptr) {
|
||||
const PJRT_Gpu_Custom_Call* custom_call_ext =
|
||||
pjrt::FindExtension<PJRT_Gpu_Custom_Call>(
|
||||
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call);
|
||||
if (custom_call_ext == nullptr) {
|
||||
return Unimplemented("The plugin does not have a custom call extension.");
|
||||
}
|
||||
PJRT_Gpu_Register_Custom_Call* register_custom_call =
|
||||
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call;
|
||||
custom_call_ext->custom_call;
|
||||
|
||||
if (traits != 0) {
|
||||
return Unimplemented("The plugin does not support custom call traits.");
|
||||
@ -137,6 +131,34 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
|
||||
#endif
|
||||
}
|
||||
|
||||
absl::Status RegisterCustomTypeId(const PJRT_Api* c_api,
|
||||
const char* type_name_c_str,
|
||||
size_t type_name_size, nb::object type_id) {
|
||||
const PJRT_FFI_Extension* ffi_ext = pjrt::FindExtension<PJRT_FFI_Extension>(
|
||||
c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI);
|
||||
if (ffi_ext == nullptr) {
|
||||
return Unimplemented("The plugin does not have the FFI extension.");
|
||||
}
|
||||
|
||||
PJRT_FFI_TypeID_Register_Args args;
|
||||
args.struct_size = PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE;
|
||||
args.type_name = type_name_c_str;
|
||||
args.type_name_size = type_name_size;
|
||||
RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_id_register(&args), c_api);
|
||||
|
||||
nb::capsule capsule;
|
||||
if (!nb::try_cast<nb::capsule>(type_id, capsule)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The type_id argument to register_custom_call_type_id must be a "
|
||||
"PyCapsule object holding a pointer to a XLA_FFI_TypeId.");
|
||||
}
|
||||
XLA_FFI_TypeId* type_id_ptr =
|
||||
reinterpret_cast<XLA_FFI_TypeId*>(static_cast<void*>(capsule.data()));
|
||||
type_id_ptr->type_id = args.type_id;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict["xla_python_gpu_callback"] =
|
||||
@ -171,6 +193,16 @@ void BuildGpuPluginExtension(nanobind::module_& m) {
|
||||
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,
|
||||
nb::arg("traits") = 0);
|
||||
m.def(
|
||||
"register_custom_type_id",
|
||||
[](nb::capsule c_api, nb::str type_name_py, nb::object type_id) {
|
||||
const char* type_name_c_str = type_name_py.c_str();
|
||||
size_t type_name_size = nb::len(type_name_py);
|
||||
xla::ThrowIfError(RegisterCustomTypeId(
|
||||
static_cast<const PJRT_Api*>(c_api.data()), type_name_c_str,
|
||||
type_name_size, std::move(type_id)));
|
||||
},
|
||||
nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id"));
|
||||
m.def("registrations", &Registrations);
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ cc_proto_library = _cc_proto_library
|
||||
cuda_library = _cuda_library
|
||||
rocm_library = _rocm_library
|
||||
pytype_test = native.py_test
|
||||
pybind_extension = _pybind_extension
|
||||
nanobind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
if_rocm_is_configured = _if_rocm_is_configured
|
||||
if_windows = _if_windows
|
||||
@ -120,7 +120,7 @@ def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pyt
|
||||
lib_rule(name = name, **kwargs)
|
||||
|
||||
def py_extension(name, srcs, copts, deps, linkopts = []):
|
||||
pybind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
|
||||
nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
|
||||
|
||||
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []):
|
||||
"""Workaround DLL building issue.
|
||||
@ -399,6 +399,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
|
||||
jax_test_file_visibility = []
|
||||
|
||||
jax_export_file_visibility = []
|
||||
|
||||
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
|
||||
pass
|
||||
|
||||
|
@ -15,8 +15,8 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
"nanobind_extension",
|
||||
"py_extension",
|
||||
"pybind_extension",
|
||||
"windows_cc_shared_mlir_library",
|
||||
)
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
||||
@ -190,7 +190,7 @@ py_extension(
|
||||
)
|
||||
|
||||
# This target contains the extension and it's Python dependencies, which are not
|
||||
# supported by the `py_extension`/`pybind_extension` macros.
|
||||
# supported by the `py_extension`/`nanobind_extension` macros.
|
||||
py_library(
|
||||
name = "_tpu_ext_lib",
|
||||
deps = [
|
||||
@ -200,7 +200,7 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_triton_ext",
|
||||
srcs = ["triton_ext.cc"],
|
||||
copts = COPTS,
|
||||
|
@ -67,12 +67,12 @@ constexpr const char IR_MODULE[] = "jaxlib.mlir.ir";
|
||||
// TODO(tlongeri): Get rid of this somehow
|
||||
constexpr MlirTpuI64TargetTuple TARGET_SHAPE{8, 128};
|
||||
|
||||
// TODO(tlongeri): Add type annotations from pybind11/typing.h once there is
|
||||
// TODO(tlongeri): Add type annotations via nanobind once there is
|
||||
// a release for it (and maybe add a custom Sequence<T> one as well).
|
||||
|
||||
// TODO(tlongeri): For our use-case, we don't really need C++ exceptions - just
|
||||
// setting the exception object and returning NULL to Python should suffice, but
|
||||
// not sure if this is possible with pybind.
|
||||
// not sure if this is possible with nanobind.
|
||||
class NotImplementedException : public std::runtime_error {
|
||||
using runtime_error::runtime_error;
|
||||
};
|
||||
|
@ -13,6 +13,8 @@
|
||||
// NOLINTNEXTLINE(misc-include-cleaner)
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
// NOLINTNEXTLINE(misc-include-cleaner)
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
@ -23,6 +25,7 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/include/mlir/IR/AffineExpr.h"
|
||||
@ -294,8 +297,13 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
|
||||
auto element_type = ty.getElementType();
|
||||
// PowFOp and DivFOp do not seem to be supported in bf16 on later
|
||||
// hardware.
|
||||
// There's an annoying hodgepodge of elementwise ops that need to be
|
||||
// rewritten to f32 on later hardware.
|
||||
// TODO(mvoz): Look into (1) what it would take to support these ops
|
||||
// natively on later hardware, and (2) how to better organize this list.
|
||||
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
|
||||
isa<arith::DivFOp>(op);
|
||||
isa<arith::DivFOp>(op) || isa<math::TanhOp>(op) ||
|
||||
isa<math::ExpOp>(op) || isa<math::LogOp>(op);
|
||||
if (needs_cast && element_type.isBF16()) {
|
||||
auto target_f32 =
|
||||
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
|
||||
@ -552,7 +560,10 @@ const llvm::StringSet<> &elementwise_convertible_ops() {
|
||||
arith::SubFOp::getOperationName(),
|
||||
arith::MaximumFOp::getOperationName(),
|
||||
arith::MinimumFOp::getOperationName(),
|
||||
math::PowFOp::getOperationName()};
|
||||
math::PowFOp::getOperationName(),
|
||||
math::TanhOp::getOperationName(),
|
||||
math::ExpOp::getOperationName(),
|
||||
math::LogOp::getOperationName()};
|
||||
return *ops;
|
||||
}
|
||||
|
||||
|
@ -57,6 +57,9 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
auto src_int_vty = make_vty(src.bitwidth());
|
||||
auto dst_int_vty = make_vty(dst.bitwidth());
|
||||
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
|
||||
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
|
||||
// extSI or truncI below, we can reuse the inferExt and inferTrunc from
|
||||
// infer-vector-layout pass.
|
||||
auto dst_bitwidth_layout = VectorLayout(
|
||||
dst.bitwidth(),
|
||||
{
|
||||
@ -66,6 +69,12 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
: LayoutOffset(),
|
||||
},
|
||||
src.tiling(), src.implicit_dim());
|
||||
if (!dst_bitwidth_layout.isValid(target_shape)) {
|
||||
return emitError(v.getLoc(),
|
||||
"Not implemented: failed to infer valid layout during "
|
||||
"relayout, got ")
|
||||
<< dst_bitwidth_layout;
|
||||
}
|
||||
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
|
||||
setLayout(ext_op, src, src);
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("//jaxlib:jax.bzl", "pybind_extension")
|
||||
load("//jaxlib:jax.bzl", "nanobind_extension")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
@ -171,7 +171,7 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_mosaic_gpu_ext",
|
||||
srcs = ["mosaic_gpu_ext.cc"],
|
||||
copts = [
|
||||
|
@ -18,7 +18,7 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_rocm_is_configured",
|
||||
"pybind_extension",
|
||||
"nanobind_extension",
|
||||
"rocm_library",
|
||||
)
|
||||
|
||||
@ -113,7 +113,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
copts = [
|
||||
@ -206,7 +206,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_solver",
|
||||
srcs = ["//jaxlib/gpu:solver.cc"],
|
||||
copts = [
|
||||
@ -252,7 +252,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_sparse",
|
||||
srcs = ["//jaxlib/gpu:sparse.cc"],
|
||||
copts = [
|
||||
@ -317,7 +317,7 @@ rocm_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_linalg",
|
||||
srcs = ["//jaxlib/gpu:linalg.cc"],
|
||||
copts = [
|
||||
@ -370,7 +370,7 @@ rocm_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_prng",
|
||||
srcs = ["//jaxlib/gpu:prng.cc"],
|
||||
copts = [
|
||||
@ -410,7 +410,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_hybrid",
|
||||
srcs = ["//jaxlib/gpu:hybrid.cc"],
|
||||
copts = [
|
||||
@ -472,7 +472,7 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
nanobind_extension(
|
||||
name = "_triton",
|
||||
srcs = ["//jaxlib/gpu:triton.cc"],
|
||||
copts = [
|
||||
|
@ -59,7 +59,6 @@ jax_py_test(
|
||||
srcs = ["array_api_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:experimental_array_api",
|
||||
"//jax:test_util",
|
||||
] + py_deps("absl/testing"),
|
||||
)
|
||||
@ -142,6 +141,13 @@ jax_multiplatform_test(
|
||||
deps = ["//jax:extend"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "ffi_test",
|
||||
srcs = ["ffi_test.py"],
|
||||
# TODO(dfm): Remove after removal of jex.ffi imports.
|
||||
deps = ["//jax:extend"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "fft_test",
|
||||
srcs = ["fft_test.py"],
|
||||
|
@ -246,12 +246,6 @@ class ArrayAPISmokeTest(absltest.TestCase):
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE)
|
||||
|
||||
def test_deprecated_import(self):
|
||||
msg = "jax.experimental.array_api import is no longer required"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
import jax.experimental.array_api as nx
|
||||
self.assertIs(nx, ARRAY_API_NAMESPACE)
|
||||
|
||||
|
||||
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -1133,9 +1133,9 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
f(np.arange(8, dtype=jnp.float32))
|
||||
self.assertTrue(is_called)
|
||||
|
||||
def test_inspect_sharding_3d_input_pos_sharding(self):
|
||||
def test_inspect_sharding_3d_jit(self):
|
||||
def _cb(sd):
|
||||
self.assertIsInstance(sd, jax.sharding.PositionalSharding)
|
||||
self.assertIsInstance(sd, jax.sharding.NamedSharding)
|
||||
self.assertLen(sd.device_set, 2)
|
||||
|
||||
def f_(x):
|
||||
@ -1149,7 +1149,7 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
|
||||
f(arr)
|
||||
|
||||
def test_inspect_sharding_3d_input_named_sharding(self):
|
||||
def test_inspect_sharding_3d_pjit(self):
|
||||
def _cb(sd):
|
||||
self.assertIsInstance(sd, jax.sharding.NamedSharding)
|
||||
self.assertLen(sd.device_set, 2)
|
||||
|
@ -12,35 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.extend as jex
|
||||
import jax.numpy as jnp
|
||||
import jax.sharding as shd
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -110,294 +94,6 @@ class RandomTest(jtu.JaxTestCase):
|
||||
self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})")
|
||||
|
||||
|
||||
class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
def find_custom_call_in_module(self, module):
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
return op
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
|
||||
def testHeadersExist(self):
|
||||
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
|
||||
for header in ["c_api.h", "api.h", "ffi.h"]:
|
||||
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
|
||||
|
||||
@parameterized.parameters([
|
||||
(tuple(range(3)), tuple(range(3))),
|
||||
(None, tuple(reversed(range(3)))),
|
||||
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
|
||||
])
|
||||
def testLoweringLayouts(self, layout_spec, expected_layout):
|
||||
# Regression test to ensure that the lowering rule properly captures
|
||||
# layouts.
|
||||
def lowering_rule(ctx, x):
|
||||
aval, = ctx.avals_in
|
||||
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
|
||||
result_layouts=[layout_spec])(ctx, x)
|
||||
prim = core.Primitive("test_ffi")
|
||||
prim.def_impl(lambda x: x)
|
||||
prim.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(prim, lowering_rule)
|
||||
|
||||
x = jnp.ones((3,) * len(expected_layout))
|
||||
lowered = jax.jit(prim.bind).lower(x)
|
||||
module = lowered.compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertIn("operand_layouts", op.attributes)
|
||||
self.assertIn("result_layouts", op.attributes)
|
||||
|
||||
text = lowered.as_text()
|
||||
expected = ", ".join(map(str, expected_layout))
|
||||
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
|
||||
@parameterized.parameters([
|
||||
(True, mlir.ir.BoolAttr.get),
|
||||
(1, mlir.i64_attr),
|
||||
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
|
||||
("param", mlir.ir.StringAttr.get),
|
||||
(np.float32(0.5),
|
||||
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
|
||||
])
|
||||
def testParams(self, param, expected_builder):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x)(x, param=param)
|
||||
|
||||
# Here we inspect the lowered IR to test that the parameter has been
|
||||
# serialized with the appropriate type.
|
||||
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
config = op.attributes["mhlo.backend_config"]
|
||||
self.assertIsInstance(config, mlir.ir.DictAttr)
|
||||
self.assertIn("param", config)
|
||||
with mlir.make_ir_context(), mlir.ir.Location.unknown():
|
||||
expected = expected_builder(param)
|
||||
self.assertEqual(type(config["param"]), type(expected))
|
||||
self.assertTrue(expected.type.isinstance(config["param"].type))
|
||||
|
||||
def testToken(self):
|
||||
def fun():
|
||||
token = lax.create_token()
|
||||
return jex.ffi.ffi_call("test_ffi", core.abstract_token)(token)
|
||||
|
||||
# Ensure that token inputs and outputs are translated to the correct type
|
||||
module = jax.jit(fun).lower().compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
|
||||
def testEffectsHlo(self):
|
||||
# The target name must exist on the current platform, but we don't actually
|
||||
# need to call it with the correct syntax, because we're only checking the
|
||||
# compiled HLO.
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
target_name = "lapack_sgetrf_ffi"
|
||||
elif jtu.test_device_matches(["rocm"]):
|
||||
target_name = "hipsolver_getrf_ffi"
|
||||
elif jtu.test_device_matches(["cuda", "gpu"]):
|
||||
target_name = "cusolver_getrf_ffi"
|
||||
else:
|
||||
raise unittest.SkipTest("Unsupported device")
|
||||
def fun():
|
||||
jex.ffi.ffi_call(target_name, (), has_side_effect=True)()
|
||||
hlo = jax.jit(fun).lower()
|
||||
self.assertIn(target_name, hlo.as_text())
|
||||
self.assertIn("has_side_effect = true", hlo.as_text())
|
||||
self.assertIn(target_name, hlo.compile().as_text())
|
||||
|
||||
def testJvpError(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The FFI call to `.+` cannot be differentiated."):
|
||||
jax.jvp(fun, (0.5,), (0.5,))
|
||||
|
||||
def testNonHashableAttributes(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
|
||||
|
||||
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertIn("non_hashable_arg = {a = 1", hlo)
|
||||
|
||||
# If non-hashable arguments aren't handled properly, this will raise a
|
||||
# TypeError. We make sure it doesn't.
|
||||
with self.assertRaises(Exception) as manager:
|
||||
fun(jnp.ones(5))
|
||||
self.assertNotIsInstance(manager.exception, TypeError)
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
|
||||
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
|
||||
with self.assertRaises(Exception) as manager:
|
||||
fun(jnp.ones(5))
|
||||
self.assertNotIsInstance(manager.exception, TypeError)
|
||||
|
||||
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testFfiCall(self, shape):
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = ffi_call_geqrf(x)
|
||||
for a, b in zip(actual, expected):
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(6, 5), (4, 5, 6)],
|
||||
vmap_method=["expand_dims", "broadcast_all", "sequential"],
|
||||
)
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testFfiCallBatching(self, shape, vmap_method):
|
||||
shape = (10,) + shape
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
|
||||
for a, b in zip(actual, expected):
|
||||
if vmap_method == "sequential" and len(shape) == 3:
|
||||
# On GPU, the batched FFI call to geqrf uses an algorithm with
|
||||
# different numerics than the unbatched version (which is used when
|
||||
# vmap_method="sequential"). Therefore, we need to include floating
|
||||
# point tolerance for this check.
|
||||
self.assertArraysAllClose(a, b)
|
||||
else:
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testVectorizedDeprecation(self):
|
||||
x = self.rng().randn(3, 5, 4).astype(np.float32)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
ffi_call_geqrf(x, vectorized=True)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.vmap(ffi_call_geqrf)(x)
|
||||
|
||||
def testBackwardCompatSyntax(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
if deprecations.is_accelerated("jax-ffi-call-args"):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def testInputOutputAliases(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
|
||||
|
||||
def testInvalidInputOutputAliases(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with input index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with output index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
|
||||
x.dtype),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def testLegacyBackendConfig(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, custom_call_api_version=2,
|
||||
legacy_backend_config="12345")(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, 'backend_config = "12345"')
|
||||
|
||||
def testInvalidBackendConfig(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of the legacy_backend_config"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x,
|
||||
custom_call_api_version=2)(x, attribute=1)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of ffi_call attributes requires"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def testAllow64(self):
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("Requires enable_x64=False")
|
||||
def fun():
|
||||
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
|
||||
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
|
||||
|
||||
def testInvalidResultType(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "All elements of result_shape_dtypes.*position 0"):
|
||||
jex.ffi.ffi_call("test", None)()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "All elements of result_shape_dtypes.*position 1"):
|
||||
jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testShardMap(self):
|
||||
mesh = jtu.create_mesh((1,), ("i",))
|
||||
x = self.rng().randn(8, 4, 5).astype(np.float32)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
|
||||
out_specs=shd.PartitionSpec('i'))
|
||||
def f(x):
|
||||
return ffi_call_geqrf(x)
|
||||
|
||||
f(x) # eager mode doesn't crash
|
||||
jax.jit(f)(x) # neither does JIT
|
||||
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
|
||||
|
||||
|
||||
def ffi_call_geqrf(x, **kwargs):
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
lapack._lapack.initialize()
|
||||
|
||||
assert x.dtype == np.float32
|
||||
ndim = x.ndim
|
||||
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
|
||||
output_types = [
|
||||
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
|
||||
|
||||
def call(platform, x):
|
||||
target_name = dict(
|
||||
cpu="lapack_sgeqrf_ffi",
|
||||
rocm="hipsolver_geqrf_ffi",
|
||||
cuda="cusolver_geqrf_ffi",
|
||||
)[platform]
|
||||
return jex.ffi.ffi_call(
|
||||
target_name, output_types, input_output_aliases={0: 0},
|
||||
input_layouts=[x_major_to_minor],
|
||||
output_layouts=[x_major_to_minor, None],
|
||||
**kwargs)(x)
|
||||
|
||||
return lax.platform_dependent(
|
||||
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
|
||||
cuda=partial(call, "cuda"))
|
||||
|
||||
|
||||
class MlirRegisterLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def test_unknown_platform_error(self):
|
||||
|
337
tests/ffi_test.py
Normal file
337
tests/ffi_test.py
Normal file
@ -0,0 +1,337 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.extend as jex
|
||||
import jax.numpy as jnp
|
||||
import jax.sharding as shd
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
def find_custom_call_in_module(self, module):
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
return op
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
|
||||
def test_headers_exist(self):
|
||||
base_dir = os.path.join(jax.ffi.include_dir(), "xla", "ffi", "api")
|
||||
for header in ["c_api.h", "api.h", "ffi.h"]:
|
||||
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
|
||||
|
||||
@parameterized.parameters([
|
||||
(tuple(range(3)), tuple(range(3))),
|
||||
(None, tuple(reversed(range(3)))),
|
||||
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
|
||||
])
|
||||
def test_lowering_layouts(self, layout_spec, expected_layout):
|
||||
# Regression test to ensure that the lowering rule properly captures
|
||||
# layouts.
|
||||
def lowering_rule(ctx, x):
|
||||
return jax.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
|
||||
result_layouts=[layout_spec])(ctx, x)
|
||||
prim = core.Primitive("test_ffi")
|
||||
prim.def_impl(lambda x: x)
|
||||
prim.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(prim, lowering_rule)
|
||||
|
||||
x = jnp.ones((3,) * len(expected_layout))
|
||||
lowered = jax.jit(prim.bind).lower(x)
|
||||
module = lowered.compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertIn("operand_layouts", op.attributes)
|
||||
self.assertIn("result_layouts", op.attributes)
|
||||
|
||||
text = lowered.as_text()
|
||||
expected = ", ".join(map(str, expected_layout))
|
||||
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
|
||||
@parameterized.parameters([
|
||||
(True, mlir.ir.BoolAttr.get),
|
||||
(1, mlir.i64_attr),
|
||||
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
|
||||
("param", mlir.ir.StringAttr.get),
|
||||
(np.float32(0.5),
|
||||
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
|
||||
])
|
||||
def test_params(self, param, expected_builder):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x)(x, param=param)
|
||||
|
||||
# Here we inspect the lowered IR to test that the parameter has been
|
||||
# serialized with the appropriate type.
|
||||
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
config = op.attributes["mhlo.backend_config"]
|
||||
self.assertIsInstance(config, mlir.ir.DictAttr)
|
||||
self.assertIn("param", config)
|
||||
with mlir.make_ir_context(), mlir.ir.Location.unknown():
|
||||
expected = expected_builder(param)
|
||||
self.assertEqual(type(config["param"]), type(expected))
|
||||
self.assertTrue(expected.type.isinstance(config["param"].type))
|
||||
|
||||
def test_token(self):
|
||||
def fun():
|
||||
token = lax.create_token()
|
||||
return jax.ffi.ffi_call("test_ffi", core.abstract_token)(token)
|
||||
|
||||
# Ensure that token inputs and outputs are translated to the correct type
|
||||
module = jax.jit(fun).lower().compiler_ir("stablehlo")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
|
||||
def test_effects_hlo(self):
|
||||
# The target name must exist on the current platform, but we don't actually
|
||||
# need to call it with the correct syntax, because we're only checking the
|
||||
# compiled HLO.
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
target_name = "lapack_sgetrf_ffi"
|
||||
elif jtu.test_device_matches(["rocm"]):
|
||||
target_name = "hipsolver_getrf_ffi"
|
||||
elif jtu.test_device_matches(["cuda", "gpu"]):
|
||||
target_name = "cusolver_getrf_ffi"
|
||||
else:
|
||||
raise unittest.SkipTest("Unsupported device")
|
||||
def fun():
|
||||
jax.ffi.ffi_call(target_name, (), has_side_effect=True)()
|
||||
hlo = jax.jit(fun).lower()
|
||||
self.assertIn(target_name, hlo.as_text())
|
||||
self.assertIn("has_side_effect = true", hlo.as_text())
|
||||
self.assertIn(target_name, hlo.compile().as_text())
|
||||
|
||||
def test_jvp_error(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The FFI call to `.+` cannot be differentiated."):
|
||||
jax.jvp(fun, (0.5,), (0.5,))
|
||||
|
||||
def test_non_hashable_attributes(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
|
||||
|
||||
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertIn("non_hashable_arg = {a = 1", hlo)
|
||||
|
||||
# If non-hashable arguments aren't handled properly, this will raise a
|
||||
# TypeError. We make sure it doesn't.
|
||||
with self.assertRaises(Exception) as manager:
|
||||
fun(jnp.ones(5))
|
||||
self.assertNotIsInstance(manager.exception, TypeError)
|
||||
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
|
||||
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
|
||||
with self.assertRaises(Exception) as manager:
|
||||
fun(jnp.ones(5))
|
||||
self.assertNotIsInstance(manager.exception, TypeError)
|
||||
|
||||
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_ffi_call(self, shape):
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = ffi_call_geqrf(x)
|
||||
for a, b in zip(actual, expected):
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(6, 5), (4, 5, 6)],
|
||||
vmap_method=["expand_dims", "broadcast_all", "sequential"],
|
||||
)
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_ffi_call_batching(self, shape, vmap_method):
|
||||
shape = (10,) + shape
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
|
||||
for a, b in zip(actual, expected):
|
||||
if vmap_method == "sequential" and len(shape) == 3:
|
||||
# On GPU, the batched FFI call to geqrf uses an algorithm with
|
||||
# different numerics than the unbatched version (which is used when
|
||||
# vmap_method="sequential"). Therefore, we need to include floating
|
||||
# point tolerance for this check.
|
||||
self.assertArraysAllClose(a, b)
|
||||
else:
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_vectorized_deprecation(self):
|
||||
x = self.rng().randn(3, 5, 4).astype(np.float32)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
ffi_call_geqrf(x, vectorized=True)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.vmap(ffi_call_geqrf)(x)
|
||||
|
||||
def test_backward_compat_syntax(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
if deprecations.is_accelerated("jax-ffi-call-args"):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def test_input_output_aliases(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
|
||||
|
||||
def test_invalid_input_output_aliases(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with input index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with output index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
|
||||
x.dtype),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def test_legacy_backend_config(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, custom_call_api_version=2,
|
||||
legacy_backend_config="12345")(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, 'backend_config = "12345"')
|
||||
|
||||
def test_invalid_backend_config(self):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of the legacy_backend_config"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test", x,
|
||||
custom_call_api_version=2)(x, attribute=1)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of ffi_call attributes requires"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def test_allow_x64(self):
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("Requires enable_x64=False")
|
||||
def fun():
|
||||
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
|
||||
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
|
||||
|
||||
def test_invalid_result_type(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "All elements of result_shape_dtypes.*position 0"):
|
||||
jax.ffi.ffi_call("test", None)()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "All elements of result_shape_dtypes.*position 1"):
|
||||
jax.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_shard_map(self):
|
||||
mesh = jtu.create_mesh((1,), ("i",))
|
||||
x = self.rng().randn(8, 4, 5).astype(np.float32)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
|
||||
out_specs=shd.PartitionSpec('i'))
|
||||
def f(x):
|
||||
return ffi_call_geqrf(x)
|
||||
|
||||
f(x) # eager mode doesn't crash
|
||||
jax.jit(f)(x) # neither does JIT
|
||||
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
def test_extend_import_shim(self):
|
||||
ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True)
|
||||
|
||||
|
||||
def ffi_call_geqrf(x, _use_extend=False, **kwargs):
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
lapack._lapack.initialize()
|
||||
|
||||
assert x.dtype == np.float32
|
||||
ndim = x.ndim
|
||||
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
|
||||
output_types = [
|
||||
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
|
||||
|
||||
def call(platform, x):
|
||||
target_name = dict(
|
||||
cpu="lapack_sgeqrf_ffi",
|
||||
rocm="hipsolver_geqrf_ffi",
|
||||
cuda="cusolver_geqrf_ffi",
|
||||
)[platform]
|
||||
f = jex.ffi.ffi_call if _use_extend else jax.ffi.ffi_call
|
||||
return f(
|
||||
target_name, output_types, input_output_aliases={0: 0},
|
||||
input_layouts=[x_major_to_minor],
|
||||
output_layouts=[x_major_to_minor, None],
|
||||
**kwargs)(x)
|
||||
|
||||
return lax.platform_dependent(
|
||||
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
|
||||
cuda=partial(call, "cuda"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -327,12 +327,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1E-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-8)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
l_max=[1, 2, 3, 6],
|
||||
shape=[(5,), (10,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
|
||||
def testLpmn(self, l_max, shape, dtype):
|
||||
if jtu.is_device_tpu(6, "e"):
|
||||
self.skipTest("TODO(b/364258243): fails on TPU v6e")
|
||||
@ -350,12 +350,12 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
atol=3e-3, check_dtypes=False)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
l_max=[3, 4, 6, 32],
|
||||
shape=[(2,), (3,), (4,), (64,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
|
||||
def testNormalizedLpmnValues(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -383,7 +383,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
rtol=1e-5, atol=1e-5, check_dtypes=False)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmAccuracy(self):
|
||||
m = jnp.arange(-3, 3)[:, None]
|
||||
@ -398,6 +399,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderZeroDegreeZero(self):
|
||||
"""Tests the spherical harmonics of order zero and degree zero."""
|
||||
@ -411,6 +414,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderZeroDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree zero."""
|
||||
@ -424,6 +429,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderOneDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree one."""
|
||||
@ -445,6 +452,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
],
|
||||
dtype=jtu.dtypes.all_integer,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
|
||||
"""Tests against JIT compatibility and Numpy."""
|
||||
@ -469,6 +478,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
with self.subTest('Test against numpy.'):
|
||||
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmCornerCaseWithWrongNmax(self):
|
||||
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
|
||||
|
@ -200,16 +200,11 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
"3*a*mod(a + 2, b + 2)"),
|
||||
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2,
|
||||
"6*floordiv(a + 2, b + 2)"),
|
||||
# Keep for backwards compatibility. We ought to be able to parse
|
||||
# non_negative
|
||||
("non_negative(a - 2)", "build_inside", "max(a - 2, 0)"),
|
||||
("max(a, b)", "build_inside", "max(a, b)"),
|
||||
("min(a, b)", "build_inside", "min(a, b)"),
|
||||
]])
|
||||
def test_parse_dim(self, dim_spec, dim_poly, expected_str):
|
||||
if dim_spec == "non_negative(a - 2)":
|
||||
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
|
||||
elif dim_spec == "max(a, b)":
|
||||
if dim_spec == "max(a, b)":
|
||||
dim_poly = core.max_dim(DimExprTest.a, DimExprTest.b)
|
||||
elif dim_spec == "min(a, b)":
|
||||
dim_poly = core.min_dim(DimExprTest.a, DimExprTest.b)
|
||||
@ -382,13 +377,6 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
[b * (a % 4), b * (a // 4), a * (a // 4), a // 4,
|
||||
a * a, b, 15])
|
||||
|
||||
# This failed with a previous implementation of factor equality
|
||||
self.assertNotEqual(shape_poly._DimTerm.from_operation(shape_poly._DimFactor.NON_NEGATIVE,
|
||||
a - b - 1,
|
||||
scope=a.scope),
|
||||
shape_poly._DimTerm.from_operation(shape_poly._DimFactor.NON_NEGATIVE,
|
||||
a - 2 * b - 1,
|
||||
scope=a.scope))
|
||||
def test_bounds_arithmetic(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
bounded_le4 = 5 - a
|
||||
@ -471,6 +459,10 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))
|
||||
|
||||
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
|
||||
self.assertEqual(_bounds((a + 3) - (a + 3) // 2), (2, np.inf))
|
||||
self.assertEqual(_bounds((a + 6) - 1 * (a + 6) // 4), (6, np.inf))
|
||||
self.assertEqual(_bounds((a + 6) - 2 * ((a + 6) // 4)), (4, np.inf))
|
||||
self.assertEqual(_bounds((a + 6) - 3 * ((a + 6) // 4)), (2, np.inf))
|
||||
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Possible division by 0"):
|
||||
@ -495,15 +487,6 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertGreaterEqual(fact_val, lb)
|
||||
self.assertLessEqual(fact_val, ub)
|
||||
|
||||
def test_bounds_non_negative(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
|
||||
self.assertEqual(_bounds(core.non_negative_dim(a)), (1, np.inf))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
|
||||
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)), (1, 3))
|
||||
|
||||
def test_max_dim(self):
|
||||
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
|
||||
|
||||
@ -599,7 +582,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_bounds_complex(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
min_a_b = b - core.non_negative_dim(b - a)
|
||||
min_a_b = b - core.max_dim(0, b - a)
|
||||
# This comes up in slicing with stride
|
||||
self.assertGreaterEqual(min_a_b // 2, 0)
|
||||
|
||||
@ -805,16 +788,6 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
set(decision.combine_term_with_existing(_m(d), 2, scope=scope,
|
||||
only_smaller_than_t=True)))
|
||||
|
||||
def test_non_negative_dim(self):
|
||||
a, = shape_poly.symbolic_shape("a,")
|
||||
|
||||
self.sampled_assertion(2, core.non_negative_dim, 2)
|
||||
self.sampled_assertion(0, core.non_negative_dim, 0)
|
||||
self.sampled_assertion(0, core.non_negative_dim, -1)
|
||||
self.sampled_assertion(a, core.non_negative_dim, a)
|
||||
self.sampled_assertion(2 * a - 1, core.non_negative_dim, 2 * a - 1)
|
||||
self.sampled_assertion(core.non_negative_dim(a - 2),
|
||||
core.non_negative_dim, a - 2)
|
||||
|
||||
def test_dilate_dim(self):
|
||||
"""0 if d == 0 else 1 + dilation * (d - 1))"""
|
||||
@ -3013,31 +2986,30 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
RandArg((3, 4, 5), _f32)],
|
||||
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
# TODO(necula): The known dimensions product must be even.
|
||||
PolyHarness("random_categorical", f"axis=0_{flags_name}",
|
||||
lambda key, a: jax.random.categorical(
|
||||
jax.random.wrap_key_data(key), a, axis=0),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 8), _f32)],
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
polymorphic_shapes=[None, "b0, b1"],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_categorical", f"axis=1_{flags_name}",
|
||||
lambda key, a: jax.random.categorical(
|
||||
jax.random.wrap_key_data(key), a, axis=1),
|
||||
jax.random.wrap_key_data(key), a, axis=1),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 5, 8), _f32)],
|
||||
polymorphic_shapes=[None, "b0, b1, ..."],
|
||||
polymorphic_shapes=[None, "b0, b1, b2"],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_categorical", f"axis=1_then_reshape_{flags_name}",
|
||||
lambda key, a: jax.random.categorical(
|
||||
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
|
||||
jax.random.wrap_key_data(key), a, axis=1).reshape(-1),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 5, 8), _f32)],
|
||||
polymorphic_shapes=[None, "b0, b1, ..."],
|
||||
polymorphic_shapes=[None, "b0, b1, b2"],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_categorical", f"0_dim_{flags_name}", # One axis has 0 size
|
||||
lambda key, a: jax.random.categorical(
|
||||
jax.random.wrap_key_data(key), a, axis=1),
|
||||
jax.random.wrap_key_data(key), a, axis=1),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 5, 0), _f32)],
|
||||
polymorphic_shapes=[None, "b0, b1, ..."],
|
||||
@ -3055,14 +3027,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
RandArg((64, 12, 4), _f32), # sample on axis=1
|
||||
RandArg((3, 4), _f32),
|
||||
StaticArg(use_p)],
|
||||
# TODO(necula): threefry requires even-sized samples.
|
||||
polymorphic_shapes=[None,
|
||||
"_, 2*b1, _" if arr_poly else None,
|
||||
"b0, b1, b2" if arr_poly else None,
|
||||
"b3, b4" if shape_poly else None],
|
||||
# The array sampled dimension must be larger than res_shape.size
|
||||
symbolic_constraints=[
|
||||
"2*b1 >= 12" if arr_poly else "1 >= 0",
|
||||
"2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
|
||||
"b1 >= 12" if arr_poly else "1 >= 0",
|
||||
"b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0",
|
||||
"12 >= b3*b4" if shape_poly else "1 >= 0"
|
||||
],
|
||||
override_jax_config_flags=override_jax_config_flags,
|
||||
@ -3089,24 +3060,20 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
|
||||
a.shape, dtype=_f32),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4, 5), _f32)],
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
polymorphic_shapes=[None, "b0, 4, 5"],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_uniform", f"even_2_{flags_name}",
|
||||
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
|
||||
(2 * a.shape[0], a.shape[1]),
|
||||
dtype=_f32),
|
||||
a.shape, dtype=_f32),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32), RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=[None, "b0, b1, ..."],
|
||||
polymorphic_shapes=[None, "b0, 2*b1"],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_uniform", f"error_not_even_{flags_name}",
|
||||
PolyHarness("random_uniform", f"error_unknown_evenness_{flags_name}",
|
||||
lambda key, a: jax.random.uniform(jax.random.wrap_key_data(key),
|
||||
a.shape, dtype=_f32),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=[None, "b0, ..."],
|
||||
expect_error=(
|
||||
(core.InconclusiveDimensionOperation,
|
||||
"array size .* must be even") if flags_name == "threefry_non_partitionable" else None),
|
||||
polymorphic_shapes=[None, "b0, b1"],
|
||||
override_jax_config_flags=override_jax_config_flags) # type: ignore
|
||||
]
|
||||
for key_size, flags_name, override_jax_config_flags in [
|
||||
|
@ -1904,7 +1904,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
x = shard_map(g, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
check_rep=False,
|
||||
auto=frozenset({'j'}))(x)
|
||||
return jax.lax.with_sharding_constraint(
|
||||
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
@ -2162,7 +2161,22 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
||||
check_rep=False, auto=frozenset({'j'}))()
|
||||
|
||||
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
|
||||
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
def test_partial_auto_axis_index_degenerated_axis(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
|
||||
out_sharding = NamedSharding(mesh, P('i', None))
|
||||
|
||||
@partial(jax.jit, out_shardings=out_sharding)
|
||||
def f():
|
||||
return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1),
|
||||
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
||||
check_rep=False, auto=frozenset({'j'}))()
|
||||
|
||||
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
def test_partial_auto_ppermute(self):
|
||||
if xla_extension_version < 302:
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "ac6e71fe0cf864eec152de5ba761b76d8bef3153"
|
||||
XLA_SHA256 = "2b568ff365bc4b5c2b257002aa71f094a2b60357ceb1f2a1c6c33f4ad1a411bd"
|
||||
XLA_COMMIT = "1a6361a734c5cd10dc93938fc6163a51fd37b82e"
|
||||
XLA_SHA256 = "01159fd52f0e402829a3823472a309562817c72d0212f81cd5555f77394c094f"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user