rocm_jax/docs/jax.lax.rst

230 lines
3.7 KiB
ReStructuredText
Raw Normal View History

jax.lax package
2021-01-20 14:38:19 -08:00
===============
.. automodule:: jax.lax
2019-07-20 14:40:31 +01:00
:mod:`jax.lax` is a library of primitives operations that underpins libraries
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on :mod:`jax.lax` primitives.
Many of the primitives are thin wrappers around equivalent XLA operations,
described by the `XLA operation semantics
2019-07-20 14:40:31 +01:00
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
cases JAX diverges from XLA, usually to ensure that the set of operations is
closed under the operation of JVP and transpose rules.
2019-07-20 14:40:31 +01:00
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
therefore more stable and less likely to change than the :mod:`jax.lax` API.
Operators
---------
.. autosummary::
:toctree: _autosummary
abs
add
acos
approx_max_k
approx_min_k
argmax
argmin
asin
atan
atan2
batch_matmul
bessel_i0e
bessel_i1e
betainc
bitcast_convert_type
bitwise_not
bitwise_and
bitwise_or
bitwise_xor
population_count
broadcast
broadcasted_iota
broadcast_in_dim
cbrt
ceil
clamp
collapse
complex
concatenate
conj
conv
convert_element_type
conv_dimension_numbers
conv_general_dilated
2021-05-13 12:20:31 -07:00
conv_general_dilated_local
2020-10-20 22:58:53 -07:00
conv_general_dilated_patches
conv_with_general_padding
conv_transpose
cos
cosh
cummax
cummin
cumprod
cumsum
digamma
div
dot
dot_general
dynamic_index_in_dim
dynamic_slice
dynamic_slice_in_dim
dynamic_update_slice
dynamic_update_index_in_dim
dynamic_update_slice_in_dim
eq
erf
erfc
erf_inv
exp
expand_dims
expm1
fft
floor
full
full_like
gather
ge
gt
2020-01-29 08:25:21 -08:00
igamma
igammac
imag
index_in_dim
index_take
iota
is_finite
le
lt
lgamma
log
log1p
logistic
max
min
mul
ne
neg
2019-12-11 16:41:24 -05:00
nextafter
pad
pow
real
reciprocal
reduce
2021-04-05 09:54:14 -07:00
reduce_precision
reduce_window
reshape
rem
rev
round
rsqrt
scatter
scatter_add
scatter_max
scatter_min
scatter_mul
select
shift_left
shift_right_arithmetic
shift_right_logical
slice
slice_in_dim
sign
sin
sinh
sort
sort_key_val
sqrt
square
squeeze
sub
tan
tie_in
2020-04-19 11:49:15 -07:00
top_k
transpose
.. _lax-control-flow:
Control flow operators
----------------------
.. autosummary::
:toctree: _autosummary
associative_scan
cond
fori_loop
2019-08-05 12:13:07 -07:00
map
2019-05-13 14:46:35 -04:00
scan
switch
while_loop
Custom gradient operators
-------------------------
.. autosummary::
:toctree: _autosummary
stop_gradient
custom_linear_solve
custom_root
.. _jax-parallel-operators:
Parallel operators
------------------
Parallelism support is experimental.
.. autosummary::
:toctree: _autosummary
all_gather
2019-10-10 15:19:17 -04:00
all_to_all
psum
pmax
pmin
2020-04-20 19:03:43 -07:00
pmean
ppermute
2020-07-14 06:05:45 -07:00
pshuffle
2019-10-10 15:19:17 -04:00
pswapaxes
axis_index
Linear algebra operators (jax.lax.linalg)
-----------------------------------------
.. automodule:: jax.lax.linalg
.. autosummary::
:toctree: _autosummary
cholesky
eig
eigh
hessenberg
lu
householder_product
qdwh
qr
schur
svd
triangular_solve
tridiagonal_solve
Argument classes
----------------
.. currentmodule:: jax.lax
.. autoclass:: ConvDimensionNumbers
.. autoclass:: ConvGeneralDilatedDimensionNumbers
.. autoclass:: GatherDimensionNumbers
.. autoclass:: GatherScatterMode
.. autoclass:: Precision
.. autoclass:: RoundingMethod
.. autoclass:: ScatterDimensionNumbers