rocm_jax/docs/jax.ops.rst

55 lines
1.6 KiB
ReStructuredText
Raw Normal View History

2019-03-04 20:12:22 -05:00
jax.ops package
=================
.. currentmodule:: jax.ops
.. automodule:: jax.ops
Indexed update operators
------------------------
JAX is intended to be used with a functional style of programming, and hence
2019-07-20 14:40:31 +01:00
does not support NumPy-style indexed assignment directly. Instead, JAX provides
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
2019-03-04 20:12:22 -05:00
.. autosummary::
:toctree: _autosummary
index
index_update
index_add
index_mul
index_min
index_max
2019-07-20 14:40:31 +01:00
2020-04-17 01:30:06 +00:00
Syntactic sugar for indexed update operators
--------------------------------------------
JAX also provides an alternate syntax for these indexed update operators.
Specifically, JAX ndarray types have a property ``at``, which can be used as
follows (where ``idx`` can be an arbitrary index expression).
==================== ===================================================
Alternate syntax Equivalent expression
==================== ===================================================
2020-04-17 01:43:00 +00:00
``x.at[idx].set(y)`` ``jax.ops.index_update(x, jax.ops.index[idx], y)``
``x.at[idx].add(y)`` ``jax.ops.index_add(x, jax.ops.index[idx], y)``
``x.at[idx].mul(y)`` ``jax.ops.index_mul(x, jax.ops.index[idx], y)``
``x.at[idx].min(y)`` ``jax.ops.index_min(x, jax.ops.index[idx], y)``
``x.at[idx].max(y)`` ``jax.ops.index_max(x, jax.ops.index[idx], y)``
2020-04-17 01:30:06 +00:00
==================== ===================================================
Note that none of these expressions modify the original `x`; instead they return
a modified copy of `x`.
2019-07-20 14:40:31 +01:00
Other operators
---------------
.. autosummary::
:toctree: _autosummary
segment_sum