Merge pull request #9621 from hawkinsp:docs3

PiperOrigin-RevId: 429401742
This commit is contained in:
jax authors 2022-02-17 14:23:13 -08:00
commit e545daa1e5
6 changed files with 41 additions and 12 deletions

View File

@ -19,7 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
class MyTestCase(jtu.JaxTestCase):
...
```
* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``,
``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``,
``jax.scipy.signal.welch``.
## jaxlib 0.3.1 (Unreleased)
* Changes

11
docs/jax.distributed.rst Normal file
View File

@ -0,0 +1,11 @@
jax.distributed module
======================
.. currentmodule:: jax.distributed
.. automodule:: jax.distributed
.. autosummary::
:toctree: _autosummary
initialize

View File

@ -1,6 +1,12 @@
jax.dlpack module
=================
.. currentmodule:: jax.dlpack
.. automodule:: jax.dlpack
:members:
:show-inheritance:
.. autosummary::
:toctree: _autosummary
from_dlpack
to_dlpack

View File

@ -11,18 +11,19 @@ Subpackages
jax.numpy
jax.scipy
jax.config
jax.dlpack
jax.distributed
jax.example_libraries
jax.experimental
jax.flatten_util
jax.image
jax.lax
jax.nn
jax.ops
jax.profiler
jax.random
jax.tree_util
jax.flatten_util
jax.dlpack
jax.profiler
jax.config
.. toctree::
:hidden:

View File

@ -33,6 +33,8 @@ jax.scipy.linalg
lu_factor
lu_solve
qr
schur
sqrtm
solve
solve_triangular
svd
@ -72,6 +74,9 @@ jax.scipy.signal
convolve2d
correlate
correlate2d
csd
stft
welch
jax.scipy.sparse.linalg
-----------------------

View File

@ -27,20 +27,24 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
is not required for CPU or TPU backends.
Args:
coordinator_address: IP address of the coordinator.
coordinator_address: IP address and port of the coordinator. The choice of
port does not matter, so long as the port is available on the coordinator
and all processes agree on the port.
num_processes: Number of processes.
process_id: Id of the current processe.
process_id: Id of the current process.
Example:
Suppose there are two GPU hosts, and host 0 is the designated coordinator
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
following commands before anything else.
On host 0
On host 0:
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
On host 1
On host 1:
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
"""
if process_id == 0: