Fix syntax error and typos for composite primitive docstring.

PiperOrigin-RevId: 735808000
This commit is contained in:
Gunhyun Park 2025-03-11 10:36:32 -07:00 committed by jax authors
parent 6f7ce9d048
commit d191927b24
2 changed files with 19 additions and 16 deletions

View File

@ -58,6 +58,7 @@ Operators
clz
collapse
complex
composite
concatenate
conj
conv

View File

@ -1489,14 +1489,14 @@ def composite(
):
"""Composite with semantics defined by the decomposition function.
A composite is a higher-order JAX function that encapsulates an operation mad
A composite is a higher-order JAX function that encapsulates an operation made
up (composed) of other JAX functions. The semantics of the op are implemented
by the ``decomposition`` function. In other words, the defined composite
function can be replaced with its decomposed implementation without changing
the semantics of the encapsulated operation.
The compiler can recognize specific composite operations by their ``name``,
``version``, ``kawargs``, and dtypes to emit more efficient code, potentially
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
leveraging hardware-specific instructions or optimizations. If the compiler
doesn't recognize the composite, it falls back to compiling the
``decomposition`` function.
@ -1505,11 +1505,11 @@ def composite(
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
recognize the "tangent" composite and emit a single ``tangent`` instruction
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
With compilers for hardwares without dedicated tangent support, it would fall
back to compiling the decomposition.
For hardware without dedicated tangent support, it would fall back to
compiling the decomposition.
This is useful for preserving high level abstraction that would otherwise be
lost while lowering which allows for easier pattern-matching in low-level IR.
This is useful for preserving high-level abstractions that would otherwise be
lost while lowering, which allows for easier pattern-matching in low-level IR.
Args:
decomposition: function that implements the semantics of the composite op.
@ -1517,19 +1517,20 @@ def composite(
version: optional int to indicate semantic changes to the composite.
Returns:
out: callable composite function. Note that positional arguments to this
function should be interpreted as inputs and keyword arguments should be
interpreted as attributes of the op. Any keyword arguments that are passed
with ``None`` as a value will be omitted from the
``composite_attributes``.
Callable: Returns a composite function. Note that positional arguments to
this function should be interpreted as inputs and keyword arguments should
be interpreted as attributes of the op. Any keyword arguments that are
passed with ``None`` as a value will be omitted from the
``composite_attributes``.
Examples:
Tangent kernel:
>>> def my_tangent_composite(x):
... return lax.composite(
... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent'
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
... )(x)
...
>>>
>>> pi = jnp.pi
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
>>> with jnp.printoptions(precision=3, suppress=True):
@ -1538,9 +1539,10 @@ def composite(
[ 0. 1. -1. 0.]
[ 0. 1. -1. 0.]
The recommended way to create composites is via a decorator. Use `/` and `*`
in the function signature to be explicit about positional and keyword
arguments respectively:
The recommended way to create composites is via a decorator. Use ``/`` and
``*`` in the function signature to be explicit about positional and keyword
arguments, respectively:
>>> @partial(lax.composite, name="my.softmax")
... def my_softmax_composite(x, /, *, axis):
... return jax.nn.softmax(x, axis)