Fixes broken examples, and (invalid) comment for PartitionSpec

PiperOrigin-RevId: 517531823
This commit is contained in:
Mark Sandler 2023-03-17 16:09:14 -07:00 committed by jax authors
parent c25ea3f0f2
commit bab1098866
2 changed files with 10 additions and 7 deletions

View File

@ -88,11 +88,12 @@ Example:
```
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)

View File

@ -153,11 +153,13 @@ _UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
class PartitionSpec(tuple):
"""Tuple of integer specifying how a value should be partitioned.
"""Tuple describing how to partition tensor into mesh .
Each integer corresponds to how many ways a dimension is partitioned. We
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
Each element is either None, string or a tuple of strings.
See``NamedSharding`` class for more details.
We create a separate class for this so JAX's pytree utilities can distinguish
it from a tuple that should be treated as a pytree.
"""
# A sentinel value representing a dim is unconstrained.
@ -186,7 +188,7 @@ class NamedSharding(XLACompatibleSharding):
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
``Mesh`` is "logical mesh".
``PartitionSpec`` is a named tuple, whose elements can be a ``None``,
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
a mesh axis or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data