2023-01-18 10:59:42 -08:00
|
|
|
``jax.lax`` module
|
|
|
|
==================
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
.. automodule:: jax.lax
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2019-07-20 14:40:31 +01:00
|
|
|
:mod:`jax.lax` is a library of primitives operations that underpins libraries
|
|
|
|
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
|
|
|
|
are typically defined as transformations on :mod:`jax.lax` primitives.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Many of the primitives are thin wrappers around equivalent XLA operations,
|
|
|
|
described by the `XLA operation semantics
|
2019-07-20 14:40:31 +01:00
|
|
|
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
|
|
|
|
cases JAX diverges from XLA, usually to ensure that the set of operations is
|
|
|
|
closed under the operation of JVP and transpose rules.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2019-07-20 14:40:31 +01:00
|
|
|
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
|
|
|
|
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
|
|
|
|
therefore more stable and less likely to change than the :mod:`jax.lax` API.
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Operators
|
|
|
|
---------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
abs
|
|
|
|
add
|
|
|
|
acos
|
2022-03-01 14:46:04 -08:00
|
|
|
approx_max_k
|
|
|
|
approx_min_k
|
2020-07-01 11:01:22 -04:00
|
|
|
argmax
|
|
|
|
argmin
|
2019-04-15 12:14:41 -04:00
|
|
|
asin
|
|
|
|
atan
|
|
|
|
atan2
|
|
|
|
batch_matmul
|
2019-10-21 13:47:36 -04:00
|
|
|
bessel_i0e
|
|
|
|
bessel_i1e
|
2020-01-15 13:13:11 -08:00
|
|
|
betainc
|
2019-04-15 12:14:41 -04:00
|
|
|
bitcast_convert_type
|
|
|
|
bitwise_not
|
|
|
|
bitwise_and
|
|
|
|
bitwise_or
|
|
|
|
bitwise_xor
|
2020-04-28 06:32:52 +01:00
|
|
|
population_count
|
2019-04-15 12:14:41 -04:00
|
|
|
broadcast
|
|
|
|
broadcasted_iota
|
|
|
|
broadcast_in_dim
|
2021-07-22 14:00:52 -07:00
|
|
|
cbrt
|
2019-04-15 12:14:41 -04:00
|
|
|
ceil
|
|
|
|
clamp
|
|
|
|
collapse
|
|
|
|
complex
|
|
|
|
concatenate
|
|
|
|
conj
|
|
|
|
conv
|
|
|
|
convert_element_type
|
2021-10-19 17:18:15 -07:00
|
|
|
conv_dimension_numbers
|
2019-04-15 12:14:41 -04:00
|
|
|
conv_general_dilated
|
2021-05-13 12:20:31 -07:00
|
|
|
conv_general_dilated_local
|
2020-10-20 22:58:53 -07:00
|
|
|
conv_general_dilated_patches
|
2019-04-15 12:14:41 -04:00
|
|
|
conv_with_general_padding
|
|
|
|
conv_transpose
|
|
|
|
cos
|
|
|
|
cosh
|
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:26:29 -04:00
|
|
|
cummax
|
|
|
|
cummin
|
|
|
|
cumprod
|
|
|
|
cumsum
|
2019-04-15 12:14:41 -04:00
|
|
|
digamma
|
|
|
|
div
|
|
|
|
dot
|
|
|
|
dot_general
|
|
|
|
dynamic_index_in_dim
|
|
|
|
dynamic_slice
|
|
|
|
dynamic_slice_in_dim
|
2020-10-14 21:18:09 -04:00
|
|
|
dynamic_update_slice
|
2019-04-15 12:14:41 -04:00
|
|
|
dynamic_update_index_in_dim
|
|
|
|
dynamic_update_slice_in_dim
|
|
|
|
eq
|
|
|
|
erf
|
|
|
|
erfc
|
|
|
|
erf_inv
|
|
|
|
exp
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
expand_dims
|
2019-04-15 12:14:41 -04:00
|
|
|
expm1
|
2019-05-02 17:39:21 -07:00
|
|
|
fft
|
2019-04-15 12:14:41 -04:00
|
|
|
floor
|
|
|
|
full
|
|
|
|
full_like
|
|
|
|
gather
|
|
|
|
ge
|
|
|
|
gt
|
2020-01-29 08:25:21 -08:00
|
|
|
igamma
|
|
|
|
igammac
|
2019-04-15 12:14:41 -04:00
|
|
|
imag
|
|
|
|
index_in_dim
|
|
|
|
index_take
|
|
|
|
iota
|
|
|
|
is_finite
|
|
|
|
le
|
|
|
|
lt
|
|
|
|
lgamma
|
|
|
|
log
|
|
|
|
log1p
|
2022-09-07 06:06:22 -07:00
|
|
|
logistic
|
2019-04-15 12:14:41 -04:00
|
|
|
max
|
|
|
|
min
|
|
|
|
mul
|
|
|
|
ne
|
|
|
|
neg
|
2019-12-11 16:41:24 -05:00
|
|
|
nextafter
|
2019-04-15 12:14:41 -04:00
|
|
|
pad
|
|
|
|
pow
|
|
|
|
real
|
|
|
|
reciprocal
|
|
|
|
reduce
|
2021-04-05 09:54:14 -07:00
|
|
|
reduce_precision
|
2019-04-15 12:14:41 -04:00
|
|
|
reduce_window
|
|
|
|
reshape
|
|
|
|
rem
|
|
|
|
rev
|
|
|
|
round
|
|
|
|
rsqrt
|
|
|
|
scatter
|
|
|
|
scatter_add
|
2021-11-22 15:19:14 -05:00
|
|
|
scatter_max
|
|
|
|
scatter_min
|
|
|
|
scatter_mul
|
2019-04-15 12:14:41 -04:00
|
|
|
select
|
|
|
|
shift_left
|
|
|
|
shift_right_arithmetic
|
|
|
|
shift_right_logical
|
|
|
|
slice
|
|
|
|
slice_in_dim
|
|
|
|
sign
|
|
|
|
sin
|
|
|
|
sinh
|
|
|
|
sort
|
|
|
|
sort_key_val
|
|
|
|
sqrt
|
|
|
|
square
|
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00
|
|
|
squeeze
|
2019-04-15 12:14:41 -04:00
|
|
|
sub
|
|
|
|
tan
|
|
|
|
tie_in
|
2020-04-19 11:49:15 -07:00
|
|
|
top_k
|
2019-04-15 12:14:41 -04:00
|
|
|
transpose
|
|
|
|
|
2021-03-02 09:29:59 -08:00
|
|
|
.. _lax-control-flow:
|
2019-04-15 12:14:41 -04:00
|
|
|
|
|
|
|
Control flow operators
|
|
|
|
----------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
2020-08-03 12:32:32 -07:00
|
|
|
associative_scan
|
2019-04-15 12:14:41 -04:00
|
|
|
cond
|
|
|
|
fori_loop
|
2019-08-05 12:13:07 -07:00
|
|
|
map
|
2019-05-13 14:46:35 -04:00
|
|
|
scan
|
2020-08-03 12:32:32 -07:00
|
|
|
switch
|
2019-04-15 12:14:41 -04:00
|
|
|
while_loop
|
|
|
|
|
2019-10-29 16:00:00 -07:00
|
|
|
Custom gradient operators
|
|
|
|
-------------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
stop_gradient
|
|
|
|
custom_linear_solve
|
|
|
|
custom_root
|
2019-04-15 12:14:41 -04:00
|
|
|
|
2021-03-08 16:25:04 -08:00
|
|
|
.. _jax-parallel-operators:
|
|
|
|
|
2019-04-15 12:14:41 -04:00
|
|
|
Parallel operators
|
|
|
|
------------------
|
|
|
|
|
|
|
|
Parallelism support is experimental.
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
2020-03-19 15:35:00 +00:00
|
|
|
all_gather
|
2019-10-10 15:19:17 -04:00
|
|
|
all_to_all
|
2019-04-15 12:14:41 -04:00
|
|
|
psum
|
2019-05-17 12:27:09 -07:00
|
|
|
pmax
|
|
|
|
pmin
|
2020-04-20 19:03:43 -07:00
|
|
|
pmean
|
2019-05-17 12:27:09 -07:00
|
|
|
ppermute
|
2020-07-14 06:05:45 -07:00
|
|
|
pshuffle
|
2019-10-10 15:19:17 -04:00
|
|
|
pswapaxes
|
2020-03-29 13:56:26 -07:00
|
|
|
axis_index
|
2020-11-05 17:02:19 -05:00
|
|
|
|
|
|
|
Linear algebra operators (jax.lax.linalg)
|
|
|
|
-----------------------------------------
|
|
|
|
|
|
|
|
.. automodule:: jax.lax.linalg
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
cholesky
|
|
|
|
eig
|
|
|
|
eigh
|
2022-11-09 06:23:22 -08:00
|
|
|
hessenberg
|
2020-11-05 17:02:19 -05:00
|
|
|
lu
|
2022-11-09 06:23:22 -08:00
|
|
|
householder_product
|
2021-10-29 14:44:27 -07:00
|
|
|
qdwh
|
2020-11-05 17:02:19 -05:00
|
|
|
qr
|
2022-03-15 09:55:59 -07:00
|
|
|
schur
|
2020-11-05 17:02:19 -05:00
|
|
|
svd
|
|
|
|
triangular_solve
|
2022-11-10 13:15:44 -08:00
|
|
|
tridiagonal
|
2022-03-15 09:55:59 -07:00
|
|
|
tridiagonal_solve
|
2020-12-01 17:23:36 +01:00
|
|
|
|
|
|
|
Argument classes
|
|
|
|
----------------
|
|
|
|
|
2021-02-23 10:31:44 -08:00
|
|
|
.. currentmodule:: jax.lax
|
2020-12-01 17:23:36 +01:00
|
|
|
|
|
|
|
.. autoclass:: ConvDimensionNumbers
|
2021-10-19 17:18:15 -07:00
|
|
|
.. autoclass:: ConvGeneralDilatedDimensionNumbers
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: GatherDimensionNumbers
|
2021-11-22 15:19:14 -05:00
|
|
|
.. autoclass:: GatherScatterMode
|
2020-12-01 17:23:36 +01:00
|
|
|
.. autoclass:: Precision
|
|
|
|
.. autoclass:: RoundingMethod
|
|
|
|
.. autoclass:: ScatterDimensionNumbers
|