Adds a note that pjit is equivalent to jit.

PiperOrigin-RevId: 535296532
This commit is contained in:
Mark Sandler 2023-05-25 10:13:50 -07:00 committed by jax authors
parent 32026ad18b
commit bc547aa318

View File

@ -578,6 +578,7 @@ def pjit(
) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
NOTE: This function is now equivalent to jax.jit please use that instead.
The returned function has semantics equivalent to those of ``fun``, but is
compiled to an XLA computation that runs across multiple devices
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted