rocm_jax/docs/jax.numpy.rst
2022-03-15 11:02:59 -07:00

478 lines
7.4 KiB
ReStructuredText

jax.numpy package
=================
.. currentmodule:: jax.numpy
.. automodule:: jax.numpy
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
cannot follow NumPy exactly.
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
in-place cannot be implemented in JAX. However, often JAX is able to provide
an alternative API that is purely functional. For example, instead of in-place
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
update function :code:`x.at[i].set(y)` (see :attr:`ndarray.at`).
* Relatedly, some NumPy functions often return views of arrays when possible
(examples are :func:`transpose` and :func:`reshape`). JAX versions of such
functions will return copies instead, although such are often optimized
away by XLA when sequences of operations are compiled using :func:`jax.jit`.
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
sometimes is less aggressive about type promotion (See :ref:`type-promotion`).
* Some NumPy routines have data-dependent output shapes (examples include
:func:`unique` and :func:`nonzero`). Because the XLA compiler requires array
shapes to be known at compile time, such operations are not compatible with
JIT. For this reason, JAX adds an optional ``size`` argument to such functions
which may be specified statically in order to use them with JIT.
Nearly all applicable NumPy functions are implemented in the ``jax.numpy``
namespace; they are listed below.
.. Generate the list below as follows:
>>> import jax.numpy, numpy
>>> fns = set(dir(numpy)) & set(dir(jax.numpy)) - set(jax.numpy._NOT_IMPLEMENTED)
>>> print('\n'.join(' ' + x for x in fns if callable(getattr(jax.numpy, x)))) # doctest: +SKIP
# Finally, sort the list using sort(1), which is different than Python's
# sorted() function.
.. autosummary::
:toctree: _autosummary
ndarray.at
abs
absolute
add
all
allclose
alltrue
amax
amin
angle
any
append
apply_along_axis
apply_over_axes
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argsort
argwhere
around
array
array_equal
array_equiv
array_repr
array_split
array_str
asarray
atleast_1d
atleast_2d
atleast_3d
average
bartlett
bincount
bitwise_and
bitwise_not
bitwise_or
bitwise_xor
blackman
block
bool_
broadcast_arrays
broadcast_shapes
broadcast_to
c_
can_cast
cbrt
cdouble
ceil
character
choose
clip
column_stack
complex_
complex128
complex64
complexfloating
ComplexWarning
compress
concatenate
conj
conjugate
convolve
copy
copysign
corrcoef
correlate
cos
cosh
count_nonzero
cov
cross
csingle
cumprod
cumproduct
cumsum
deg2rad
degrees
delete
diag
diag_indices
diag_indices_from
diagflat
diagonal
diff
digitize
divide
divmod
dot
double
dsplit
dstack
dtype
ediff1d
einsum
einsum_path
empty
empty_like
equal
exp
exp2
expand_dims
expm1
extract
eye
fabs
finfo
fix
flatnonzero
flexible
flip
fliplr
flipud
float_
float_power
float16
float32
float64
floating
floor
floor_divide
fmax
fmin
fmod
frexp
full
full_like
gcd
generic
geomspace
get_printoptions
gradient
greater
greater_equal
hamming
hanning
heaviside
histogram
histogram_bin_edges
histogram2d
histogramdd
hsplit
hstack
hypot
i0
identity
iinfo
imag
in1d
index_exp
indices
inexact
inner
insert
int_
int16
int32
int64
int8
integer
interp
intersect1d
invert
isclose
iscomplex
iscomplexobj
isfinite
isin
isinf
isnan
isneginf
isposinf
isreal
isrealobj
isscalar
issubdtype
issubsctype
iterable
ix_
kaiser
kron
lcm
ldexp
left_shift
less
less_equal
lexsort
linspace
load
log
log10
log1p
log2
logaddexp
logaddexp2
logical_and
logical_not
logical_or
logical_xor
logspace
mask_indices
matmul
max
maximum
mean
median
meshgrid
mgrid
min
minimum
mod
modf
moveaxis
msort
multiply
nan_to_num
nanargmax
nanargmin
nancumprod
nancumsum
nanmax
nanmean
nanmedian
nanmin
nanpercentile
nanprod
nanquantile
nanstd
nansum
nan_to_num
nanvar
ndarray
ndim
negative
nextafter
nonzero
not_equal
number
object_
ogrid
ones
ones_like
outer
packbits
pad
percentile
piecewise
poly
polyadd
polyder
polyfit
polyint
polymul
polysub
polyval
positive
power
printoptions
prod
product
promote_types
ptp
quantile
r_
rad2deg
radians
ravel
ravel_multi_index
real
reciprocal
remainder
repeat
reshape
resize
result_type
right_shift
rint
roll
rollaxis
roots
rot90
round
round_
row_stack
s_
save
savez
searchsorted
select
set_printoptions
setdiff1d
setxor1d
shape
sign
signbit
signedinteger
sin
sinc
single
sinh
size
sometrue
sort
sort_complex
split
sqrt
square
squeeze
stack
std
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
tile
trace
transpose
trapz
tri
tril
tril_indices
tril_indices_from
trim_zeros
triu
triu_indices
triu_indices_from
true_divide
trunc
uint
uint16
uint32
uint64
uint8
union1d
unique
unpackbits
unravel_index
unsignedinteger
unwrap
vander
var
vdot
vectorize
vsplit
vstack
where
zeros
zeros_like
jax.numpy.fft
-------------
.. automodule:: jax.numpy.fft
.. autosummary::
:toctree: _autosummary
fft
fft2
fftfreq
fftn
fftshift
hfft
ifft
ifft2
ifftn
ifftshift
ihfft
irfft
irfft2
irfftn
rfft
rfft2
rfftfreq
rfftn
jax.numpy.linalg
----------------
.. automodule:: jax.numpy.linalg
.. autosummary::
:toctree: _autosummary
cholesky
cond
det
eig
eigh
eigvals
eigvalsh
inv
lstsq
matrix_power
matrix_rank
multi_dot
norm
pinv
qr
slogdet
solve
svd
tensorinv
tensorsolve
JAX DeviceArray
---------------
The JAX :class:`~jax.numpy.DeviceArray` is the core array object in JAX: you can
think of it as the equivalent of a :class:`numpy.ndarray` backed by a memory buffer
on a single device. Like :class:`numpy.ndarray`, most users will not need to
instantiate :class:`DeviceArray` objects manually, but rather will create them via
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
:func:`~jax.numpy.linspace`, and others listed above.
.. autoclass:: jax.numpy.DeviceArray
.. autoclass:: jaxlib.xla_extension.DeviceArrayBase
.. autoclass:: jaxlib.xla_extension.DeviceArray
:members:
:inherited-members: