From 782138fb6fdba215c62b73284a70b5b2967cb85f Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 27 Jan 2025 11:16:00 -0500 Subject: [PATCH] Add custom_dce to changelogs and API docs. --- CHANGELOG.md | 6 ++++++ docs/jax.experimental.custom_dce.rst | 13 +++++++++++++ docs/jax.experimental.rst | 1 + jax/_src/custom_dce.py | 18 +++++++++--------- 4 files changed, 29 insertions(+), 9 deletions(-) create mode 100644 docs/jax.experimental.custom_dce.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f7eedb0d..d7a8e0b93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* New Features + * Added an experimental {func}`jax.experimental.custom_dce.custom_dce` + decorator to support customizing the behavior of opaque functions under + JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more + details. + ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses diff --git a/docs/jax.experimental.custom_dce.rst b/docs/jax.experimental.custom_dce.rst new file mode 100644 index 000000000..776b68ff1 --- /dev/null +++ b/docs/jax.experimental.custom_dce.rst @@ -0,0 +1,13 @@ +``jax.experimental.custom_dce`` module +====================================== + +.. automodule:: jax.experimental.custom_dce + +API +--- + +.. autosummary:: + :toctree: _autosummary + + custom_dce + custom_dce.def_dce diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index da0584f0e..39c778db7 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -16,6 +16,7 @@ Experimental Modules jax.experimental.checkify jax.experimental.compilation_cache + jax.experimental.custom_dce jax.experimental.custom_partitioning jax.experimental.jet jax.experimental.key_reuse diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index 45479bcbe..7e8bbc290 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -75,9 +75,9 @@ class custom_dce: ... x * jnp.sin(y) if used_outs[1] else None, ... ) - In this example, ``used_outs`` is a ``tuple`` with two ``bool``s indicating - which outputs are required. The DCE rule only computes the required outputs, - replacing the unused outputs with ``None``. + In this example, ``used_outs`` is a ``tuple`` with two ``bool`` values, + indicating which outputs are required. The DCE rule only computes the + required outputs, replacing the unused outputs with ``None``. If the ``static_argnums`` argument is provided to ``custom_dce``, the indicated arguments are treated as static when the function is traced, and @@ -108,12 +108,12 @@ class custom_dce: Args: dce_rule: A function that takes (a) any arguments indicated as static - using ``static_argnums``, (b) a Pytree of ``bool``s (``used_outs``) - indicating which outputs should be computed, and (c) the rest of the - (non-static) arguments to the original function. The rule should return - a Pytree with with the same structure as the output of the original - function, but any unused outputs (as indicated by ``used_outs``) can be - replaced with ``None``. + using ``static_argnums``, (b) a Pytree of ``bool`` values + (``used_outs``) indicating which outputs should be computed, and (c) + the rest of the (non-static) arguments to the original function. The + rule should return a Pytree with with the same structure as the output + of the original function, but any unused outputs (as indicated by + ``used_outs``) can be replaced with ``None``. """ self.dce_rule = dce_rule return dce_rule