Add custom_dce to changelogs and API docs.

This commit is contained in:
Dan Foreman-Mackey 2025-01-27 11:16:00 -05:00
parent c61401ab6f
commit 782138fb6f
4 changed files with 29 additions and 9 deletions

View File

@ -16,6 +16,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## Unreleased
* New Features
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
## jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses

View File

@ -0,0 +1,13 @@
``jax.experimental.custom_dce`` module
======================================
.. automodule:: jax.experimental.custom_dce
API
---
.. autosummary::
:toctree: _autosummary
custom_dce
custom_dce.def_dce

View File

@ -16,6 +16,7 @@ Experimental Modules
jax.experimental.checkify
jax.experimental.compilation_cache
jax.experimental.custom_dce
jax.experimental.custom_partitioning
jax.experimental.jet
jax.experimental.key_reuse

View File

@ -75,9 +75,9 @@ class custom_dce:
... x * jnp.sin(y) if used_outs[1] else None,
... )
In this example, ``used_outs`` is a ``tuple`` with two ``bool``s indicating
which outputs are required. The DCE rule only computes the required outputs,
replacing the unused outputs with ``None``.
In this example, ``used_outs`` is a ``tuple`` with two ``bool`` values,
indicating which outputs are required. The DCE rule only computes the
required outputs, replacing the unused outputs with ``None``.
If the ``static_argnums`` argument is provided to ``custom_dce``, the
indicated arguments are treated as static when the function is traced, and
@ -108,12 +108,12 @@ class custom_dce:
Args:
dce_rule: A function that takes (a) any arguments indicated as static
using ``static_argnums``, (b) a Pytree of ``bool``s (``used_outs``)
indicating which outputs should be computed, and (c) the rest of the
(non-static) arguments to the original function. The rule should return
a Pytree with with the same structure as the output of the original
function, but any unused outputs (as indicated by ``used_outs``) can be
replaced with ``None``.
using ``static_argnums``, (b) a Pytree of ``bool`` values
(``used_outs``) indicating which outputs should be computed, and (c)
the rest of the (non-static) arguments to the original function. The
rule should return a Pytree with with the same structure as the output
of the original function, but any unused outputs (as indicated by
``used_outs``) can be replaced with ``None``.
"""
self.dce_rule = dce_rule
return dce_rule