mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix the docs build
This commit is contained in:
parent
aa8c6edc7b
commit
a65f74b392
@ -668,17 +668,18 @@ def pjit(
|
||||
will be partitioned. With this, using a mesh context manager is not
|
||||
required.
|
||||
- :py:obj:`None` is a special case whose semantics are:
|
||||
- if the mesh context manager is *not* provided, JAX has the freedom to
|
||||
choose whatever sharding it wants.
|
||||
For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
- If the mesh context manager is provided, None will imply that the
|
||||
value will be replicated on all devices of the mesh.
|
||||
- if the mesh context manager is *not* provided, JAX has the freedom to
|
||||
choose whatever sharding it wants.
|
||||
For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
- If the mesh context manager is provided, None will imply that the
|
||||
value will be replicated on all devices of the mesh.
|
||||
- For backwards compatibility, in_shardings still supports ingesting
|
||||
:py:class:`PartitionSpec`. This option can *only* be used with the
|
||||
mesh context manager.
|
||||
|
||||
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
|
||||
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
|
||||
axis or a tuple of mesh axes, and specifies the set of resources assigned
|
||||
|
Loading…
x
Reference in New Issue
Block a user