Remove special casing on npy_value when indexing sharded arrays.

Before:

```
GPU
sda_index_2                           2912972 ns      2778716 ns          256

TPU
sda_index_1                            769968 ns       751700 ns          921
sda_index_2                           1510841 ns      1489716 ns          465
sda_index_8                           6102259 ns      6027655 ns          117
```

After:

```
GPU
sda_index_2                             28095 ns        27983 ns        25463

TPU
sda_index_1                             10302 ns        10279 ns        67884
sda_index_2                             20010 ns        19947 ns        34628
sda_index_8                             78492 ns        78306 ns         8934
```

PiperOrigin-RevId: 368380864
This commit is contained in:
Tom Hennigan 2021-04-14 01:16:37 -07:00 committed by jax authors
parent d17d9a8081
commit 9d56552517
2 changed files with 36 additions and 10 deletions

View File

@ -220,6 +220,32 @@ def pmap_simple_8_devices(state):
d.block_until_ready()
def _run_sda_index_bench(state, num_devices):
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
jax.device_get(x)
while state:
for i in range(num_devices):
_ = x[i]
@google_benchmark.register
@required_devices(1)
def sda_index_1(state):
_run_sda_index_bench(state, 1)
@google_benchmark.register
@required_devices(2)
def sda_index_2(state):
_run_sda_index_bench(state, 2)
@google_benchmark.register
@required_devices(8)
def sda_index_8(state):
_run_sda_index_bench(state, 8)
def swap(a, b):
return b, a

View File

@ -553,16 +553,16 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return xla.DeviceArray.__getitem__(self, idx)
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
else:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
def _hashable_index(idx):