From 93ef0f13fe0d192046c6faefc76742fa8257509a Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 9 Jan 2025 10:54:09 -0800 Subject: [PATCH] Clarify documentation of composites. There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767. PiperOrigin-RevId: 713726697 --- jax/_src/lax/lax.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index dd946a81d..869e3247f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -700,7 +700,9 @@ def composite( version: optional int to indicate semantic changes to the composite. Returns: - out: callable composite function. + out: callable composite function. Note that positional arguments to this + function should be interpreted as inputs and keyword arguments should be + interpreted as attributes of the op. Examples: Tangent kernel: @@ -716,6 +718,13 @@ def composite( ... print(lax.tan(x)) [ 0. 1. -1. 0.] [ 0. 1. -1. 0.] + + The recommended way to create composites is via a decorator. Use `/` and `*` + in the function signature to be explicit about positional and keyword + arguments respectively: + >>> @partial(lax.composite, name="my.softmax") + ... def my_softmax_composite(x, /, *, axis): + ... return jax.nn.softmax(x, axis) """ @functools.wraps(decomposition) def _decorator(*args, **kwargs):