mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Use jax.make_array_from_process_local_data
API in distributed data loading doc
PiperOrigin-RevId: 677973689
This commit is contained in:
parent
e4091a6752
commit
a99ea73336
@ -243,35 +243,10 @@ ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
|
||||
# Grab just the first batch from the Dataset for this example
|
||||
per_process_batch = ds.as_numpy_iterator().next()
|
||||
|
||||
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
|
||||
# isn't 0
|
||||
|
||||
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
|
||||
assert per_process_batch_size % per_replica_batch_size == 0, \
|
||||
"This example doesn't implement padding."
|
||||
per_replica_batches = np.split(per_process_batch, jax.local_device_count())
|
||||
|
||||
# Thanks to the very important trick about data parallelism, no need to care what
|
||||
# order the devices appear in the sharding.
|
||||
sharding = jax.sharding.PositionalSharding(jax.devices())
|
||||
# PositionalSharding must have same rank as data being sharded.
|
||||
sharding = sharding.reshape((jax.device_count(),) +
|
||||
(1,) * (per_process_batch.ndim - 1))
|
||||
|
||||
global_batch_size = per_replica_batch_size * jax.device_count()
|
||||
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
|
||||
|
||||
global_batch_array = jax.make_array_from_single_device_arrays(
|
||||
global_batch_shape, sharding,
|
||||
# Thanks again to the very important trick, no need to care which device gets
|
||||
# which per-replica batch.
|
||||
arrays=[jax.device_put(batch, device)
|
||||
for batch, device
|
||||
in zip(per_replica_batches, sharding.addressable_devices)])
|
||||
|
||||
assert global_batch_array.shape == global_batch_shape
|
||||
assert (global_batch_array.addressable_shards[0].data.shape ==
|
||||
per_replica_batches[0].shape)
|
||||
mesh = jax.make_mesh((jax.device_count(),), ('batch',))
|
||||
sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
|
||||
global_batch_array = jax.make_array_from_process_local_data(
|
||||
sharding, per_process_batch)
|
||||
```
|
||||
|
||||
## Data + model parallelism
|
||||
@ -366,16 +341,6 @@ per_process_batch = ds.as_numpy_iterator().next()
|
||||
num_model_replicas_per_process = 2 # set according to your parallelism strategy
|
||||
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
|
||||
|
||||
per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim
|
||||
# isn't 0
|
||||
|
||||
per_replica_batch_size = (per_process_batch_size //
|
||||
num_model_replicas_per_process)
|
||||
assert per_process_batch_size % per_replica_batch_size == 0, \
|
||||
"This example doesn't implement padding."
|
||||
per_replica_batches = np.split(per_process_batch,
|
||||
num_model_replicas_per_process)
|
||||
|
||||
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
|
||||
# are grouped by process, and then resize so each row is a model replica.
|
||||
mesh_devices = np.array([jax.local_devices(process_idx)
|
||||
@ -393,35 +358,8 @@ mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec("model_replicas"))
|
||||
|
||||
global_batch_size = per_replica_batch_size * num_model_replicas_total
|
||||
global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:])
|
||||
|
||||
# Create the final jax.Array using jax.make_array_from_callback. The callback
|
||||
# will be called for each local device, and passed the N-D numpy-style index
|
||||
# that describes what shard of the global data that device should receive.
|
||||
#
|
||||
# You don't need care exactly which index is passed in due to the very important data
|
||||
# parallelism, but you do use the index argument to make sure you replicate each
|
||||
# per-replica batch correctly -- the `index` argument will be the same for
|
||||
# devices in the same model replica, and different for devices in different
|
||||
# model replicas.
|
||||
|
||||
index_to_batch = {}
|
||||
def callback(index: tuple[slice, ...]) -> np.ndarray:
|
||||
# Python `slice` objects aren't hashable, so manually create dict key.
|
||||
index_key = tuple((slice_.start, slice_.stop) for slice_ in index)
|
||||
if index_key not in index_to_batch:
|
||||
# You don't care which per-replica batch goes to which replica, just take the
|
||||
# next unused one.
|
||||
index_to_batch[index_key] = per_replica_batches[len(index_to_batch)]
|
||||
return index_to_batch[index_key]
|
||||
|
||||
global_batch_array = jax.make_array_from_callback(
|
||||
global_batch_shape, sharding, callback)
|
||||
|
||||
assert global_batch_array.shape == global_batch_shape
|
||||
assert (global_batch_array.addressable_shards[0].data.shape ==
|
||||
per_replica_batches[0].shape)
|
||||
global_batch_array = jax.make_array_from_process_local_data(
|
||||
sharding, per_process_batch)
|
||||
```
|
||||
|
||||
### Model parallelism across processes
|
||||
|
Loading…
x
Reference in New Issue
Block a user