mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fixes broken examples, and (invalid) comment for PartitionSpec
PiperOrigin-RevId: 517531823
This commit is contained in:
parent
c25ea3f0f2
commit
bab1098866
@ -88,11 +88,12 @@ Example:
|
||||
```
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
x = jnp.arange(8)
|
||||
|
||||
# Let's say there are 8 devices in jax.devices()
|
||||
mesh = jax.sharding.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
sharded_x = jax.device_put(x, sharding)
|
||||
|
@ -153,11 +153,13 @@ _UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
"""Tuple of integer specifying how a value should be partitioned.
|
||||
"""Tuple describing how to partition tensor into mesh .
|
||||
|
||||
Each integer corresponds to how many ways a dimension is partitioned. We
|
||||
create a separate class for this so JAX's pytree utilities can distinguish it
|
||||
from a tuple that should be treated as a pytree.
|
||||
Each element is either None, string or a tuple of strings.
|
||||
See``NamedSharding`` class for more details.
|
||||
|
||||
We create a separate class for this so JAX's pytree utilities can distinguish
|
||||
it from a tuple that should be treated as a pytree.
|
||||
"""
|
||||
|
||||
# A sentinel value representing a dim is unconstrained.
|
||||
@ -186,7 +188,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
|
||||
``Mesh`` is "logical mesh".
|
||||
|
||||
``PartitionSpec`` is a named tuple, whose elements can be a ``None``,
|
||||
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
|
||||
a mesh axis or a tuple of mesh axes. Each element describes how an input
|
||||
dimension is partitioned across zero or more mesh dimensions. For example,
|
||||
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
|
||||
|
Loading…
x
Reference in New Issue
Block a user