mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767. PiperOrigin-RevId: 713726697
This commit is contained in:
parent
6c8b02df01
commit
93ef0f13fe
@ -700,7 +700,9 @@ def composite(
|
||||
version: optional int to indicate semantic changes to the composite.
|
||||
|
||||
Returns:
|
||||
out: callable composite function.
|
||||
out: callable composite function. Note that positional arguments to this
|
||||
function should be interpreted as inputs and keyword arguments should be
|
||||
interpreted as attributes of the op.
|
||||
|
||||
Examples:
|
||||
Tangent kernel:
|
||||
@ -716,6 +718,13 @@ def composite(
|
||||
... print(lax.tan(x))
|
||||
[ 0. 1. -1. 0.]
|
||||
[ 0. 1. -1. 0.]
|
||||
|
||||
The recommended way to create composites is via a decorator. Use `/` and `*`
|
||||
in the function signature to be explicit about positional and keyword
|
||||
arguments respectively:
|
||||
>>> @partial(lax.composite, name="my.softmax")
|
||||
... def my_softmax_composite(x, /, *, axis):
|
||||
... return jax.nn.softmax(x, axis)
|
||||
"""
|
||||
@functools.wraps(decomposition)
|
||||
def _decorator(*args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user