mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8043 from hawkinsp:iter
PiperOrigin-RevId: 406822933
This commit is contained in:
commit
335857bf93
@ -926,7 +926,7 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
return dynamic_slice_p.bind(operand, *start_indices,
|
||||
slice_sizes=tuple(slice_sizes))
|
||||
slice_sizes=core.canonicalize_shape(slice_sizes))
|
||||
|
||||
def dynamic_update_slice(operand: Array, update: Array,
|
||||
start_indices: Array) -> Array:
|
||||
@ -1362,7 +1362,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
||||
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
||||
operator.
|
||||
"""
|
||||
permutation = tuple(permutation)
|
||||
permutation = tuple(operator.index(d) for d in permutation)
|
||||
if (permutation == tuple(range(np.ndim(operand)))
|
||||
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
|
||||
return operand
|
||||
|
@ -6699,6 +6699,22 @@ def _multi_slice(arr,
|
||||
results.append(sliced)
|
||||
return results
|
||||
|
||||
# The next two functions are related to iter(device_array), implemented here to
|
||||
# avoid circular imports.
|
||||
@jit
|
||||
def _unstack(x):
|
||||
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
|
||||
setattr(DeviceArray, "_unstack", _unstack)
|
||||
def _chunk_iter(x, size):
|
||||
if size > x.shape[0]:
|
||||
yield x
|
||||
else:
|
||||
num_chunks, tail = divmod(x.shape[0], size)
|
||||
for i in range(num_chunks):
|
||||
yield lax.dynamic_slice_in_dim(x, i * size, size)
|
||||
if tail:
|
||||
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
|
||||
setattr(DeviceArray, "_chunk_iter", _chunk_iter)
|
||||
|
||||
# Syntactic sugar for scatter operations.
|
||||
class _IndexUpdateHelper:
|
||||
|
@ -622,21 +622,34 @@ def _sda_value(self):
|
||||
|
||||
|
||||
def _sda__getitem__(self, idx):
|
||||
self._check_if_deleted()
|
||||
if not isinstance(idx, tuple):
|
||||
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
|
||||
else:
|
||||
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
|
||||
try:
|
||||
buf_idx = self.indices.index(cidx)
|
||||
except ValueError:
|
||||
# NOTE: Slow path, this will materialize the sharded array on a single
|
||||
# device and use XLA's Gather to index into the resulting array.
|
||||
return xla.DeviceArray.__getitem__(self, idx)
|
||||
if self._npy_value is None:
|
||||
try:
|
||||
buf_idx = self.indices.index(cidx)
|
||||
except ValueError:
|
||||
buf_idx = None
|
||||
if buf_idx is not None:
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, buf)
|
||||
return super(self.__class__, self).__getitem__(idx)
|
||||
|
||||
|
||||
def _sda__iter__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
self._check_if_deleted()
|
||||
buf = self.device_buffers[buf_idx]
|
||||
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
||||
return xla.make_device_array(aval, None, buf)
|
||||
return (self[i] for i in range(self.shape[0]))
|
||||
|
||||
def _sda__reversed__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
|
||||
|
||||
|
||||
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
|
||||
@ -647,6 +660,8 @@ for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
|
||||
setattr(sda, "block_until_ready", _sda_block_until_ready)
|
||||
setattr(sda, "_value", property(_sda_value))
|
||||
setattr(sda, "__getitem__", _sda__getitem__)
|
||||
setattr(sda, "__iter__", _sda__iter__)
|
||||
setattr(sda, "__reversed__", _sda__reversed__)
|
||||
|
||||
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
|
||||
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
|
||||
@ -659,6 +674,7 @@ else:
|
||||
ShardedDeviceArray = _ShardedDeviceArray
|
||||
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
|
||||
idx)
|
||||
|
@ -1596,15 +1596,12 @@ for device_array in [DeviceArray]:
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
return self._value.__iter__()
|
||||
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())
|
||||
|
||||
setattr(device_array, "__iter__", __iter__)
|
||||
|
||||
def __reversed__(self):
|
||||
if self.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array")
|
||||
else:
|
||||
return reversed(self._value)
|
||||
return iter(self[::-1])
|
||||
|
||||
setattr(device_array, "__reversed__", __reversed__)
|
||||
|
||||
|
@ -2281,7 +2281,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertStartsWith(repr(rep), "DeviceArray")
|
||||
|
||||
def test_device_array_hash(self):
|
||||
rep = jnp.ones(()) + 1.
|
||||
rep = jnp.ones((1,)) + 1.
|
||||
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
|
||||
self.assertNotIsInstance(rep, collections.abc.Hashable)
|
||||
with self.assertRaisesRegex(TypeError, 'unhashable type'):
|
||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
@ -1497,6 +1498,12 @@ class LaxTest(jtu.JaxTestCase):
|
||||
x = rng((6, 7), np.int32)
|
||||
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
|
||||
|
||||
def testDynamicSliceArraySliceSizes(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((6, 7), np.int32)
|
||||
np.testing.assert_equal(lax.dynamic_slice(x, [2, 3], jnp.array([2, 2])),
|
||||
x[2:4, 3:5])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_indices={}_update_shape={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
@ -1556,6 +1563,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
op = lambda x: lax.transpose(x, perm)
|
||||
self._CompileAndCheck(op, args_maker)
|
||||
|
||||
def testTransposeWithArrayPermutation(self):
|
||||
x = lax.transpose(np.ones((2, 3)), jnp.array([1, 0]))
|
||||
self.assertEqual((3, 2), x.shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_perm={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), perm),
|
||||
|
Loading…
x
Reference in New Issue
Block a user