Fix the docs build

This commit is contained in:
yashkatariya 2023-06-16 13:14:38 -07:00
parent aa8c6edc7b
commit a65f74b392

View File

@ -668,17 +668,18 @@ def pjit(
will be partitioned. With this, using a mesh context manager is not
required.
- :py:obj:`None` is a special case whose semantics are:
- if the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the
value will be replicated on all devices of the mesh.
- if the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the
value will be replicated on all devices of the mesh.
- For backwards compatibility, in_shardings still supports ingesting
:py:class:`PartitionSpec`. This option can *only* be used with the
mesh context manager.
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
axis or a tuple of mesh axes, and specifies the set of resources assigned