Merge pull request #9897 from jakevdp:numpy-doc

PiperOrigin-RevId: 434837122
This commit is contained in:
jax authors 2022-03-15 13:13:13 -07:00
commit 5354a016e6

View File

@ -1,4 +1,3 @@
jax.numpy package
=================
@ -15,24 +14,24 @@ cannot follow NumPy exactly.
in-place cannot be implemented in JAX. However, often JAX is able to provide
an alternative API that is purely functional. For example, instead of in-place
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
update function :code:`x.at[i].set(y)`.
update function :code:`x.at[i].set(y)` (see :attr:`ndarray.at`).
* Relatedly, some NumPy functions return views of arrays when possible (examples
are :func:`numpy.transpose` and :func:`numpy.reshape`). JAX versions of such
functions will return copies instead, although such copies can often be optimized
* Relatedly, some NumPy functions often return views of arrays when possible
(examples are :func:`transpose` and :func:`reshape`). JAX versions of such
functions will return copies instead, although such are often optimized
away by XLA when sequences of operations are compiled using :func:`jax.jit`.
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
sometimes is less aggressive about type promotion (See :ref:`type-promotion`).
A small number of NumPy operations that have data-dependent output shapes are
incompatible with :func:`jax.jit` compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable
to JIT-compile it because the shape of its output depends on the contents of the
input data.
* Some NumPy routines have data-dependent output shapes (examples include
:func:`unique` and :func:`nonzero`). Because the XLA compiler requires array
shapes to be known at compile time, such operations are not compatible with
JIT. For this reason, JAX adds an optional ``size`` argument to such functions
which may be specified statically in order to use them with JIT.
Not every function in NumPy is implemented; contributions are welcome!
Nearly all applicable NumPy functions are implemented in the ``jax.numpy``
namespace; they are listed below.
.. Generate the list below as follows:
>>> import jax.numpy, numpy
@ -42,14 +41,10 @@ Not every function in NumPy is implemented; contributions are welcome!
# Finally, sort the list using sort(1), which is different than Python's
# sorted() function.
.. autosummary::
:toctree: _autosummary
ndarray.at
.. autosummary::
:toctree: _autosummary
ndarray.at
abs
absolute
add
@ -135,9 +130,9 @@ Not every function in NumPy is implemented; contributions are welcome!
degrees
delete
diag
diagflat
diag_indices
diag_indices_from
diagflat
diagonal
diff
digitize
@ -169,11 +164,11 @@ Not every function in NumPy is implemented; contributions are welcome!
fliplr
flipud
float_
float_power
float16
float32
float64
floating
float_power
floor
floor_divide
fmax
@ -183,8 +178,9 @@ Not every function in NumPy is implemented; contributions are welcome!
full
full_like
gcd
get_printoptions
generic
geomspace
get_printoptions
gradient
greater
greater_equal
@ -269,6 +265,7 @@ Not every function in NumPy is implemented; contributions are welcome!
moveaxis
msort
multiply
nan_to_num
nanargmax
nanargmin
nancumprod
@ -385,6 +382,7 @@ Not every function in NumPy is implemented; contributions are welcome!
triu_indices_from
true_divide
trunc
uint
uint16
uint32
uint64