Added structure to make_array_from_single_device_arrays doc.

This commit is contained in:
Rafi Witten 2024-01-19 21:22:30 +00:00
parent 03a8e5885b
commit 28d25a1196

View File

@ -745,21 +745,26 @@ def make_array_from_single_device_arrays(
with each device receiving at least one example. In this case, the following recipe will use
`make_array_from_single_device_arrays` to create a global jax.Array.
First, we create the per host data as Numpy arrays.
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
>>> rows_per_device = 2
>>> feature_length = 32
>>> per_device_shape = (rows_per_device, feature_length)
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
>>> global_shape = (rows_per_device * jax.device_count(), ) + per_device_shape[1:]
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
...
>>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays
Second, we put the Numpy data onto the local devices as single device Jax Arrays. Then we call
make_array_from_single_device_arrays to make the global Array.
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
>>> per_device_data = np.split(per_host_data, len(mesh.local_devices), axis = 0) # per device data, but on host
>>> per_device_data_on_device = jax.device_put(per_device_data, mesh.local_devices) # per device data, now on device
>>> output_global_array = jax.make_array_from_single_device_arrays(global_shape, sharding, per_device_data_on_device)
...
>>> assert output_global_array.addressable_data(0).shape == (rows_per_device, feature_length)
>>> assert output_global_array.shape == (rows_per_device * jax.device_count(), feature_length)
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape
When using tensor parallelism (equivalent to sharding across both rows and columns in the
above example), the above example doesn't generate the data in the sharding that you plan