Merge pull request #8043 from hawkinsp:iter

PiperOrigin-RevId: 406822933
This commit is contained in:
jax authors 2021-11-01 07:41:40 -07:00
commit 335857bf93
6 changed files with 58 additions and 18 deletions

View File

@ -926,7 +926,7 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_slice_p.bind(operand, *start_indices,
slice_sizes=tuple(slice_sizes))
slice_sizes=core.canonicalize_shape(slice_sizes))
def dynamic_update_slice(operand: Array, update: Array,
start_indices: Array) -> Array:
@ -1362,7 +1362,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
operator.
"""
permutation = tuple(permutation)
permutation = tuple(operator.index(d) for d in permutation)
if (permutation == tuple(range(np.ndim(operand)))
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand

View File

@ -6699,6 +6699,22 @@ def _multi_slice(arr,
results.append(sliced)
return results
# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
else:
num_chunks, tail = divmod(x.shape[0], size)
for i in range(num_chunks):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(DeviceArray, "_chunk_iter", _chunk_iter)
# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:

View File

@ -622,21 +622,34 @@ def _sda_value(self):
def _sda__getitem__(self, idx):
self._check_if_deleted()
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)
def _sda__iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
self._check_if_deleted()
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return (self[i] for i in range(self.shape[0]))
def _sda__reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
@ -647,6 +660,8 @@ for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)
setattr(sda, "__iter__", _sda__iter__)
setattr(sda, "__reversed__", _sda__reversed__)
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
@ -659,6 +674,7 @@ else:
ShardedDeviceArray = _ShardedDeviceArray
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
idx)

View File

@ -1596,15 +1596,12 @@ for device_array in [DeviceArray]:
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return self._value.__iter__()
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())
setattr(device_array, "__iter__", __iter__)
def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array")
else:
return reversed(self._value)
return iter(self[::-1])
setattr(device_array, "__reversed__", __reversed__)

View File

@ -2281,7 +2281,7 @@ class APITest(jtu.JaxTestCase):
self.assertStartsWith(repr(rep), "DeviceArray")
def test_device_array_hash(self):
rep = jnp.ones(()) + 1.
rep = jnp.ones((1,)) + 1.
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
self.assertNotIsInstance(rep, collections.abc.Hashable)
with self.assertRaisesRegex(TypeError, 'unhashable type'):

View File

@ -26,6 +26,7 @@ from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax import core
from jax._src import dtypes
from jax import lax
@ -1497,6 +1498,12 @@ class LaxTest(jtu.JaxTestCase):
x = rng((6, 7), np.int32)
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])
def testDynamicSliceArraySliceSizes(self):
rng = jtu.rand_default(self.rng())
x = rng((6, 7), np.int32)
np.testing.assert_equal(lax.dynamic_slice(x, [2, 3], jnp.array([2, 2])),
x[2:4, 3:5])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
@ -1556,6 +1563,10 @@ class LaxTest(jtu.JaxTestCase):
op = lambda x: lax.transpose(x, perm)
self._CompileAndCheck(op, args_maker)
def testTransposeWithArrayPermutation(self):
x = lax.transpose(np.ones((2, 3)), jnp.array([1, 0]))
self.assertEqual((3, 2), x.shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),