mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove special casing on npy_value when indexing sharded arrays.
Before: ``` GPU sda_index_2 2912972 ns 2778716 ns 256 TPU sda_index_1 769968 ns 751700 ns 921 sda_index_2 1510841 ns 1489716 ns 465 sda_index_8 6102259 ns 6027655 ns 117 ``` After: ``` GPU sda_index_2 28095 ns 27983 ns 25463 TPU sda_index_1 10302 ns 10279 ns 67884 sda_index_2 20010 ns 19947 ns 34628 sda_index_8 78492 ns 78306 ns 8934 ``` PiperOrigin-RevId: 368380864
This commit is contained in:
parent
d17d9a8081
commit
9d56552517
@ -220,6 +220,32 @@ def pmap_simple_8_devices(state):
|
||||
d.block_until_ready()
|
||||
|
||||
|
||||
def _run_sda_index_bench(state, num_devices):
|
||||
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
|
||||
jax.device_get(x)
|
||||
while state:
|
||||
for i in range(num_devices):
|
||||
_ = x[i]
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@required_devices(1)
|
||||
def sda_index_1(state):
|
||||
_run_sda_index_bench(state, 1)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@required_devices(2)
|
||||
def sda_index_2(state):
|
||||
_run_sda_index_bench(state, 2)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@required_devices(8)
|
||||
def sda_index_8(state):
|
||||
_run_sda_index_bench(state, 8)
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
|
@ -553,16 +553,16 @@ class ShardedDeviceArray(xla.DeviceArray): # type: ignore
|
||||
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
|
||||
else:
|
||||
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
|
||||
if self._npy_value is None:
|
||||
try:
|
||||
buf_idx = self.indices.index(cidx)
|
||||
except ValueError:
|
||||
buf_idx = None
|
||||
if buf_idx is not None:
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, buf)
|
||||
return xla.DeviceArray.__getitem__(self, idx)
|
||||
try:
|
||||
buf_idx = self.indices.index(cidx)
|
||||
except ValueError:
|
||||
# NOTE: Slow path, this will materialize the sharded array on a single
|
||||
# device and use XLA's Gather to index into the resulting array.
|
||||
return xla.DeviceArray.__getitem__(self, idx)
|
||||
else:
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, buf)
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
|
Loading…
x
Reference in New Issue
Block a user