diff --git a/benchmarks/gda_benchmark.py b/benchmarks/gda_benchmark.py index 66fbcfe4e..2b33ef621 100644 --- a/benchmarks/gda_benchmark.py +++ b/benchmarks/gda_benchmark.py @@ -87,6 +87,26 @@ def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state): gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes) +def gda_local_shards(mesh_shape, mesh_axes, state): + # `device_put` time is not measured in this benchmark. All the devices here + # are local. + global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) + global_input_shape = (2048, 2048) + global_input_data = np.arange( + prod(global_input_shape)).reshape(global_input_shape) + global_indices = gda.get_shard_indices(global_input_shape, global_mesh, + mesh_axes) + dbs = [ + jax.device_put(global_input_data[global_indices[device]], device) + for device in global_mesh.local_devices + ] + gda_inp = gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, + dbs) + + while state: + gda_inp._create_local_shards() + + benchmarks = [] for mesh_shape, axes in mesh_shapes_axes: benchmarks.extend([ @@ -102,6 +122,9 @@ for mesh_shape, axes in mesh_shapes_axes: google_benchmark.register( partial(indices_replica_id_calc_cached, mesh_shape, axes), name=f"indices_replica_id_calc_cached_{mesh_shape}_{axes}"), + google_benchmark.register( + partial(gda_local_shards, mesh_shape, axes), + name=f"gda_local_shards_{mesh_shape}_{axes}"), ]) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 45c3f6b97..a400143ec 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -21,6 +21,7 @@ from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Di from jax import core from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.config import config from jax.interpreters import pxla, xla from jax._src.util import prod, safe_zip, cache from jax._src.api import device_put @@ -130,6 +131,12 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape: return tuple(chunk_size) +def _set_aval(val): + if val.aval is None: + val.aval = core.ShapedArray(val.shape, val.dtype) + return val + + @dataclasses.dataclass(frozen=True) class Shard: """A single data shard of a GlobalDeviceArray. @@ -274,7 +281,8 @@ class GlobalDeviceArray: def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray], - _gda_fast_path_args: Optional[_GdaFastPathArgs] = None): + _gda_fast_path_args: Optional[_GdaFastPathArgs] = None, + _enable_checks: bool = True): self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes @@ -288,24 +296,25 @@ class GlobalDeviceArray: else: self._local_devices = self._gda_fast_path_args.local_devices - for db, ld in safe_zip(device_buffers, self._local_devices): - if db.device() != ld: - raise ValueError( - "The `global_mesh.local_devices` and `device_buffers` device order " - "doesn't match. Please use `global_mesh.local_devices` to put " - "arrays on devices instead of `jax.local_devices()`") + if _enable_checks or config.jax_enable_checks: + for db, ld in safe_zip(device_buffers, self._local_devices): + if db.device() != ld: + raise ValueError( + "The `global_mesh.local_devices` and `device_buffers` device " + "order doesn't match. Please use `global_mesh.local_devices` to " + "put arrays on devices instead of `jax.local_devices()`") - self._local_shards = self._create_local_shards() - - ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes) - assert all(db.shape == ss for db in device_buffers), ( - f"Expected shard shape {ss} doesn't match the device buffer " - f"shape, got: {[db.shape for db in device_buffers]}") + if _enable_checks or config.jax_enable_checks: + ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes) + assert all(db.shape == ss for db in device_buffers), ( + f"Expected shard shape {ss} doesn't match the device buffer " + f"shape, got: {[db.shape for db in device_buffers]}") dtype = device_buffers[0].dtype - assert all(db.dtype == dtype for db in device_buffers), ( - "Input arrays to GlobalDeviceArray must have matching dtypes, " - f"got: {[db.dtype for db in device_buffers]}") + if _enable_checks or config.jax_enable_checks: + assert all(db.dtype == dtype for db in device_buffers), ( + "Input arrays to GlobalDeviceArray must have matching dtypes, " + f"got: {[db.dtype for db in device_buffers]}") self.dtype = dtype def __eq__(self, other: object): @@ -358,16 +367,15 @@ class GlobalDeviceArray: out = [] for db in self._device_buffers: - if db.aval is None: - db.aval = core.ShapedArray(db.shape, db.dtype) + db = _set_aval(db) device = db.device() index, rid = global_indices_rid[device] out.append(Shard(device, index, rid, db)) return out - @property + @pxla.maybe_cached_property def local_shards(self) -> Sequence[Shard]: - return self._local_shards + return self._create_local_shards() @property def global_shards(self) -> Sequence[Shard]: @@ -393,11 +401,11 @@ class GlobalDeviceArray: return global_shards def local_data(self, index) -> DeviceArray: - return self.local_shards[index].data + return _set_aval(self._device_buffers[index]) def block_until_ready(self): - for s in self.local_shards: - s.data.block_until_ready() + for db in self._device_buffers: + db.block_until_ready() return self @classmethod @@ -563,7 +571,7 @@ xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray( xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity def _gda_shard_arg(x, devices, indices): - return [s.data for s in x.local_shards] + return [d for d in x._device_buffers] pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg @@ -573,6 +581,7 @@ def _gda_array_result_handler(global_aval, out_axis_resources, global_mesh): local_devices = global_mesh.local_devices fast_path_args = _GdaFastPathArgs(global_idx_rid, local_devices) return lambda bufs: GlobalDeviceArray( - global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args) + global_aval.shape, global_mesh, out_axis_resources, bufs, fast_path_args, + _enable_checks=False) pxla.global_result_handlers[core.ShapedArray] = _gda_array_result_handler pxla.global_result_handlers[core.ConcreteArray] = _gda_array_result_handler