rocm_jax/docs/jax.lax.rst
Stephan Hoyer cc8fbb7669
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
2020-05-28 19:12:50 -07:00

172 lines
2.7 KiB
ReStructuredText

jax.lax package
================
.. automodule:: jax.lax
:mod:`jax.lax` is a library of primitives operations that underpins libraries
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on :mod:`jax.lax` primitives.
Many of the primitives are thin wrappers around equivalent XLA operations,
described by the `XLA operation semantics
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
cases JAX diverges from XLA, usually to ensure that the set of operations is
closed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
therefore more stable and less likely to change than the :mod:`jax.lax` API.
Operators
---------
.. autosummary::
:toctree: _autosummary
abs
add
acos
asin
atan
atan2
batch_matmul
bessel_i0e
bessel_i1e
betainc
bitcast_convert_type
bitwise_not
bitwise_and
bitwise_or
bitwise_xor
population_count
broadcast
broadcasted_iota
broadcast_in_dim
ceil
clamp
collapse
complex
concatenate
conj
conv
convert_element_type
conv_general_dilated
conv_with_general_padding
conv_transpose
cos
cosh
digamma
div
dot
dot_general
dynamic_index_in_dim
dynamic_slice
dynamic_slice_in_dim
dynamic_update_index_in_dim
dynamic_update_slice_in_dim
eq
erf
erfc
erf_inv
exp
expand_dims
expm1
fft
floor
full
full_like
gather
ge
gt
igamma
igammac
imag
index_in_dim
index_take
iota
is_finite
le
lt
lgamma
log
log1p
max
min
mul
ne
neg
nextafter
pad
pow
real
reciprocal
reduce
reduce_window
reshape
rem
rev
round
rsqrt
scatter
scatter_add
select
shift_left
shift_right_arithmetic
shift_right_logical
slice
slice_in_dim
sign
sin
sinh
sort
sort_key_val
sqrt
square
squeeze
sub
tan
tie_in
top_k
transpose
Control flow operators
----------------------
.. autosummary::
:toctree: _autosummary
cond
fori_loop
map
scan
while_loop
Custom gradient operators
-------------------------
.. autosummary::
:toctree: _autosummary
stop_gradient
custom_linear_solve
custom_root
Parallel operators
------------------
Parallelism support is experimental.
.. autosummary::
:toctree: _autosummary
all_gather
all_to_all
psum
pmax
pmin
pmean
ppermute
pswapaxes
axis_index