jax.numpy package ================= .. currentmodule:: jax.numpy .. automodule:: jax.numpy Implements the NumPy API, using the primitives in :mod:`jax.lax`. While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly. * Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed update function :code:`x.at[i].set(y)` (see :attr:`ndarray.at`). * Relatedly, some NumPy functions often return views of arrays when possible (examples are :func:`transpose` and :func:`reshape`). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled using :func:`jax.jit`. * NumPy is very aggressive at promoting values to :code:`float64` type. JAX sometimes is less aggressive about type promotion (See :ref:`type-promotion`). * Some NumPy routines have data-dependent output shapes (examples include :func:`unique` and :func:`nonzero`). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optional ``size`` argument to such functions which may be specified statically in order to use them with JIT. Nearly all applicable NumPy functions are implemented in the ``jax.numpy`` namespace; they are listed below. .. Generate the list below as follows: >>> import jax.numpy, numpy >>> fns = set(dir(numpy)) & set(dir(jax.numpy)) - set(jax.numpy._NOT_IMPLEMENTED) >>> print('\n'.join(' ' + x for x in fns if callable(getattr(jax.numpy, x)))) # doctest: +SKIP # Finally, sort the list using sort(1), which is different than Python's # sorted() function. .. autosummary:: :toctree: _autosummary ndarray.at abs absolute add all allclose alltrue amax amin angle any append apply_along_axis apply_over_axes arange arccos arccosh arcsin arcsinh arctan arctan2 arctanh argmax argmin argsort argwhere around array array_equal array_equiv array_repr array_split array_str asarray atleast_1d atleast_2d atleast_3d average bartlett bincount bitwise_and bitwise_not bitwise_or bitwise_xor blackman block bool_ broadcast_arrays broadcast_shapes broadcast_to c_ can_cast cbrt cdouble ceil character choose clip column_stack complex_ complex128 complex64 complexfloating ComplexWarning compress concatenate conj conjugate convolve copy copysign corrcoef correlate cos cosh count_nonzero cov cross csingle cumprod cumproduct cumsum deg2rad degrees delete diag diag_indices diag_indices_from diagflat diagonal diff digitize divide divmod dot double dsplit dstack dtype ediff1d einsum einsum_path empty empty_like equal exp exp2 expand_dims expm1 extract eye fabs finfo fix flatnonzero flexible flip fliplr flipud float_ float_power float16 float32 float64 floating floor floor_divide fmax fmin fmod frexp frombuffer fromfunction fromstring full full_like gcd generic geomspace get_printoptions gradient greater greater_equal hamming hanning heaviside histogram histogram_bin_edges histogram2d histogramdd hsplit hstack hypot i0 identity iinfo imag in1d index_exp indices inexact inner insert int_ int16 int32 int64 int8 integer interp intersect1d invert isclose iscomplex iscomplexobj isfinite isin isinf isnan isneginf isposinf isreal isrealobj isscalar issubdtype issubsctype iterable ix_ kaiser kron lcm ldexp left_shift less less_equal lexsort linspace load log log10 log1p log2 logaddexp logaddexp2 logical_and logical_not logical_or logical_xor logspace mask_indices matmul max maximum mean median meshgrid mgrid min minimum mod modf moveaxis msort multiply nan_to_num nanargmax nanargmin nancumprod nancumsum nanmax nanmean nanmedian nanmin nanpercentile nanprod nanquantile nanstd nansum nan_to_num nanvar ndarray ndim negative nextafter nonzero not_equal number object_ ogrid ones ones_like outer packbits pad percentile piecewise poly polyadd polyder polyfit polyint polymul polysub polyval positive power printoptions prod product promote_types ptp quantile r_ rad2deg radians ravel ravel_multi_index real reciprocal remainder repeat reshape resize result_type right_shift rint roll rollaxis roots rot90 round round_ row_stack s_ save savez searchsorted select set_printoptions setdiff1d setxor1d shape sign signbit signedinteger sin sinc single sinh size sometrue sort sort_complex split sqrt square squeeze stack std subtract sum swapaxes take take_along_axis tan tanh tensordot tile trace transpose trapz tri tril tril_indices tril_indices_from trim_zeros triu triu_indices triu_indices_from true_divide trunc uint uint16 uint32 uint64 uint8 union1d unique unpackbits unravel_index unsignedinteger unwrap vander var vdot vectorize vsplit vstack where zeros zeros_like jax.numpy.fft ------------- .. automodule:: jax.numpy.fft .. autosummary:: :toctree: _autosummary fft fft2 fftfreq fftn fftshift hfft ifft ifft2 ifftn ifftshift ihfft irfft irfft2 irfftn rfft rfft2 rfftfreq rfftn jax.numpy.linalg ---------------- .. automodule:: jax.numpy.linalg .. autosummary:: :toctree: _autosummary cholesky cond det eig eigh eigvals eigvalsh inv lstsq matrix_power matrix_rank multi_dot norm pinv qr slogdet solve svd tensorinv tensorsolve JAX DeviceArray --------------- The JAX :class:`~jax.numpy.DeviceArray` is the core array object in JAX: you can think of it as the equivalent of a :class:`numpy.ndarray` backed by a memory buffer on a single device. Like :class:`numpy.ndarray`, most users will not need to instantiate :class:`DeviceArray` objects manually, but rather will create them via :mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`, :func:`~jax.numpy.linspace`, and others listed above. .. autoclass:: jax.numpy.DeviceArray .. autoclass:: jaxlib.xla_extension.DeviceArrayBase .. autoclass:: jaxlib.xla_extension.DeviceArray :members: :inherited-members: