Clarify documentation of composites.

There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
This commit is contained in:
Gunhyun Park 2025-01-09 10:54:09 -08:00 committed by jax authors
parent 6c8b02df01
commit 93ef0f13fe

View File

@ -700,7 +700,9 @@ def composite(
version: optional int to indicate semantic changes to the composite.
Returns:
out: callable composite function.
out: callable composite function. Note that positional arguments to this
function should be interpreted as inputs and keyword arguments should be
interpreted as attributes of the op.
Examples:
Tangent kernel:
@ -716,6 +718,13 @@ def composite(
... print(lax.tan(x))
[ 0. 1. -1. 0.]
[ 0. 1. -1. 0.]
The recommended way to create composites is via a decorator. Use `/` and `*`
in the function signature to be explicit about positional and keyword
arguments respectively:
>>> @partial(lax.composite, name="my.softmax")
... def my_softmax_composite(x, /, *, axis):
... return jax.nn.softmax(x, axis)
"""
@functools.wraps(decomposition)
def _decorator(*args, **kwargs):