Use jax.make_array_from_process_local_data API in distributed data loading doc

PiperOrigin-RevId: 677973689
This commit is contained in:
Yash Katariya 2024-09-23 16:02:51 -07:00 committed by jax authors
parent e4091a6752
commit a99ea73336

View File

@ -243,35 +243,10 @@ ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch, jax.local_device_count())
# Thanks to the very important trick about data parallelism, no need to care what
# order the devices appear in the sharding.
sharding = jax.sharding.PositionalSharding(jax.devices())
# PositionalSharding must have same rank as data being sharded.
sharding = sharding.reshape((jax.device_count(),) +
(1,) * (per_process_batch.ndim - 1))
global_batch_size = per_replica_batch_size * jax.device_count()
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
global_batch_array = jax.make_array_from_single_device_arrays(
global_batch_shape, sharding,
# Thanks again to the very important trick, no need to care which device gets
# which per-replica batch.
arrays=[jax.device_put(batch, device)
for batch, device
in zip(per_replica_batches, sharding.addressable_devices)])
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
mesh = jax.make_mesh((jax.device_count(),), ('batch',))
sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
```
## Data + model parallelism
@ -366,16 +341,6 @@ per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
# isn't 0
per_replica_batch_size = (per_process_batch_size //
num_model_replicas_per_process)
assert per_process_batch_size % per_replica_batch_size == 0, \
"This example doesn't implement padding."
per_replica_batches = np.split(per_process_batch,
num_model_replicas_per_process)
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
@ -393,35 +358,8 @@ mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_size = per_replica_batch_size * num_model_replicas_total
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
# Create the final jax.Array using jax.make_array_from_callback. The callback
# will be called for each local device, and passed the N-D numpy-style index
# that describes what shard of the global data that device should receive.
#
# You don't need care exactly which index is passed in due to the very important data
# parallelism, but you do use the index argument to make sure you replicate each
# per-replica batch correctly -- the `index` argument will be the same for
# devices in the same model replica, and different for devices in different
# model replicas.
index_to_batch = {}
def callback(index: tuple[slice, ...]) -> np.ndarray:
# Python `slice` objects aren't hashable, so manually create dict key.
index_key = tuple((slice_.start, slice_.stop) for slice_ in index)
if index_key not in index_to_batch:
# You don't care which per-replica batch goes to which replica, just take the
# next unused one.
index_to_batch[index_key] = per_replica_batches[len(index_to_batch)]
return index_to_batch[index_key]
global_batch_array = jax.make_array_from_callback(
global_batch_shape, sharding, callback)
assert global_batch_array.shape == global_batch_shape
assert (global_batch_array.addressable_shards[0].data.shape ==
per_replica_batches[0].shape)
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
```
### Model parallelism across processes