rocm_jax/docs/jax.lax.rst

285 lines
4.7 KiB
ReStructuredText
Raw Normal View History

2023-01-18 10:59:42 -08:00
``jax.lax`` module
==================
.. automodule:: jax.lax
2019-07-20 14:40:31 +01:00
:mod:`jax.lax` is a library of primitives operations that underpins libraries
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on :mod:`jax.lax` primitives.
Many of the primitives are thin wrappers around equivalent XLA operations,
described by the `XLA operation semantics
2019-07-20 14:40:31 +01:00
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
cases JAX diverges from XLA, usually to ensure that the set of operations is
closed under the operation of JVP and transpose rules.
2019-07-20 14:40:31 +01:00
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
therefore more stable and less likely to change than the :mod:`jax.lax` API.
Operators
---------
.. autosummary::
:toctree: _autosummary
abs
acos
2023-10-06 10:26:36 -07:00
acosh
add
after_all
approx_max_k
approx_min_k
argmax
argmin
asin
2023-10-06 10:26:36 -07:00
asinh
atan
atan2
2023-10-06 10:26:36 -07:00
atanh
batch_matmul
bessel_i0e
bessel_i1e
betainc
bitcast_convert_type
bitwise_and
2023-10-06 10:26:36 -07:00
bitwise_not
bitwise_or
bitwise_xor
population_count
broadcast
broadcast_in_dim
2023-10-06 10:26:36 -07:00
broadcast_shapes
broadcast_to_rank
broadcasted_iota
cbrt
ceil
clamp
2023-10-06 10:26:36 -07:00
clz
collapse
complex
composite
concatenate
conj
conv
convert_element_type
conv_dimension_numbers
conv_general_dilated
2021-05-13 12:20:31 -07:00
conv_general_dilated_local
2020-10-20 22:58:53 -07:00
conv_general_dilated_patches
conv_transpose
2023-10-06 10:26:36 -07:00
conv_with_general_padding
cos
cosh
2023-10-06 10:26:36 -07:00
cumlogsumexp
cummax
cummin
cumprod
cumsum
digamma
div
dot
dot_general
dynamic_index_in_dim
dynamic_slice
dynamic_slice_in_dim
dynamic_update_index_in_dim
2023-10-06 10:26:36 -07:00
dynamic_update_slice
dynamic_update_slice_in_dim
eq
erf
erfc
erf_inv
exp
exp2
expand_dims
expm1
fft
floor
full
full_like
gather
ge
gt
2020-01-29 08:25:21 -08:00
igamma
igammac
imag
index_in_dim
index_take
2023-10-06 10:26:36 -07:00
integer_pow
iota
is_finite
le
lgamma
log
log1p
logistic
2023-10-06 10:26:36 -07:00
lt
max
min
mul
ne
neg
2019-12-11 16:41:24 -05:00
nextafter
optimization_barrier
pad
platform_dependent
polygamma
2023-10-06 10:26:36 -07:00
population_count
pow
2023-10-06 10:26:36 -07:00
random_gamma_grad
real
reciprocal
reduce
reduce_and
reduce_max
reduce_min
reduce_or
2021-04-05 09:54:14 -07:00
reduce_precision
reduce_prod
reduce_sum
reduce_window
reduce_xor
rem
2023-10-06 10:26:36 -07:00
reshape
rev
2023-10-06 10:26:36 -07:00
rng_bit_generator
rng_uniform
round
rsqrt
scatter
scatter_add
2023-10-06 10:26:36 -07:00
scatter_apply
scatter_max
scatter_min
scatter_mul
shift_left
shift_right_arithmetic
shift_right_logical
sign
sin
sinh
2023-10-06 10:26:36 -07:00
slice
slice_in_dim
sort
sort_key_val
split
sqrt
square
squeeze
sub
tan
2023-10-06 10:26:36 -07:00
tanh
2020-04-19 11:49:15 -07:00
top_k
transpose
2023-10-06 10:26:36 -07:00
zeros_like_array
zeta
.. _lax-control-flow:
Control flow operators
----------------------
.. autosummary::
:toctree: _autosummary
associative_scan
cond
fori_loop
2019-08-05 12:13:07 -07:00
map
2019-05-13 14:46:35 -04:00
scan
2023-10-06 10:26:36 -07:00
select
select_n
switch
while_loop
Custom gradient operators
-------------------------
.. autosummary::
:toctree: _autosummary
stop_gradient
custom_linear_solve
custom_root
.. _jax-parallel-operators:
Parallel operators
------------------
.. autosummary::
:toctree: _autosummary
all_gather
2019-10-10 15:19:17 -04:00
all_to_all
psum
psum_scatter
pmax
pmin
2020-04-20 19:03:43 -07:00
pmean
ppermute
2020-07-14 06:05:45 -07:00
pshuffle
2019-10-10 15:19:17 -04:00
pswapaxes
axis_index
Sharding-related operators
--------------------------
.. autosummary::
:toctree: _autosummary
with_sharding_constraint
Linear algebra operators (jax.lax.linalg)
-----------------------------------------
.. automodule:: jax.lax.linalg
.. autosummary::
:toctree: _autosummary
cholesky
cholesky_update
eig
eigh
hessenberg
householder_product
lu
lu_pivots_to_permutation
qdwh
qr
schur
svd
SvdAlgorithm
symmetric_product
triangular_solve
tridiagonal
tridiagonal_solve
Argument classes
----------------
.. currentmodule:: jax.lax
.. autoclass:: ConvDimensionNumbers
.. autoclass:: ConvGeneralDilatedDimensionNumbers
.. autoclass:: DotAlgorithm
Simplify and consolidate dot algorithm control in lax. In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases. The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested. Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected. To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.) With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`. Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this. One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach. PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
.. autoclass:: DotAlgorithmPreset
:members:
:undoc-members:
:member-order: bysource
.. autoclass:: FftType
:members:
.. autoclass:: GatherDimensionNumbers
.. autoclass:: GatherScatterMode
.. autoclass:: Precision
.. autoclass:: PrecisionLike
.. autoclass:: RandomAlgorithm
:members:
:member-order: bysource
.. autoclass:: RoundingMethod
:members:
:member-order: bysource
.. autoclass:: ScatterDimensionNumbers