mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve data parallel example
fix example fix example fix example fix example fix example fix example
This commit is contained in:
parent
ae9160a4e9
commit
a192b5e541
@ -646,14 +646,13 @@ def make_array_from_single_device_arrays(
|
||||
i.e. each process receives a different part of the data, then you can use
|
||||
`make_array_from_single_device_arrays` to create a global jax.Array
|
||||
|
||||
>>> global_shape = (8, 2)
|
||||
>>> host_array = np.arange(math.prod(global_shape)).reshape(global_shape)
|
||||
>>> local_shape = (8, 2)
|
||||
>>> global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
|
||||
>>> local_array = np.arange(math.prod(local_shape)).reshape(local_shape)
|
||||
>>> arrays = jax.device_put(
|
||||
... np.split(host_array, len(global_mesh.local_devices), axis=0),
|
||||
... global_mesh.local_devices)
|
||||
>>> arr = jax.make_array_from_single_device_arrays(
|
||||
... global_shape, jax.sharding.NamedSharding(global_mesh, P(('x', 'y'))),
|
||||
... arrays)
|
||||
... np.split(local_array, len(global_mesh.local_devices), axis = 0), global_mesh.local_devices)
|
||||
>>> sharding = jax.sharding.NamedSharding(global_mesh, P(('x', 'y'), ))
|
||||
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
|
||||
>>> arr.addressable_data(0).shape
|
||||
(1, 2)
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user