mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: add documentation of DeviceArray object properties & methods
This commit is contained in:
parent
010c383ab3
commit
26f74e64a6
@ -440,3 +440,20 @@ jax.numpy.linalg
|
||||
svd
|
||||
tensorinv
|
||||
tensorsolve
|
||||
|
||||
JAX DeviceArray
|
||||
---------------
|
||||
The JAX :class:`~jax.numpy.DeviceArray` is the core array object in JAX: you can
|
||||
think of it as the equivalent of a :class:`numpy.ndarray` backed by a memory buffer
|
||||
on a single device. Like :class:`numpy.ndarray`, most users will not need to
|
||||
instantiate :class:`DeviceArray`s manually, but rather will create them via
|
||||
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
|
||||
:func:`~jax.numpy.linspace`, and others listed above.
|
||||
|
||||
.. autoclass:: jax.numpy.DeviceArray
|
||||
|
||||
.. autoclass:: jaxlib.xla_extension.DeviceArrayBase
|
||||
|
||||
.. autoclass:: jaxlib.xla_extension.DeviceArray
|
||||
:members:
|
||||
:inherited-members:
|
@ -5837,18 +5837,20 @@ class _IndexUpdateHelper:
|
||||
# Note: this docstring will appear as the docstring for the `at` property.
|
||||
"""Indexable helper object to call indexed update functions.
|
||||
|
||||
The `at` property is syntactic sugar for calling the indexed update functions
|
||||
The ``at`` property is syntactic sugar for calling the indexed update functions
|
||||
defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place
|
||||
modificatons.
|
||||
modificatons. For further information, see `Syntactic Sugar for Index Update Operators
|
||||
<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.
|
||||
|
||||
In particular:
|
||||
|
||||
- ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``.
|
||||
- ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``.
|
||||
- ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``.
|
||||
- ``x = x.at[idx].min(y)`` is a pure equivalent of
|
||||
``x[idx] = minimum(x[idx], y)``.
|
||||
``x[idx] = minimum(x[idx], y)``.
|
||||
- ``x = x.at[idx].max(y)`` is a pure equivalent of
|
||||
``x[idx] = maximum(x[idx], y)``.
|
||||
``x[idx] = maximum(x[idx], y)``.
|
||||
"""
|
||||
__slots__ = ("array",)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user