mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #9897 from jakevdp:numpy-doc
PiperOrigin-RevId: 434837122
This commit is contained in:
commit
5354a016e6
@ -1,4 +1,3 @@
|
||||
|
||||
jax.numpy package
|
||||
=================
|
||||
|
||||
@ -15,24 +14,24 @@ cannot follow NumPy exactly.
|
||||
in-place cannot be implemented in JAX. However, often JAX is able to provide
|
||||
an alternative API that is purely functional. For example, instead of in-place
|
||||
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
|
||||
update function :code:`x.at[i].set(y)`.
|
||||
update function :code:`x.at[i].set(y)` (see :attr:`ndarray.at`).
|
||||
|
||||
* Relatedly, some NumPy functions return views of arrays when possible (examples
|
||||
are :func:`numpy.transpose` and :func:`numpy.reshape`). JAX versions of such
|
||||
functions will return copies instead, although such copies can often be optimized
|
||||
* Relatedly, some NumPy functions often return views of arrays when possible
|
||||
(examples are :func:`transpose` and :func:`reshape`). JAX versions of such
|
||||
functions will return copies instead, although such are often optimized
|
||||
away by XLA when sequences of operations are compiled using :func:`jax.jit`.
|
||||
|
||||
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
|
||||
sometimes is less aggressive about type promotion (See :ref:`type-promotion`).
|
||||
|
||||
A small number of NumPy operations that have data-dependent output shapes are
|
||||
incompatible with :func:`jax.jit` compilation. The XLA compiler requires that
|
||||
shapes of arrays be known at compile time. While it would be possible to provide
|
||||
a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable
|
||||
to JIT-compile it because the shape of its output depends on the contents of the
|
||||
input data.
|
||||
* Some NumPy routines have data-dependent output shapes (examples include
|
||||
:func:`unique` and :func:`nonzero`). Because the XLA compiler requires array
|
||||
shapes to be known at compile time, such operations are not compatible with
|
||||
JIT. For this reason, JAX adds an optional ``size`` argument to such functions
|
||||
which may be specified statically in order to use them with JIT.
|
||||
|
||||
Not every function in NumPy is implemented; contributions are welcome!
|
||||
Nearly all applicable NumPy functions are implemented in the ``jax.numpy``
|
||||
namespace; they are listed below.
|
||||
|
||||
.. Generate the list below as follows:
|
||||
>>> import jax.numpy, numpy
|
||||
@ -42,14 +41,10 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
# Finally, sort the list using sort(1), which is different than Python's
|
||||
# sorted() function.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ndarray.at
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
ndarray.at
|
||||
abs
|
||||
absolute
|
||||
add
|
||||
@ -135,9 +130,9 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
degrees
|
||||
delete
|
||||
diag
|
||||
diagflat
|
||||
diag_indices
|
||||
diag_indices_from
|
||||
diagflat
|
||||
diagonal
|
||||
diff
|
||||
digitize
|
||||
@ -169,11 +164,11 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
fliplr
|
||||
flipud
|
||||
float_
|
||||
float_power
|
||||
float16
|
||||
float32
|
||||
float64
|
||||
floating
|
||||
float_power
|
||||
floor
|
||||
floor_divide
|
||||
fmax
|
||||
@ -183,8 +178,9 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
full
|
||||
full_like
|
||||
gcd
|
||||
get_printoptions
|
||||
generic
|
||||
geomspace
|
||||
get_printoptions
|
||||
gradient
|
||||
greater
|
||||
greater_equal
|
||||
@ -269,6 +265,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
moveaxis
|
||||
msort
|
||||
multiply
|
||||
nan_to_num
|
||||
nanargmax
|
||||
nanargmin
|
||||
nancumprod
|
||||
@ -385,6 +382,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
triu_indices_from
|
||||
true_divide
|
||||
trunc
|
||||
uint
|
||||
uint16
|
||||
uint32
|
||||
uint64
|
||||
|
Loading…
x
Reference in New Issue
Block a user