mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add custom_dce to changelogs and API docs.
This commit is contained in:
parent
c61401ab6f
commit
782138fb6f
@ -16,6 +16,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## Unreleased
|
||||
|
||||
* New Features
|
||||
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
|
||||
decorator to support customizing the behavior of opaque functions under
|
||||
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
|
||||
details.
|
||||
|
||||
## jax 0.5.0 (Jan 17, 2025)
|
||||
|
||||
As of this release, JAX now uses
|
||||
|
13
docs/jax.experimental.custom_dce.rst
Normal file
13
docs/jax.experimental.custom_dce.rst
Normal file
@ -0,0 +1,13 @@
|
||||
``jax.experimental.custom_dce`` module
|
||||
======================================
|
||||
|
||||
.. automodule:: jax.experimental.custom_dce
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
custom_dce
|
||||
custom_dce.def_dce
|
@ -16,6 +16,7 @@ Experimental Modules
|
||||
|
||||
jax.experimental.checkify
|
||||
jax.experimental.compilation_cache
|
||||
jax.experimental.custom_dce
|
||||
jax.experimental.custom_partitioning
|
||||
jax.experimental.jet
|
||||
jax.experimental.key_reuse
|
||||
|
@ -75,9 +75,9 @@ class custom_dce:
|
||||
... x * jnp.sin(y) if used_outs[1] else None,
|
||||
... )
|
||||
|
||||
In this example, ``used_outs`` is a ``tuple`` with two ``bool``s indicating
|
||||
which outputs are required. The DCE rule only computes the required outputs,
|
||||
replacing the unused outputs with ``None``.
|
||||
In this example, ``used_outs`` is a ``tuple`` with two ``bool`` values,
|
||||
indicating which outputs are required. The DCE rule only computes the
|
||||
required outputs, replacing the unused outputs with ``None``.
|
||||
|
||||
If the ``static_argnums`` argument is provided to ``custom_dce``, the
|
||||
indicated arguments are treated as static when the function is traced, and
|
||||
@ -108,12 +108,12 @@ class custom_dce:
|
||||
|
||||
Args:
|
||||
dce_rule: A function that takes (a) any arguments indicated as static
|
||||
using ``static_argnums``, (b) a Pytree of ``bool``s (``used_outs``)
|
||||
indicating which outputs should be computed, and (c) the rest of the
|
||||
(non-static) arguments to the original function. The rule should return
|
||||
a Pytree with with the same structure as the output of the original
|
||||
function, but any unused outputs (as indicated by ``used_outs``) can be
|
||||
replaced with ``None``.
|
||||
using ``static_argnums``, (b) a Pytree of ``bool`` values
|
||||
(``used_outs``) indicating which outputs should be computed, and (c)
|
||||
the rest of the (non-static) arguments to the original function. The
|
||||
rule should return a Pytree with with the same structure as the output
|
||||
of the original function, but any unused outputs (as indicated by
|
||||
``used_outs``) can be replaced with ``None``.
|
||||
"""
|
||||
self.dce_rule = dce_rule
|
||||
return dce_rule
|
||||
|
Loading…
x
Reference in New Issue
Block a user