mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #9621 from hawkinsp:docs3
PiperOrigin-RevId: 429401742
This commit is contained in:
commit
e545daa1e5
@ -19,7 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
class MyTestCase(jtu.JaxTestCase):
|
||||
...
|
||||
```
|
||||
|
||||
* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``,
|
||||
``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``,
|
||||
``jax.scipy.signal.welch``.
|
||||
|
||||
## jaxlib 0.3.1 (Unreleased)
|
||||
* Changes
|
||||
|
11
docs/jax.distributed.rst
Normal file
11
docs/jax.distributed.rst
Normal file
@ -0,0 +1,11 @@
|
||||
jax.distributed module
|
||||
======================
|
||||
|
||||
.. currentmodule:: jax.distributed
|
||||
|
||||
.. automodule:: jax.distributed
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
initialize
|
@ -1,6 +1,12 @@
|
||||
jax.dlpack module
|
||||
=================
|
||||
|
||||
.. currentmodule:: jax.dlpack
|
||||
|
||||
.. automodule:: jax.dlpack
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
from_dlpack
|
||||
to_dlpack
|
@ -11,18 +11,19 @@ Subpackages
|
||||
|
||||
jax.numpy
|
||||
jax.scipy
|
||||
jax.config
|
||||
jax.dlpack
|
||||
jax.distributed
|
||||
jax.example_libraries
|
||||
jax.experimental
|
||||
jax.flatten_util
|
||||
jax.image
|
||||
jax.lax
|
||||
jax.nn
|
||||
jax.ops
|
||||
jax.profiler
|
||||
jax.random
|
||||
jax.tree_util
|
||||
jax.flatten_util
|
||||
jax.dlpack
|
||||
jax.profiler
|
||||
jax.config
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
@ -33,6 +33,8 @@ jax.scipy.linalg
|
||||
lu_factor
|
||||
lu_solve
|
||||
qr
|
||||
schur
|
||||
sqrtm
|
||||
solve
|
||||
solve_triangular
|
||||
svd
|
||||
@ -72,6 +74,9 @@ jax.scipy.signal
|
||||
convolve2d
|
||||
correlate
|
||||
correlate2d
|
||||
csd
|
||||
stft
|
||||
welch
|
||||
|
||||
jax.scipy.sparse.linalg
|
||||
-----------------------
|
||||
|
@ -27,20 +27,24 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int):
|
||||
is not required for CPU or TPU backends.
|
||||
|
||||
Args:
|
||||
coordinator_address: IP address of the coordinator.
|
||||
coordinator_address: IP address and port of the coordinator. The choice of
|
||||
port does not matter, so long as the port is available on the coordinator
|
||||
and all processes agree on the port.
|
||||
num_processes: Number of processes.
|
||||
process_id: Id of the current processe.
|
||||
process_id: Id of the current process.
|
||||
|
||||
Example:
|
||||
|
||||
Suppose there are two GPU hosts, and host 0 is the designated coordinator
|
||||
with address '10.0.0.1:1234', to initialize the GPU cluster, run the
|
||||
with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the
|
||||
following commands before anything else.
|
||||
|
||||
On host 0
|
||||
On host 0:
|
||||
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP
|
||||
|
||||
On host 1
|
||||
On host 1:
|
||||
|
||||
>>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP
|
||||
"""
|
||||
if process_id == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user