mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
568 lines
10 KiB
ReStructuredText
568 lines
10 KiB
ReStructuredText
``jax.numpy`` module
|
|
====================
|
|
|
|
.. currentmodule:: jax.numpy
|
|
|
|
.. automodule:: jax.numpy
|
|
|
|
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
|
|
|
|
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
|
|
cannot follow NumPy exactly.
|
|
|
|
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
|
|
in-place cannot be implemented in JAX. However, often JAX is able to provide
|
|
an alternative API that is purely functional. For example, instead of in-place
|
|
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
|
|
update function :code:`x.at[i].set(y)` (see :attr:`ndarray.at`).
|
|
|
|
* Relatedly, some NumPy functions often return views of arrays when possible
|
|
(examples are :func:`transpose` and :func:`reshape`). JAX versions of such
|
|
functions will return copies instead, although such are often optimized
|
|
away by XLA when sequences of operations are compiled using :func:`jax.jit`.
|
|
|
|
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
|
|
sometimes is less aggressive about type promotion (See :ref:`type-promotion`).
|
|
|
|
* Some NumPy routines have data-dependent output shapes (examples include
|
|
:func:`unique` and :func:`nonzero`). Because the XLA compiler requires array
|
|
shapes to be known at compile time, such operations are not compatible with
|
|
JIT. For this reason, JAX adds an optional ``size`` argument to such functions
|
|
which may be specified statically in order to use them with JIT.
|
|
|
|
Nearly all applicable NumPy functions are implemented in the ``jax.numpy``
|
|
namespace; they are listed below.
|
|
|
|
.. Generate the list below as follows:
|
|
>>> import jax.numpy, numpy
|
|
>>> fns = set(dir(numpy)) & set(dir(jax.numpy))
|
|
>>> print('\n'.join(' ' + x for x in fns if callable(getattr(jax.numpy, x)))) # doctest: +SKIP
|
|
|
|
# Finally, sort the list using sort(1), which is different than Python's
|
|
# sorted() function.
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
ndarray.at
|
|
abs
|
|
absolute
|
|
acos
|
|
acosh
|
|
add
|
|
all
|
|
allclose
|
|
amax
|
|
amin
|
|
angle
|
|
any
|
|
append
|
|
apply_along_axis
|
|
apply_over_axes
|
|
arange
|
|
arccos
|
|
arccosh
|
|
arcsin
|
|
arcsinh
|
|
arctan
|
|
arctan2
|
|
arctanh
|
|
argmax
|
|
argmin
|
|
argpartition
|
|
argsort
|
|
argwhere
|
|
around
|
|
array
|
|
array_equal
|
|
array_equiv
|
|
array_repr
|
|
array_split
|
|
array_str
|
|
asarray
|
|
asin
|
|
asinh
|
|
astype
|
|
atan
|
|
atanh
|
|
atan2
|
|
atleast_1d
|
|
atleast_2d
|
|
atleast_3d
|
|
average
|
|
bartlett
|
|
bincount
|
|
bitwise_and
|
|
bitwise_count
|
|
bitwise_invert
|
|
bitwise_left_shift
|
|
bitwise_not
|
|
bitwise_or
|
|
bitwise_right_shift
|
|
bitwise_xor
|
|
blackman
|
|
block
|
|
bool_
|
|
broadcast_arrays
|
|
broadcast_shapes
|
|
broadcast_to
|
|
c_
|
|
can_cast
|
|
cbrt
|
|
cdouble
|
|
ceil
|
|
character
|
|
choose
|
|
clip
|
|
column_stack
|
|
complex_
|
|
complex128
|
|
complex64
|
|
complexfloating
|
|
ComplexWarning
|
|
compress
|
|
concat
|
|
concatenate
|
|
conj
|
|
conjugate
|
|
convolve
|
|
copy
|
|
copysign
|
|
corrcoef
|
|
correlate
|
|
cos
|
|
cosh
|
|
count_nonzero
|
|
cov
|
|
cross
|
|
csingle
|
|
cumprod
|
|
cumsum
|
|
cumulative_prod
|
|
cumulative_sum
|
|
deg2rad
|
|
degrees
|
|
delete
|
|
diag
|
|
diag_indices
|
|
diag_indices_from
|
|
diagflat
|
|
diagonal
|
|
diff
|
|
digitize
|
|
divide
|
|
divmod
|
|
dot
|
|
double
|
|
dsplit
|
|
dstack
|
|
dtype
|
|
ediff1d
|
|
einsum
|
|
einsum_path
|
|
empty
|
|
empty_like
|
|
equal
|
|
exp
|
|
exp2
|
|
expand_dims
|
|
expm1
|
|
extract
|
|
eye
|
|
fabs
|
|
fill_diagonal
|
|
finfo
|
|
fix
|
|
flatnonzero
|
|
flexible
|
|
flip
|
|
fliplr
|
|
flipud
|
|
float_
|
|
float_power
|
|
float16
|
|
float32
|
|
float64
|
|
floating
|
|
floor
|
|
floor_divide
|
|
fmax
|
|
fmin
|
|
fmod
|
|
frexp
|
|
frombuffer
|
|
fromfile
|
|
fromfunction
|
|
fromiter
|
|
frompyfunc
|
|
fromstring
|
|
from_dlpack
|
|
full
|
|
full_like
|
|
gcd
|
|
generic
|
|
geomspace
|
|
get_printoptions
|
|
gradient
|
|
greater
|
|
greater_equal
|
|
hamming
|
|
hanning
|
|
heaviside
|
|
histogram
|
|
histogram_bin_edges
|
|
histogram2d
|
|
histogramdd
|
|
hsplit
|
|
hstack
|
|
hypot
|
|
i0
|
|
identity
|
|
iinfo
|
|
imag
|
|
index_exp
|
|
indices
|
|
inexact
|
|
inner
|
|
insert
|
|
int_
|
|
int16
|
|
int32
|
|
int64
|
|
int8
|
|
integer
|
|
interp
|
|
intersect1d
|
|
invert
|
|
isclose
|
|
iscomplex
|
|
iscomplexobj
|
|
isdtype
|
|
isfinite
|
|
isin
|
|
isinf
|
|
isnan
|
|
isneginf
|
|
isposinf
|
|
isreal
|
|
isrealobj
|
|
isscalar
|
|
issubdtype
|
|
iterable
|
|
ix_
|
|
kaiser
|
|
kron
|
|
lcm
|
|
ldexp
|
|
left_shift
|
|
less
|
|
less_equal
|
|
lexsort
|
|
linspace
|
|
load
|
|
log
|
|
log10
|
|
log1p
|
|
log2
|
|
logaddexp
|
|
logaddexp2
|
|
logical_and
|
|
logical_not
|
|
logical_or
|
|
logical_xor
|
|
logspace
|
|
mask_indices
|
|
matmul
|
|
matrix_transpose
|
|
matvec
|
|
max
|
|
maximum
|
|
mean
|
|
median
|
|
meshgrid
|
|
mgrid
|
|
min
|
|
minimum
|
|
mod
|
|
modf
|
|
moveaxis
|
|
multiply
|
|
nan_to_num
|
|
nanargmax
|
|
nanargmin
|
|
nancumprod
|
|
nancumsum
|
|
nanmax
|
|
nanmean
|
|
nanmedian
|
|
nanmin
|
|
nanpercentile
|
|
nanprod
|
|
nanquantile
|
|
nanstd
|
|
nansum
|
|
nanvar
|
|
ndarray
|
|
ndim
|
|
negative
|
|
nextafter
|
|
nonzero
|
|
not_equal
|
|
number
|
|
object_
|
|
ogrid
|
|
ones
|
|
ones_like
|
|
outer
|
|
packbits
|
|
pad
|
|
partition
|
|
percentile
|
|
permute_dims
|
|
piecewise
|
|
place
|
|
poly
|
|
polyadd
|
|
polyder
|
|
polydiv
|
|
polyfit
|
|
polyint
|
|
polymul
|
|
polysub
|
|
polyval
|
|
positive
|
|
pow
|
|
power
|
|
printoptions
|
|
prod
|
|
promote_types
|
|
ptp
|
|
put
|
|
put_along_axis
|
|
quantile
|
|
r_
|
|
rad2deg
|
|
radians
|
|
ravel
|
|
ravel_multi_index
|
|
real
|
|
reciprocal
|
|
remainder
|
|
repeat
|
|
reshape
|
|
resize
|
|
result_type
|
|
right_shift
|
|
rint
|
|
roll
|
|
rollaxis
|
|
roots
|
|
rot90
|
|
round
|
|
s_
|
|
save
|
|
savez
|
|
searchsorted
|
|
select
|
|
set_printoptions
|
|
setdiff1d
|
|
setxor1d
|
|
shape
|
|
sign
|
|
signbit
|
|
signedinteger
|
|
sin
|
|
sinc
|
|
single
|
|
sinh
|
|
size
|
|
sort
|
|
sort_complex
|
|
spacing
|
|
split
|
|
sqrt
|
|
square
|
|
squeeze
|
|
stack
|
|
std
|
|
subtract
|
|
sum
|
|
swapaxes
|
|
take
|
|
take_along_axis
|
|
tan
|
|
tanh
|
|
tensordot
|
|
tile
|
|
trace
|
|
trapezoid
|
|
transpose
|
|
tri
|
|
tril
|
|
tril_indices
|
|
tril_indices_from
|
|
trim_zeros
|
|
triu
|
|
triu_indices
|
|
triu_indices_from
|
|
true_divide
|
|
trunc
|
|
ufunc
|
|
uint
|
|
uint16
|
|
uint32
|
|
uint64
|
|
uint8
|
|
union1d
|
|
unique
|
|
unique_all
|
|
unique_counts
|
|
unique_inverse
|
|
unique_values
|
|
unpackbits
|
|
unravel_index
|
|
unstack
|
|
unsignedinteger
|
|
unwrap
|
|
vander
|
|
var
|
|
vdot
|
|
vecdot
|
|
vecmat
|
|
vectorize
|
|
vsplit
|
|
vstack
|
|
where
|
|
zeros
|
|
zeros_like
|
|
|
|
jax.numpy.fft
|
|
-------------
|
|
|
|
.. automodule:: jax.numpy.fft
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
fft
|
|
fft2
|
|
fftfreq
|
|
fftn
|
|
fftshift
|
|
hfft
|
|
ifft
|
|
ifft2
|
|
ifftn
|
|
ifftshift
|
|
ihfft
|
|
irfft
|
|
irfft2
|
|
irfftn
|
|
rfft
|
|
rfft2
|
|
rfftfreq
|
|
rfftn
|
|
|
|
jax.numpy.linalg
|
|
----------------
|
|
|
|
.. automodule:: jax.numpy.linalg
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
cholesky
|
|
cond
|
|
cross
|
|
det
|
|
diagonal
|
|
eig
|
|
eigh
|
|
eigvals
|
|
eigvalsh
|
|
inv
|
|
lstsq
|
|
matmul
|
|
matrix_norm
|
|
matrix_power
|
|
matrix_rank
|
|
matrix_transpose
|
|
multi_dot
|
|
norm
|
|
outer
|
|
pinv
|
|
qr
|
|
slogdet
|
|
solve
|
|
svd
|
|
svdvals
|
|
tensordot
|
|
tensorinv
|
|
tensorsolve
|
|
trace
|
|
vector_norm
|
|
vecdot
|
|
|
|
JAX Array
|
|
---------
|
|
The JAX :class:`~jax.Array` (along with its alias, :class:`jax.numpy.ndarray`) is
|
|
the core array object in JAX: you can think of it as JAX's equivalent of a
|
|
:class:`numpy.ndarray`. Like :class:`numpy.ndarray`, most users will not need to
|
|
instantiate :class:`~jax.Array` objects manually, but rather will create them via
|
|
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
|
|
:func:`~jax.numpy.linspace`, and others listed above.
|
|
|
|
Copying and Serialization
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
JAX :class:`~jax.Array` objects are designed to work seamlessly with Python
|
|
standard library tools where appropriate.
|
|
|
|
With the built-in :mod:`copy` module, when :func:`copy.copy` or :func:`copy.deepcopy`
|
|
encounder an :class:`~jax.Array`, it is equivalent to calling the
|
|
:meth:`~jax.Array.copy` method, which will create a copy of
|
|
the buffer on the same device as the original array. This will work correctly within
|
|
traced/JIT-compiled code, though copy operations may be elided by the compiler
|
|
in this context.
|
|
|
|
When the built-in :mod:`pickle` module encounters an :class:`~jax.Array`,
|
|
it will be serialized via a compact bit representation in a similar manner to pickled
|
|
:class:`numpy.ndarray` objects. When unpickled, the result will be a new
|
|
:class:`~jax.Array` object *on the default device.*
|
|
This is because in general, pickling and unpickling may take place in different runtime
|
|
environments, and there is no general way to map the device IDs of one runtime
|
|
to the device IDs of another. If :mod:`pickle` is used in traced/JIT-compiled code,
|
|
it will result in a :class:`~jax.errors.ConcretizationTypeError`.
|
|
|
|
.. _python-array-api:
|
|
|
|
Python Array API standard
|
|
-------------------------
|
|
|
|
.. note::
|
|
|
|
Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
|
|
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
|
|
module is no longer required, and will raise a deprecation warning. After
|
|
JAX v0.5.0, this import will raise an error.
|
|
|
|
Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
|
|
with the `Python Array API Standard`_. You can access the Array API namespace via
|
|
:meth:`jax.Array.__array_namespace__`::
|
|
|
|
>>> def f(x):
|
|
... nx = x.__array_namespace__()
|
|
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> x = jnp.arange(5)
|
|
>>> f(x).round()
|
|
Array([1., 1., 1., 1., 1.], dtype=float32)
|
|
|
|
JAX departs from the standard in a few places, namely because JAX arrays are
|
|
immutable, in-place updates are not supported. Some of these incompatibilities
|
|
are being addressed via the `array-api-compat`_ module.
|
|
|
|
For more information, refer to the `Python Array API Standard`_ documentation.
|
|
|
|
.. _Python Array API Standard: https://data-apis.org/array-api
|
|
.. _array-api-compat: https://github.com/data-apis/array-api-compat |