2023-01-18 10:59:42 -08:00
|
|
|
``jax.lax`` module
|
|
|
|
==================
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
.. automodule:: jax.lax
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2019-07-20 14:40:31 +01:00
|
|
|
:mod:`jax.lax` is a library of primitives operations that underpins libraries
|
|
|
|
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
|
|
|
|
are typically defined as transformations on :mod:`jax.lax` primitives.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Many of the primitives are thin wrappers around equivalent XLA operations,
|
|
|
|
described by the `XLA operation semantics
|
2019-07-20 14:40:31 +01:00
|
|
|
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
|
|
|
|
cases JAX diverges from XLA, usually to ensure that the set of operations is
|
|
|
|
closed under the operation of JVP and transpose rules.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2019-07-20 14:40:31 +01:00
|
|
|
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
|
|
|
|
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
|
|
|
|
therefore more stable and less likely to change than the :mod:`jax.lax` API.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Operators
|
|
|
|
---------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
abs
|
|
|
|
acos
|
2023-10-06 10:26:36 -07:00
|
|
|
acosh
|
|
|
|
add
|
|
|
|
after_all
|
2022-03-01 14:46:04 -08:00
|
|
|
approx_max_k
|
|
|
|
approx_min_k
|
2020-07-01 11:01:22 -04:00
|
|
|
argmax
|
|
|
|
argmin
|
2019-04-15 12:14:41 -04:00
|
|
|
asin
|
2023-10-06 10:26:36 -07:00
|
|
|
asinh
|
2019-04-15 12:14:41 -04:00
|
|
|
atan
|
|
|
|
atan2
|
2023-10-06 10:26:36 -07:00
|
|
|
atanh
|
2019-04-15 12:14:41 -04:00
|
|
|
batch_matmul
|
2019-10-21 13:47:36 -04:00
|
|
|
bessel_i0e
|
|
|
|
bessel_i1e
|
2020-01-15 13:13:11 -08:00
|
|
|
betainc
|
2019-04-15 12:14:41 -04:00
|
|
|
bitcast_convert_type
|
|
|
|
bitwise_and
|
2023-10-06 10:26:36 -07:00
|
|
|
bitwise_not
|
2019-04-15 12:14:41 -04:00
|
|
|
bitwise_or
|
|
|
|
bitwise_xor
|
2020-04-28 06:32:52 +01:00
|
|
|
population_count
|
2019-04-15 12:14:41 -04:00
|
|
|
broadcast
|
|
|
|
broadcast_in_dim
|
2023-10-06 10:26:36 -07:00
|
|
|
broadcast_shapes
|
|
|
|
broadcast_to_rank
|
|
|
|
broadcasted_iota
|
2021-07-22 14:00:52 -07:00
|
|
|
cbrt
|
2019-04-15 12:14:41 -04:00
|
|
|
ceil
|
|
|
|
clamp
|
2023-10-06 10:26:36 -07:00
|
|
|
clz
|
2019-04-15 12:14:41 -04:00
|
|
|
collapse
|
|
|
|
complex
|
|
|
|
concatenate
|
|
|
|
conj
|
|
|
|
conv
|
|
|
|
convert_element_type
|
2021-10-19 17:18:15 -07:00
|
|
|
conv_dimension_numbers
|
2019-04-15 12:14:41 -04:00
|
|
|
conv_general_dilated
|
2021-05-13 12:20:31 -07:00
|
|
|
conv_general_dilated_local
|
2020-10-20 22:58:53 -07:00
|
|
|
conv_general_dilated_patches
|
2019-04-15 12:14:41 -04:00
|
|
|
conv_transpose
|
2023-10-06 10:26:36 -07:00
|
|
|
conv_with_general_padding
|
2019-04-15 12:14:41 -04:00
|
|
|
cos
|
|
|
|
cosh
|
2023-10-06 10:26:36 -07:00
|
|
|
cumlogsumexp
|
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:26:29 -04:00
|
|
|
cummax
|
|
|
|
cummin
|
|
|
|
cumprod
|
|
|
|
cumsum
|
2019-04-15 12:14:41 -04:00
|
|
|
digamma
|
|
|
|
div
|
|
|
|
dot
|
|
|
|
dot_general
|
|
|
|
dynamic_index_in_dim
|
|
|
|
dynamic_slice
|
|
|
|
dynamic_slice_in_dim
|
|
|
|
dynamic_update_index_in_dim
|
2023-10-06 10:26:36 -07:00
|
|
|
dynamic_update_slice
|
2019-04-15 12:14:41 -04:00
|
|
|
dynamic_update_slice_in_dim
|
|
|
|
eq
|
|
|
|
erf
|
|
|
|
erfc
|
|
|
|
erf_inv
|
|
|
|
exp
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
expand_dims
|
2019-04-15 12:14:41 -04:00
|
|
|
expm1
|
2019-05-02 17:39:21 -07:00
|
|
|
fft
|
2019-04-15 12:14:41 -04:00
|
|
|
floor
|
|
|
|
full
|
|
|
|
full_like
|
|
|
|
gather
|
|
|
|
ge
|
|
|
|
gt
|
2020-01-29 08:25:21 -08:00
|
|
|
igamma
|
|
|
|
igammac
|
2019-04-15 12:14:41 -04:00
|
|
|
imag
|
|
|
|
index_in_dim
|
|
|
|
index_take
|
2023-10-06 10:26:36 -07:00
|
|
|
integer_pow
|
2019-04-15 12:14:41 -04:00
|
|
|
iota
|
|
|
|
is_finite
|
|
|
|
le
|
|
|
|
lgamma
|
|
|
|
log
|
|
|
|
log1p
|
2022-09-07 06:06:22 -07:00
|
|
|
logistic
|
2023-10-06 10:26:36 -07:00
|
|
|
lt
|
2019-04-15 12:14:41 -04:00
|
|
|
max
|
|
|
|
min
|
|
|
|
mul
|
|
|
|
ne
|
|
|
|
neg
|
2019-12-11 16:41:24 -05:00
|
|
|
nextafter
|
2024-09-05 19:49:12 +00:00
|
|
|
optimization_barrier
|
2019-04-15 12:14:41 -04:00
|
|
|
pad
|
2024-06-25 14:08:49 -04:00
|
|
|
platform_dependent
|
2023-08-16 11:57:05 -07:00
|
|
|
polygamma
|
2023-10-06 10:26:36 -07:00
|
|
|
population_count
|
2019-04-15 12:14:41 -04:00
|
|
|
pow
|
2023-10-06 10:26:36 -07:00
|
|
|
random_gamma_grad
|
2019-04-15 12:14:41 -04:00
|
|
|
real
|
|
|
|
reciprocal
|
|
|
|
reduce
|
2021-04-05 09:54:14 -07:00
|
|
|
reduce_precision
|
2019-04-15 12:14:41 -04:00
|
|
|
reduce_window
|
|
|
|
rem
|
2023-10-06 10:26:36 -07:00
|
|
|
reshape
|
2019-04-15 12:14:41 -04:00
|
|
|
rev
|
2023-10-06 10:26:36 -07:00
|
|
|
rng_bit_generator
|
|
|
|
rng_uniform
|
2019-04-15 12:14:41 -04:00
|
|
|
round
|
|
|
|
rsqrt
|
|
|
|
scatter
|
|
|
|
scatter_add
|
2023-10-06 10:26:36 -07:00
|
|
|
scatter_apply
|
2021-11-22 15:19:14 -05:00
|
|
|
scatter_max
|
|
|
|
scatter_min
|
|
|
|
scatter_mul
|
2019-04-15 12:14:41 -04:00
|
|
|
shift_left
|
|
|
|
shift_right_arithmetic
|
|
|
|
shift_right_logical
|
|
|
|
sign
|
|
|
|
sin
|
|
|
|
sinh
|
2023-10-06 10:26:36 -07:00
|
|
|
slice
|
|
|
|
slice_in_dim
|
2019-04-15 12:14:41 -04:00
|
|
|
sort
|
|
|
|
sort_key_val
|
|
|
|
sqrt
|
|
|
|
square
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
squeeze
|
2019-04-15 12:14:41 -04:00
|
|
|
sub
|
|
|
|
tan
|
2023-10-06 10:26:36 -07:00
|
|
|
tanh
|
2020-04-19 11:49:15 -07:00
|
|
|
top_k
|
2019-04-15 12:14:41 -04:00
|
|
|
transpose
|
2023-10-06 10:26:36 -07:00
|
|
|
zeros_like_array
|
2023-08-16 13:43:41 -07:00
|
|
|
zeta
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2021-03-02 09:29:59 -08:00
|
|
|
.. _lax-control-flow:
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Control flow operators
|
|
|
|
----------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
2020-08-03 12:32:32 -07:00
|
|
|
associative_scan
|
2019-04-15 12:14:41 -04:00
|
|
|
cond
|
|
|
|
fori_loop
|
2019-08-05 12:13:07 -07:00
|
|
|
map
|
2019-05-13 14:46:35 -04:00
|
|
|
scan
|
2023-10-06 10:26:36 -07:00
|
|
|
select
|
|
|
|
select_n
|
2020-08-03 12:32:32 -07:00
|
|
|
switch
|
2019-04-15 12:14:41 -04:00
|
|
|
while_loop
|
|
|
|
|
2019-10-29 16:00:00 -07:00
|
|
|
Custom gradient operators
|
|
|
|
-------------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
stop_gradient
|
|
|
|
custom_linear_solve
|
|
|
|
custom_root
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2021-03-08 16:25:04 -08:00
|
|
|
.. _jax-parallel-operators:
|
|
|
|
|
2019-04-15 12:14:41 -04:00
|
|
|
Parallel operators
|
|
|
|
------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
2020-03-19 15:35:00 +00:00
|
|
|
all_gather
|
2019-10-10 15:19:17 -04:00
|
|
|
all_to_all
|
2019-04-15 12:14:41 -04:00
|
|
|
psum
|
2023-11-14 15:08:50 -08:00
|
|
|
psum_scatter
|
2019-05-17 12:27:09 -07:00
|
|
|
pmax
|
|
|
|
pmin
|
2020-04-20 19:03:43 -07:00
|
|
|
pmean
|
2019-05-17 12:27:09 -07:00
|
|
|
ppermute
|
2020-07-14 06:05:45 -07:00
|
|
|
pshuffle
|
2019-10-10 15:19:17 -04:00
|
|
|
pswapaxes
|
2020-03-29 13:56:26 -07:00
|
|
|
axis_index
|
2020-11-05 17:02:19 -05:00
|
|
|
|
2023-04-26 10:19:04 -07:00
|
|
|
Sharding-related operators
|
|
|
|
--------------------------
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
with_sharding_constraint
|
|
|
|
|
2020-11-05 17:02:19 -05:00
|
|
|
Linear algebra operators (jax.lax.linalg)
|
|
|
|
-----------------------------------------
|
|
|
|
|
|
|
|
.. automodule:: jax.lax.linalg
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
cholesky
|
|
|
|
eig
|
|
|
|
eigh
|
2022-11-09 06:23:22 -08:00
|
|
|
hessenberg
|
2020-11-05 17:02:19 -05:00
|
|
|
lu
|
2022-11-09 06:23:22 -08:00
|
|
|
householder_product
|
2021-10-29 14:44:27 -07:00
|
|
|
qdwh
|
2020-11-05 17:02:19 -05:00
|
|
|
qr
|
2022-03-15 09:55:59 -07:00
|
|
|
schur
|
2020-11-05 17:02:19 -05:00
|
|
|
svd
|
|
|
|
triangular_solve
|
2022-11-10 13:15:44 -08:00
|
|
|
tridiagonal
|
2022-03-15 09:55:59 -07:00
|
|
|
tridiagonal_solve
|
2020-12-01 17:23:36 +01:00
|
|
|
|
|
|
|
Argument classes
|
|
|
|
----------------
|
|
|
|
|
2021-02-23 10:31:44 -08:00
|
|
|
.. currentmodule:: jax.lax
|
2020-12-01 17:23:36 +01:00
|
|
|
|
|
|
|
.. autoclass:: ConvDimensionNumbers
|
2021-10-19 17:18:15 -07:00
|
|
|
.. autoclass:: ConvGeneralDilatedDimensionNumbers
|
2024-09-25 06:16:22 -07:00
|
|
|
.. autoclass:: DotAlgorithm
|
Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
2024-10-07 13:20:24 -07:00
|
|
|
.. autoclass:: DotAlgorithmPreset
|
|
|
|
:members:
|
|
|
|
:undoc-members:
|
|
|
|
:member-order: bysource
|
2024-10-10 08:06:50 -07:00
|
|
|
.. autoclass:: FftType
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: GatherDimensionNumbers
|
2021-11-22 15:19:14 -05:00
|
|
|
.. autoclass:: GatherScatterMode
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: Precision
|
2023-12-05 10:50:22 -05:00
|
|
|
.. autoclass:: PrecisionLike
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: RoundingMethod
|
2024-10-03 07:22:22 -07:00
|
|
|
:members:
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: ScatterDimensionNumbers
|