mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Update lax documentation to reflect new code organization.
This commit is contained in:
parent
c5a381ed4d
commit
407306293f
6
docs/README.md
Normal file
6
docs/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
To rebuild the documentation, install the `sphinx` and `sphinx_rtd_theme` pip
|
||||
packages and then run:
|
||||
|
||||
```
|
||||
sphinx-build -M html . build
|
||||
```
|
143
docs/jax.lax.rst
143
docs/jax.lax.rst
@ -2,5 +2,144 @@ jax.lax package
|
||||
================
|
||||
|
||||
.. automodule:: jax.lax
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
`lax` is a library of primitives that underpins libraries such as `jax.numpy`.
|
||||
|
||||
Many of the primitives are thin wrappers around equivalent XLA operations,
|
||||
described by the `XLA operation semantics
|
||||
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation.
|
||||
|
||||
Where possible, prefer to use libraries such as `jax.numpy` instead of using `jax.lax` directly.
|
||||
|
||||
Operators
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
acos
|
||||
acosh
|
||||
asin
|
||||
asinh
|
||||
atan
|
||||
atanh
|
||||
atan2
|
||||
batch_matmul
|
||||
bitcast_convert_type
|
||||
bitwise_not
|
||||
bitwise_and
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
broadcast
|
||||
broadcasted_iota
|
||||
broadcast_in_dim
|
||||
ceil
|
||||
clamp
|
||||
collapse
|
||||
complex
|
||||
concatenate
|
||||
conj
|
||||
conv
|
||||
convert_element_type
|
||||
conv_general_dilated
|
||||
conv_with_general_padding
|
||||
conv_transpose
|
||||
cos
|
||||
cosh
|
||||
digamma
|
||||
div
|
||||
dot
|
||||
dot_general
|
||||
dynamic_index_in_dim
|
||||
dynamic_slice
|
||||
dynamic_slice_in_dim
|
||||
dynamic_update_index_in_dim
|
||||
dynamic_update_slice_in_dim
|
||||
eq
|
||||
erf
|
||||
erfc
|
||||
erf_inv
|
||||
exp
|
||||
expm1
|
||||
floor
|
||||
full
|
||||
full_like
|
||||
gather
|
||||
ge
|
||||
gt
|
||||
imag
|
||||
index_in_dim
|
||||
index_take
|
||||
iota
|
||||
is_finite
|
||||
le
|
||||
lt
|
||||
lgamma
|
||||
log
|
||||
log1p
|
||||
max
|
||||
min
|
||||
mul
|
||||
ne
|
||||
neg
|
||||
pad
|
||||
pow
|
||||
real
|
||||
reciprocal
|
||||
reduce
|
||||
reduce_window
|
||||
reshape
|
||||
rem
|
||||
rev
|
||||
round
|
||||
rsqrt
|
||||
scatter
|
||||
scatter_add
|
||||
select
|
||||
shaped_identity
|
||||
shift_left
|
||||
shift_right_arithmetic
|
||||
shift_right_logical
|
||||
slice
|
||||
slice_in_dim
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
sort
|
||||
sort_key_val
|
||||
sqrt
|
||||
square
|
||||
stop_gradient
|
||||
sub
|
||||
tan
|
||||
tie_in
|
||||
transpose
|
||||
|
||||
|
||||
Control flow operators
|
||||
----------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cond
|
||||
fori_loop
|
||||
while_loop
|
||||
|
||||
|
||||
Parallel operators
|
||||
------------------
|
||||
|
||||
Parallelism support is experimental.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
pcollect
|
||||
pmax
|
||||
psplit
|
||||
psplit_like
|
||||
psum
|
||||
pswapaxes
|
||||
|
@ -12,14 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
`lax` is a library of primitives that underpins libraries such as `jax.numpy`.
|
||||
|
||||
Many of the primitives are thin wrappers around equivalent XLA operations,
|
||||
described by the `XLA operation semantics
|
||||
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
Loading…
x
Reference in New Issue
Block a user