diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ce178e4e..04304216b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. class MyTestCase(jtu.JaxTestCase): ... ``` - +* Added ``jax.scipy.linalg.schur``, ``jax.scipy.linalg.sqrtm``, + ``jax.scipy.signal.csd``, ``jax.scipy.signal.stft``, + ``jax.scipy.signal.welch``. ## jaxlib 0.3.1 (Unreleased) * Changes diff --git a/docs/jax.distributed.rst b/docs/jax.distributed.rst new file mode 100644 index 000000000..92718cee0 --- /dev/null +++ b/docs/jax.distributed.rst @@ -0,0 +1,11 @@ +jax.distributed module +====================== + +.. currentmodule:: jax.distributed + +.. automodule:: jax.distributed + +.. autosummary:: + :toctree: _autosummary + + initialize \ No newline at end of file diff --git a/docs/jax.dlpack.rst b/docs/jax.dlpack.rst index 3f0dcf4a8..996ee3f0e 100644 --- a/docs/jax.dlpack.rst +++ b/docs/jax.dlpack.rst @@ -1,6 +1,12 @@ jax.dlpack module ================= +.. currentmodule:: jax.dlpack + .. automodule:: jax.dlpack - :members: - :show-inheritance: \ No newline at end of file + +.. autosummary:: + :toctree: _autosummary + + from_dlpack + to_dlpack \ No newline at end of file diff --git a/docs/jax.rst b/docs/jax.rst index 8ad5955c9..aa2e5e472 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -11,18 +11,19 @@ Subpackages jax.numpy jax.scipy + jax.config + jax.dlpack + jax.distributed jax.example_libraries jax.experimental + jax.flatten_util jax.image jax.lax jax.nn jax.ops + jax.profiler jax.random jax.tree_util - jax.flatten_util - jax.dlpack - jax.profiler - jax.config .. toctree:: :hidden: diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 3ff75a7e5..622dc30fe 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -33,6 +33,8 @@ jax.scipy.linalg lu_factor lu_solve qr + schur + sqrtm solve solve_triangular svd @@ -72,6 +74,9 @@ jax.scipy.signal convolve2d correlate correlate2d + csd + stft + welch jax.scipy.sparse.linalg ----------------------- diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5487f4f80..e1c1c2430 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -27,20 +27,24 @@ def initialize(coordinator_address: str, num_processes: int, process_id: int): is not required for CPU or TPU backends. Args: - coordinator_address: IP address of the coordinator. + coordinator_address: IP address and port of the coordinator. The choice of + port does not matter, so long as the port is available on the coordinator + and all processes agree on the port. num_processes: Number of processes. - process_id: Id of the current processe. + process_id: Id of the current process. Example: Suppose there are two GPU hosts, and host 0 is the designated coordinator - with address '10.0.0.1:1234', to initialize the GPU cluster, run the + with address ``10.0.0.1:1234``. To initialize the GPU cluster, run the following commands before anything else. - On host 0 + On host 0: + >>> jax.distributed.initialize('10.0.0.1:1234', 2, 0) # doctest: +SKIP - On host 1 + On host 1: + >>> jax.distributed.initialize('10.0.0.1:1234', 2, 1) # doctest: +SKIP """ if process_id == 0: