diff --git a/docs/jax.rst b/docs/jax.rst index 2f6442958..7396c439e 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -77,6 +77,14 @@ Automatic differentiation closure_convert checkpoint +jax.Array (:code:`jax.Array`) +----------------------------- + +.. autosummary:: + :toctree: _autosummary + + make_array_from_callback + make_array_from_single_device_arrays Vectorization (:code:`vmap`) ---------------------------- diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 9dc8c0868..0048b66a4 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -140,12 +140,12 @@ Please use `addressable_shards` and `addressable_data` which are compatible with All JAX functions will output `jax.Array` when the `jax_array` flag is True. If you were using `GlobalDeviceArray.from_callback` or `make_sharded_device_array` or `make_device_array` functions to explicitly create the respective JAX data -types, you will need to switch them to use `jax.make_array_from_callback` or -`jax.make_array_from_single_device_arrays`. +types, you will need to switch them to use {func}`jax.make_array_from_callback` +or {func}`jax.make_array_from_single_device_arrays`. **For GDA:** -`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become +`GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)` can become `jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)` in a 1:1 switch. diff --git a/jax/_src/array.py b/jax/_src/array.py index 85957dbeb..0fe758218 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -510,6 +510,40 @@ setattr(ArrayImpl, "__array_priority__", 100) def make_array_from_callback( shape: Shape, sharding: Sharding, data_callback: Callable[[Optional[Index]], ArrayLike]) -> ArrayImpl: + """Returns a ``jax.Array`` via data fetched from ``data_callback``. + + ``data_callback`` is used to fetch the data for each addressable shard of the + returned ``jax.Array``. + + Args: + shape : Shape of the ``jax.Array``. + sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is + laid out across devices. + data_callback : Callback that takes indices into the global array value as + input and returns the corresponding data of the global array value. + The data can be returned as any array-like object, e.g. a ``numpy.ndarray``. + + Returns: + A ``jax.Array`` via data fetched from ``data_callback``. + + Example: + + >>> from jax.experimental.maps import Mesh + >>> from jax.experimental import PartitionSpec as P + >>> import numpy as np + ... + >>> input_shape = (8, 8) + >>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape) + >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) + >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) + ... + >>> def cb(index): + ... return global_input_data[index] + ... + >>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb) + >>> arr.addressable_data(0).shape + (4, 2) + """ device_to_index_map = sharding.devices_indices_map(shape) # Use addressable_devices here instead of `_addressable_device_assignment` # because `_addressable_device_assignment` is only available on @@ -525,6 +559,41 @@ def make_array_from_callback( def make_array_from_single_device_arrays( shape: Shape, sharding: Sharding, arrays: Sequence[ArrayImpl]) -> ArrayImpl: + r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device. + + ``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use + this function if you have already ``jax.device_put`` the value on a single + device and want to create a global Array. The smaller ``jax.Array``\s should be + addressable and belong to the current process. + + Args: + shape : Shape of the ``jax.Array``. + sharding: A ``Sharding`` instance which describes how the ``jax.Array`` is + laid out across devices. + arrays: Sequence of ``jax.Array``\s that are on a single device. + + Returns: + A ``jax.Array`` from a sequence of ``jax.Array``\s on a single device. + + Example: + + >>> from jax.experimental.maps import Mesh + >>> from jax.experimental import PartitionSpec as P + >>> import numpy as np + ... + >>> shape = (8, 8) + >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) + >>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) + >>> inp_data = np.arange(prod(shape)).reshape(shape) + ... + >>> arrays = [ + ... jax.device_put(inp_data[index], d) + ... for d, index in sharding.addressable_devices_indices_map(shape).items()] + ... + >>> arr = jax.make_array_from_single_device_arrays(shape, sharding, arrays) + >>> arr.addressable_data(0).shape + (4, 2) + """ # All input arrays should be committed. Checking it is expensive on # single-controller systems. aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)