diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 22fda5a4e..39e6674a4 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -220,6 +220,32 @@ def pmap_simple_8_devices(state): d.block_until_ready() +def _run_sda_index_bench(state, num_devices): + x = jax.pmap(jnp.sin)(jnp.arange(num_devices)) + jax.device_get(x) + while state: + for i in range(num_devices): + _ = x[i] + + +@google_benchmark.register +@required_devices(1) +def sda_index_1(state): + _run_sda_index_bench(state, 1) + + +@google_benchmark.register +@required_devices(2) +def sda_index_2(state): + _run_sda_index_bench(state, 2) + + +@google_benchmark.register +@required_devices(8) +def sda_index_8(state): + _run_sda_index_bench(state, 8) + + def swap(a, b): return b, a diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index aa45011f8..b37f94592 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -553,16 +553,16 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1) else: cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx)) - if self._npy_value is None: - try: - buf_idx = self.indices.index(cidx) - except ValueError: - buf_idx = None - if buf_idx is not None: - buf = self.device_buffers[buf_idx] - aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) - return xla.make_device_array(aval, None, buf) - return xla.DeviceArray.__getitem__(self, idx) + try: + buf_idx = self.indices.index(cidx) + except ValueError: + # NOTE: Slow path, this will materialize the sharded array on a single + # device and use XLA's Gather to index into the resulting array. + return xla.DeviceArray.__getitem__(self, idx) + else: + buf = self.device_buffers[buf_idx] + aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) + return xla.make_device_array(aval, None, buf) def _hashable_index(idx):