mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a common lax._canonicalize_shape method, use on methods that accept shapes in lax.
Explicitly convert shape entries to integers using the Python __index__() method. Implement __index__ on DeviceArrays so shapes like (1, DeviceArray(2)) work. Fixes bug where np.full accepted floating point shapes; __index__() errors for non-integer inputs, where int() would silently cast and drop information.
This commit is contained in:
parent
97a5148a0d
commit
1479ae9066
@ -205,7 +205,8 @@ def add_batch_dim_to_aval(bdim, size, aval):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
else:
|
||||
assert 0 <= bdim <= aval.ndim
|
||||
batched_shape = tuple(onp.insert(aval.shape, bdim, size))
|
||||
batched_shape = tuple(
|
||||
onp.insert(onp.asarray(aval.shape, onp.intp), bdim, size))
|
||||
return ShapedArray(batched_shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError(t)
|
||||
|
@ -619,6 +619,7 @@ class DeviceArray(DeviceValue):
|
||||
__complex__ = partialmethod(_forward_to_value, complex)
|
||||
__hex__ = partialmethod(_forward_to_value, hex)
|
||||
__oct__ = partialmethod(_forward_to_value, oct)
|
||||
__index__ = partialmethod(_forward_to_value, op.index)
|
||||
|
||||
# pickle saves and loads just like an ndarray
|
||||
__reduce__ = partialmethod(_forward_to_value, op.methodcaller("__reduce__"))
|
||||
|
@ -72,6 +72,22 @@ def broadcast_shapes(*shapes):
|
||||
.format(tuple(map(tuple, shapes))))
|
||||
return tuple(result_shape)
|
||||
|
||||
def _canonicalize_shape(shape):
|
||||
"""Canonicalizes and checks for errors in a user-provided shape value.
|
||||
|
||||
Args:
|
||||
shape: a Python value that represents a shape.
|
||||
|
||||
Returns:
|
||||
A tuple of integers.
|
||||
"""
|
||||
try:
|
||||
return tuple(map(operator.index, shape))
|
||||
except TypeError:
|
||||
pass
|
||||
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
||||
"got {}")
|
||||
raise TypeError(msg.format(shape))
|
||||
|
||||
def _identity(x): return x
|
||||
|
||||
@ -576,7 +592,7 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
"""
|
||||
new_sizes = tuple(new_sizes)
|
||||
new_sizes = _canonicalize_shape(new_sizes)
|
||||
same_shape = onp.shape(operand) == new_sizes
|
||||
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
||||
if onp.shape(operand) and same_shape and same_dims:
|
||||
@ -666,7 +682,7 @@ def gather(operand, start_indices, dimension_numbers, slice_sizes):
|
||||
"""
|
||||
return gather_p.bind(
|
||||
operand, start_indices, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=tuple(slice_sizes), operand_shape=operand.shape)
|
||||
slice_sizes=_canonicalize_shape(slice_sizes), operand_shape=operand.shape)
|
||||
|
||||
def scatter_add(operand, scatter_indices, updates, dimension_numbers):
|
||||
"""Scatter-add operator.
|
||||
@ -980,7 +996,7 @@ def full(shape, fill_value, dtype=None):
|
||||
will be cast to `dtype`.
|
||||
"""
|
||||
try:
|
||||
shape = tuple(map(int, shape))
|
||||
shape = _canonicalize_shape(shape)
|
||||
except TypeError:
|
||||
msg = ("`full` requires shapes to be concrete. If using `jit`, try using "
|
||||
"`static_argnums` or applying `jit` to smaller subfunctions instead.")
|
||||
@ -1015,7 +1031,7 @@ def broadcasted_iota(dtype, shape, dimension):
|
||||
operator.
|
||||
"""
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
shape = tuple(map(int, shape))
|
||||
shape = _canonicalize_shape(shape)
|
||||
dimension = int(dimension)
|
||||
return _IotaConstant(dtype, shape, dimension)
|
||||
|
||||
@ -1026,7 +1042,7 @@ def broadcasted_eye(dtype, shape, axes):
|
||||
if not isinstance(axes, (list, tuple)) or not len(axes) >= 2:
|
||||
raise TypeError("make_diagonal `axes` must be a tuple with len at least 2.")
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
shape = tuple(map(int, shape))
|
||||
shape = _canonicalize_shape(shape)
|
||||
axes = tuple(map(int, axes))
|
||||
return _EyeConstant(shape, axes, dtype)
|
||||
|
||||
@ -1212,7 +1228,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
|
||||
An ndarray with the same shape as `x` with its entries set equal to
|
||||
`fill_value`, similar to the output of np.full.
|
||||
"""
|
||||
shape = onp.shape(x) if shape is None else shape
|
||||
shape = onp.shape(x) if shape is None else _canonicalize_shape(shape)
|
||||
out = full(shape, fill_value, dtype or _dtype(x))
|
||||
return tie_in(x, out)
|
||||
|
||||
|
@ -1758,6 +1758,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
def testIssue967(self):
|
||||
self.assertRaises(TypeError, lambda: lnp.zeros(1.5))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
@ -2240,6 +2240,17 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
expected = onp.array(0.0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReshapeWithUnusualShapes(self):
|
||||
ans = lax.reshape(onp.ones((3,), onp.float32), (lax.add(1, 2), 1))
|
||||
self.assertAllClose(ans, onp.ones((3, 1), onp.float32), check_dtypes=True)
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (onp.array([3, 1]),)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
jtu.check_raises_regexp(
|
||||
lambda: lax.reshape(onp.ones(3,), (1.5, 2.0)), TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type.*")
|
||||
|
||||
def all_bdims(*shapes):
|
||||
bdims = (itertools.chain([None], range(len(shape) + 1)) for shape in shapes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user