Add a common lax._canonicalize_shape method, use on methods that accept shapes in lax.

Explicitly convert shape entries to integers using the Python __index__() method.
Implement __index__ on DeviceArrays so shapes like (1, DeviceArray(2)) work.

Fixes bug where np.full accepted floating point shapes; __index__() errors for non-integer inputs, where int() would silently cast and drop information.
This commit is contained in:
Peter Hawkins 2019-07-23 16:18:10 -04:00
parent 97a5148a0d
commit 1479ae9066
5 changed files with 38 additions and 7 deletions

View File

@ -205,7 +205,8 @@ def add_batch_dim_to_aval(bdim, size, aval):
return ShapedArray(aval.shape, aval.dtype)
else:
assert 0 <= bdim <= aval.ndim
batched_shape = tuple(onp.insert(aval.shape, bdim, size))
batched_shape = tuple(
onp.insert(onp.asarray(aval.shape, onp.intp), bdim, size))
return ShapedArray(batched_shape, aval.dtype)
else:
raise TypeError(t)

View File

@ -619,6 +619,7 @@ class DeviceArray(DeviceValue):
__complex__ = partialmethod(_forward_to_value, complex)
__hex__ = partialmethod(_forward_to_value, hex)
__oct__ = partialmethod(_forward_to_value, oct)
__index__ = partialmethod(_forward_to_value, op.index)
# pickle saves and loads just like an ndarray
__reduce__ = partialmethod(_forward_to_value, op.methodcaller("__reduce__"))

View File

@ -72,6 +72,22 @@ def broadcast_shapes(*shapes):
.format(tuple(map(tuple, shapes))))
return tuple(result_shape)
def _canonicalize_shape(shape):
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of integers.
"""
try:
return tuple(map(operator.index, shape))
except TypeError:
pass
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
"got {}")
raise TypeError(msg.format(shape))
def _identity(x): return x
@ -576,7 +592,7 @@ def reshape(operand, new_sizes, dimensions=None):
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
"""
new_sizes = tuple(new_sizes)
new_sizes = _canonicalize_shape(new_sizes)
same_shape = onp.shape(operand) == new_sizes
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
if onp.shape(operand) and same_shape and same_dims:
@ -666,7 +682,7 @@ def gather(operand, start_indices, dimension_numbers, slice_sizes):
"""
return gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=tuple(slice_sizes), operand_shape=operand.shape)
slice_sizes=_canonicalize_shape(slice_sizes), operand_shape=operand.shape)
def scatter_add(operand, scatter_indices, updates, dimension_numbers):
"""Scatter-add operator.
@ -980,7 +996,7 @@ def full(shape, fill_value, dtype=None):
will be cast to `dtype`.
"""
try:
shape = tuple(map(int, shape))
shape = _canonicalize_shape(shape)
except TypeError:
msg = ("`full` requires shapes to be concrete. If using `jit`, try using "
"`static_argnums` or applying `jit` to smaller subfunctions instead.")
@ -1015,7 +1031,7 @@ def broadcasted_iota(dtype, shape, dimension):
operator.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
shape = tuple(map(int, shape))
shape = _canonicalize_shape(shape)
dimension = int(dimension)
return _IotaConstant(dtype, shape, dimension)
@ -1026,7 +1042,7 @@ def broadcasted_eye(dtype, shape, axes):
if not isinstance(axes, (list, tuple)) or not len(axes) >= 2:
raise TypeError("make_diagonal `axes` must be a tuple with len at least 2.")
dtype = xla_bridge.canonicalize_dtype(dtype)
shape = tuple(map(int, shape))
shape = _canonicalize_shape(shape)
axes = tuple(map(int, axes))
return _EyeConstant(shape, axes, dtype)
@ -1212,7 +1228,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
An ndarray with the same shape as `x` with its entries set equal to
`fill_value`, similar to the output of np.full.
"""
shape = onp.shape(x) if shape is None else shape
shape = onp.shape(x) if shape is None else _canonicalize_shape(shape)
out = full(shape, fill_value, dtype or _dtype(x))
return tie_in(x, out)

View File

@ -1758,6 +1758,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testIssue967(self):
self.assertRaises(TypeError, lambda: lnp.zeros(1.5))
if __name__ == "__main__":
absltest.main()

View File

@ -2240,6 +2240,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
expected = onp.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReshapeWithUnusualShapes(self):
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
jtu.check_raises_regexp(
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*")
def all_bdims(*shapes):
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)