mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added structure to make_array_from_single_device_arrays doc.
This commit is contained in:
parent
03a8e5885b
commit
28d25a1196
@ -745,21 +745,26 @@ def make_array_from_single_device_arrays(
|
||||
with each device receiving at least one example. In this case, the following recipe will use
|
||||
`make_array_from_single_device_arrays` to create a global jax.Array.
|
||||
|
||||
First, we create the per host data as Numpy arrays.
|
||||
|
||||
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
|
||||
>>> rows_per_device = 2
|
||||
>>> feature_length = 32
|
||||
>>> per_device_shape = (rows_per_device, feature_length)
|
||||
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
|
||||
>>> global_shape = (rows_per_device * jax.device_count(), ) + per_device_shape[1:]
|
||||
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
|
||||
...
|
||||
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
|
||||
|
||||
Second, we put the Numpy data onto the local devices as single device Jax Arrays. Then we call
|
||||
make_array_from_single_device_arrays to make the global Array.
|
||||
|
||||
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
|
||||
>>> per_device_data = np.split(per_host_data, len(mesh.local_devices), axis = 0) # per device data, but on host
|
||||
>>> per_device_data_on_device = jax.device_put(per_device_data, mesh.local_devices) # per device data, now on device
|
||||
>>> output_global_array = jax.make_array_from_single_device_arrays(global_shape, sharding, per_device_data_on_device)
|
||||
...
|
||||
>>> assert output_global_array.addressable_data(0).shape == (rows_per_device, feature_length)
|
||||
>>> assert output_global_array.shape == (rows_per_device * jax.device_count(), feature_length)
|
||||
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
|
||||
>>> assert output_global_array.shape == global_shape
|
||||
|
||||
When using tensor parallelism (equivalent to sharding across both rows and columns in the
|
||||
above example), the above example doesn't generate the data in the sharding that you plan
|
||||
|
Loading…
x
Reference in New Issue
Block a user