diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 5f51cdb3b..9db79f591 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -58,6 +58,7 @@ Operators clz collapse complex + composite concatenate conj conv diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 99760099d..d4131e69b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1489,14 +1489,14 @@ def composite( ): """Composite with semantics defined by the decomposition function. - A composite is a higher-order JAX function that encapsulates an operation mad + A composite is a higher-order JAX function that encapsulates an operation made up (composed) of other JAX functions. The semantics of the op are implemented by the ``decomposition`` function. In other words, the defined composite function can be replaced with its decomposed implementation without changing the semantics of the encapsulated operation. The compiler can recognize specific composite operations by their ``name``, - ``version``, ``kawargs``, and dtypes to emit more efficient code, potentially + ``version``, ``kwargs``, and dtypes to emit more efficient code, potentially leveraging hardware-specific instructions or optimizations. If the compiler doesn't recognize the composite, it falls back to compiling the ``decomposition`` function. @@ -1505,11 +1505,11 @@ def composite( be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could recognize the "tangent" composite and emit a single ``tangent`` instruction instead of three separate instructions (``sin``, ``divide``, and ``cos``). - With compilers for hardwares without dedicated tangent support, it would fall - back to compiling the decomposition. + For hardware without dedicated tangent support, it would fall back to + compiling the decomposition. - This is useful for preserving high level abstraction that would otherwise be - lost while lowering which allows for easier pattern-matching in low-level IR. + This is useful for preserving high-level abstractions that would otherwise be + lost while lowering, which allows for easier pattern-matching in low-level IR. Args: decomposition: function that implements the semantics of the composite op. @@ -1517,19 +1517,20 @@ def composite( version: optional int to indicate semantic changes to the composite. Returns: - out: callable composite function. Note that positional arguments to this - function should be interpreted as inputs and keyword arguments should be - interpreted as attributes of the op. Any keyword arguments that are passed - with ``None`` as a value will be omitted from the - ``composite_attributes``. + Callable: Returns a composite function. Note that positional arguments to + this function should be interpreted as inputs and keyword arguments should + be interpreted as attributes of the op. Any keyword arguments that are + passed with ``None`` as a value will be omitted from the + ``composite_attributes``. Examples: Tangent kernel: + >>> def my_tangent_composite(x): ... return lax.composite( - ... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent' + ... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent" ... )(x) - ... + >>> >>> pi = jnp.pi >>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi]) >>> with jnp.printoptions(precision=3, suppress=True): @@ -1538,9 +1539,10 @@ def composite( [ 0. 1. -1. 0.] [ 0. 1. -1. 0.] - The recommended way to create composites is via a decorator. Use `/` and `*` - in the function signature to be explicit about positional and keyword - arguments respectively: + The recommended way to create composites is via a decorator. Use ``/`` and + ``*`` in the function signature to be explicit about positional and keyword + arguments, respectively: + >>> @partial(lax.composite, name="my.softmax") ... def my_softmax_composite(x, /, *, axis): ... return jax.nn.softmax(x, axis)