mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code. PiperOrigin-RevId: 469789339
228 lines
3.7 KiB
ReStructuredText
228 lines
3.7 KiB
ReStructuredText
jax.lax package
|
|
===============
|
|
|
|
.. automodule:: jax.lax
|
|
|
|
:mod:`jax.lax` is a library of primitives operations that underpins libraries
|
|
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
|
|
are typically defined as transformations on :mod:`jax.lax` primitives.
|
|
|
|
Many of the primitives are thin wrappers around equivalent XLA operations,
|
|
described by the `XLA operation semantics
|
|
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
|
|
cases JAX diverges from XLA, usually to ensure that the set of operations is
|
|
closed under the operation of JVP and transpose rules.
|
|
|
|
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
|
|
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
|
|
therefore more stable and less likely to change than the :mod:`jax.lax` API.
|
|
|
|
Operators
|
|
---------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
abs
|
|
add
|
|
acos
|
|
approx_max_k
|
|
approx_min_k
|
|
argmax
|
|
argmin
|
|
asin
|
|
atan
|
|
atan2
|
|
batch_matmul
|
|
bessel_i0e
|
|
bessel_i1e
|
|
betainc
|
|
bitcast_convert_type
|
|
bitwise_not
|
|
bitwise_and
|
|
bitwise_or
|
|
bitwise_xor
|
|
population_count
|
|
broadcast
|
|
broadcasted_iota
|
|
broadcast_in_dim
|
|
cbrt
|
|
ceil
|
|
clamp
|
|
collapse
|
|
complex
|
|
concatenate
|
|
conj
|
|
conv
|
|
convert_element_type
|
|
conv_dimension_numbers
|
|
conv_general_dilated
|
|
conv_general_dilated_local
|
|
conv_general_dilated_patches
|
|
conv_with_general_padding
|
|
conv_transpose
|
|
cos
|
|
cosh
|
|
cummax
|
|
cummin
|
|
cumprod
|
|
cumsum
|
|
digamma
|
|
div
|
|
dot
|
|
dot_general
|
|
dynamic_index_in_dim
|
|
dynamic_slice
|
|
dynamic_slice_in_dim
|
|
dynamic_update_slice
|
|
dynamic_update_index_in_dim
|
|
dynamic_update_slice_in_dim
|
|
eq
|
|
erf
|
|
erfc
|
|
erf_inv
|
|
exp
|
|
expand_dims
|
|
expm1
|
|
fft
|
|
floor
|
|
full
|
|
full_like
|
|
gather
|
|
ge
|
|
gt
|
|
igamma
|
|
igammac
|
|
imag
|
|
index_in_dim
|
|
index_take
|
|
iota
|
|
is_finite
|
|
le
|
|
lt
|
|
lgamma
|
|
log
|
|
log1p
|
|
logistic
|
|
max
|
|
min
|
|
mul
|
|
ne
|
|
neg
|
|
nextafter
|
|
pad
|
|
pow
|
|
real
|
|
reciprocal
|
|
reduce
|
|
reduce_precision
|
|
reduce_window
|
|
reshape
|
|
rem
|
|
rev
|
|
round
|
|
rsqrt
|
|
scatter
|
|
scatter_add
|
|
scatter_max
|
|
scatter_min
|
|
scatter_mul
|
|
select
|
|
shift_left
|
|
shift_right_arithmetic
|
|
shift_right_logical
|
|
slice
|
|
slice_in_dim
|
|
sign
|
|
sin
|
|
sinh
|
|
sort
|
|
sort_key_val
|
|
sqrt
|
|
square
|
|
squeeze
|
|
sub
|
|
tan
|
|
tie_in
|
|
top_k
|
|
transpose
|
|
|
|
.. _lax-control-flow:
|
|
|
|
Control flow operators
|
|
----------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
associative_scan
|
|
cond
|
|
fori_loop
|
|
map
|
|
scan
|
|
switch
|
|
while_loop
|
|
|
|
Custom gradient operators
|
|
-------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
stop_gradient
|
|
custom_linear_solve
|
|
custom_root
|
|
|
|
.. _jax-parallel-operators:
|
|
|
|
Parallel operators
|
|
------------------
|
|
|
|
Parallelism support is experimental.
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
all_gather
|
|
all_to_all
|
|
psum
|
|
pmax
|
|
pmin
|
|
pmean
|
|
ppermute
|
|
pshuffle
|
|
pswapaxes
|
|
axis_index
|
|
|
|
Linear algebra operators (jax.lax.linalg)
|
|
-----------------------------------------
|
|
|
|
.. automodule:: jax.lax.linalg
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
cholesky
|
|
eig
|
|
eigh
|
|
lu
|
|
qdwh
|
|
qr
|
|
schur
|
|
svd
|
|
triangular_solve
|
|
tridiagonal_solve
|
|
|
|
Argument classes
|
|
----------------
|
|
|
|
.. currentmodule:: jax.lax
|
|
|
|
.. autoclass:: ConvDimensionNumbers
|
|
.. autoclass:: ConvGeneralDilatedDimensionNumbers
|
|
.. autoclass:: GatherDimensionNumbers
|
|
.. autoclass:: GatherScatterMode
|
|
.. autoclass:: Precision
|
|
.. autoclass:: RoundingMethod
|
|
.. autoclass:: ScatterDimensionNumbers
|