mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
This commit is contained in:
parent
6f7ce9d048
commit
d191927b24
@ -58,6 +58,7 @@ Operators
|
||||
clz
|
||||
collapse
|
||||
complex
|
||||
composite
|
||||
concatenate
|
||||
conj
|
||||
conv
|
||||
|
@ -1489,14 +1489,14 @@ def composite(
|
||||
):
|
||||
"""Composite with semantics defined by the decomposition function.
|
||||
|
||||
A composite is a higher-order JAX function that encapsulates an operation mad
|
||||
A composite is a higher-order JAX function that encapsulates an operation made
|
||||
up (composed) of other JAX functions. The semantics of the op are implemented
|
||||
by the ``decomposition`` function. In other words, the defined composite
|
||||
function can be replaced with its decomposed implementation without changing
|
||||
the semantics of the encapsulated operation.
|
||||
|
||||
The compiler can recognize specific composite operations by their ``name``,
|
||||
``version``, ``kawargs``, and dtypes to emit more efficient code, potentially
|
||||
``version``, ``kwargs``, and dtypes to emit more efficient code, potentially
|
||||
leveraging hardware-specific instructions or optimizations. If the compiler
|
||||
doesn't recognize the composite, it falls back to compiling the
|
||||
``decomposition`` function.
|
||||
@ -1505,11 +1505,11 @@ def composite(
|
||||
be implemented as ``sin(x) / cos(x)``. A hardware-aware compiler could
|
||||
recognize the "tangent" composite and emit a single ``tangent`` instruction
|
||||
instead of three separate instructions (``sin``, ``divide``, and ``cos``).
|
||||
With compilers for hardwares without dedicated tangent support, it would fall
|
||||
back to compiling the decomposition.
|
||||
For hardware without dedicated tangent support, it would fall back to
|
||||
compiling the decomposition.
|
||||
|
||||
This is useful for preserving high level abstraction that would otherwise be
|
||||
lost while lowering which allows for easier pattern-matching in low-level IR.
|
||||
This is useful for preserving high-level abstractions that would otherwise be
|
||||
lost while lowering, which allows for easier pattern-matching in low-level IR.
|
||||
|
||||
Args:
|
||||
decomposition: function that implements the semantics of the composite op.
|
||||
@ -1517,19 +1517,20 @@ def composite(
|
||||
version: optional int to indicate semantic changes to the composite.
|
||||
|
||||
Returns:
|
||||
out: callable composite function. Note that positional arguments to this
|
||||
function should be interpreted as inputs and keyword arguments should be
|
||||
interpreted as attributes of the op. Any keyword arguments that are passed
|
||||
with ``None`` as a value will be omitted from the
|
||||
``composite_attributes``.
|
||||
Callable: Returns a composite function. Note that positional arguments to
|
||||
this function should be interpreted as inputs and keyword arguments should
|
||||
be interpreted as attributes of the op. Any keyword arguments that are
|
||||
passed with ``None`` as a value will be omitted from the
|
||||
``composite_attributes``.
|
||||
|
||||
Examples:
|
||||
Tangent kernel:
|
||||
|
||||
>>> def my_tangent_composite(x):
|
||||
... return lax.composite(
|
||||
... lambda x: lax.sin(x) / lax.cos(x), name='my.tangent'
|
||||
... lambda x: lax.sin(x) / lax.cos(x), name="my.tangent"
|
||||
... )(x)
|
||||
...
|
||||
>>>
|
||||
>>> pi = jnp.pi
|
||||
>>> x = jnp.array([0.0, pi / 4, 3 * pi / 4, pi])
|
||||
>>> with jnp.printoptions(precision=3, suppress=True):
|
||||
@ -1538,9 +1539,10 @@ def composite(
|
||||
[ 0. 1. -1. 0.]
|
||||
[ 0. 1. -1. 0.]
|
||||
|
||||
The recommended way to create composites is via a decorator. Use `/` and `*`
|
||||
in the function signature to be explicit about positional and keyword
|
||||
arguments respectively:
|
||||
The recommended way to create composites is via a decorator. Use ``/`` and
|
||||
``*`` in the function signature to be explicit about positional and keyword
|
||||
arguments, respectively:
|
||||
|
||||
>>> @partial(lax.composite, name="my.softmax")
|
||||
... def my_softmax_composite(x, /, *, axis):
|
||||
... return jax.nn.softmax(x, axis)
|
||||
|
Loading…
x
Reference in New Issue
Block a user