Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)

* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py

`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.

This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.

I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.

I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)

* Remove unnecessary branch

* Add lax.squeeze primitive

* Changes per review

* Fix typing

* Move expand_dims into lax

* Update per review; add comments/documentation

* Type annotations for squeeze/expand_dims
This commit is contained in:
Stephan Hoyer 2020-05-28 19:12:50 -07:00 committed by GitHub
parent 7944879cdd
commit cc8fbb7669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 238 additions and 106 deletions

View File

@ -69,6 +69,7 @@ Operators
erfc
erf_inv
exp
expand_dims
expm1
fft
floor
@ -121,6 +122,7 @@ Operators
sort_key_val
sqrt
square
squeeze
sub
tan
tie_in

View File

@ -120,6 +120,7 @@ from .lax import (
erfc_p,
exp,
exp_p,
expand_dims,
expm1,
expm1_p,
floor,
@ -254,6 +255,8 @@ from .lax import (
sqrt,
sqrt_p,
square,
squeeze,
squeeze_p,
standard_abstract_eval,
standard_naryop,
standard_primitive,
@ -283,7 +286,7 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_const, _eq_meet, _broadcasting_select,
_check_user_dtype_supported, _one, _const,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros)
_eye, _tri, _delta, _ones, _zeros, _canonicalize_axis)
from .lax_control_flow import (
cond,
cond_p,

View File

@ -627,14 +627,15 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
rhs = transpose(rhs,
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
new_lhs_shape = onp.insert(onp.array(onp.shape(lhs), dtype=onp.int64),
len(lhs_batch_dims) + len(lhs_noncontract_dims),
(1,) * len(rhs_noncontract_dims))
new_rhs_shape = onp.insert(onp.array(onp.shape(rhs), dtype=onp.int64),
len(lhs_batch_dims),
(1,) * len(lhs_noncontract_dims))
lhs = reshape(lhs, new_lhs_shape)
rhs = reshape(rhs, new_rhs_shape)
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
rhs_start_expand = len(lhs_batch_dims)
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
len(rhs_noncontract_dims))
op_product = bitwise_and if lhs.dtype == onp.bool_ else mul
@ -687,6 +688,10 @@ def reshape(operand: Array, new_sizes: Shape,
"""Wraps XLA's `Reshape
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
``lax.expand_dims``. These preserve information about axis identity that may
be useful for advanced transformation rules.
"""
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
@ -999,7 +1004,7 @@ def scatter(operand: Array, scatter_indices:Array, updates: Array,
update_consts=consts, dimension_numbers=dimension_numbers)
def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
indices = concatenate([reshape(i, [i.shape[0], 1]) for i in idxs], 1)
indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1)
indices = indices % onp.array([src.shape[ax] for ax in axes])
slice_sizes = list(src.shape)
for ax in axes:
@ -1559,7 +1564,7 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
if keepdims:
return result
else:
return reshape(result, onp.delete(operand.shape, axis))
return squeeze(result, (axis,))
def dynamic_slice_in_dim(operand: Array, start_index: Array,
@ -1581,7 +1586,7 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
if keepdims:
return result
else:
return reshape(result, onp.delete(operand.shape, axis))
return squeeze(result, (axis,))
def dynamic_update_slice_in_dim(operand: Array, update: Array,
@ -1597,8 +1602,7 @@ def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
axis = int(axis)
if _ndim(update) != _ndim(operand):
assert _ndim(update) + 1 == _ndim(operand)
ax = axis % _ndim(operand)
update = reshape(update, operand.shape[:ax] + (1,) + operand.shape[ax+1:])
update = expand_dims(update, (axis,))
return dynamic_update_slice_in_dim(operand, update, index, axis)
@ -1849,8 +1853,8 @@ def _brcast_to(x, shape):
assert len(x_shape) == len(shape)
broadcast_dimensions, = onp.where(onp.equal(x_shape, shape))
squeezed_dimensions, = onp.where(onp.not_equal(x_shape, shape))
inshape = onp.delete(x_shape, squeezed_dimensions)
return broadcast_in_dim(reshape(x, inshape), shape, broadcast_dimensions)
squeezed = squeeze(x, squeezed_dimensions)
return broadcast_in_dim(squeezed, shape, broadcast_dimensions)
else:
return broadcast(x, shape)
@ -2926,6 +2930,74 @@ ad.primitive_transposes[pad_p] = _pad_transpose
batching.primitive_batchers[pad_p] = _pad_batch_rule
# The squeeze primitive exists for the benefit of masking and other
# transformations that need to keep track of axis identity.
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
# with shape (N,). This results in the following JAXpr:
# reshape[ dimension=None new_sizes=(N,) ]
# For N > 1, we can match up the output array axis with the second axis of the
# input. But for N = 1, it is not clear how axes match up: all we know from the
# JAXpr is that we are reshaping from (1, 1) to (1,).
# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array:
"""Squeeze any number of size 1 dimensions from an array."""
ndim = onp.ndim(array)
dimensions = tuple(sorted(_canonicalize_axis(i, ndim) for i in dimensions))
if not dimensions:
return array
return squeeze_p.bind(array, dimensions=dimensions)
def _squeeze_dtype_rule(operand, *, dimensions):
return operand.dtype
def _squeeze_shape_rule(operand, *, dimensions):
return _compute_squeeze_shape(onp.shape(operand), dimensions)
def _compute_squeeze_shape(shape, dimensions):
dims_set = set(dimensions)
if len(dims_set) != len(dimensions):
raise ValueError(f"dimensions are not unique: {dimensions}")
if not all(0 <= d < len(shape) for d in dims_set):
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
if any(shape[d] != 1 for d in dimensions):
raise ValueError(
"cannot select an axis to squeeze out which has size not equal to "
f"one, got shape={shape} and dimensions={dimensions}")
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
def _squeeze_translation_rule(c, arg, *, dimensions):
new_shape = _compute_squeeze_shape(c.get_shape(arg).dimensions(), dimensions)
return xops.Reshape(arg, new_shape)
def _squeeze_transpose_rule(t, operand, *, dimensions):
assert ad.is_undefined_primal(operand)
return [expand_dims(t, dimensions)]
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
operand, = batched_args
bdim, = batch_dims
operand = batching.moveaxis(operand, bdim, 0)
dimensions = tuple(onp.add(1, dimensions))
return squeeze(operand, dimensions=dimensions), 0
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
'squeeze', _squeeze_translation_rule)
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array:
"""Insert any number of size 1 dimensions into an array."""
ndim_out = onp.ndim(array) + len(dimensions)
dims_set = frozenset(_canonicalize_axis(i, ndim_out) for i in dimensions)
result_shape = list(onp.shape(array))
for i in sorted(dims_set):
result_shape.insert(i, 1)
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
return broadcast_in_dim(array, result_shape, broadcast_dims)
# We have a nonstandard reshape impl so that we can be lazy about data movement.
def _reshape_impl(operand, *, new_sizes, dimensions):
old_sizes = onp.shape(operand)
@ -3324,7 +3396,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
def _batch_dynamic_slice_indices(indices, bdims):
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
if size < 0:
return concatenate([reshape(i, [1]) for i in indices], 0), None
return concatenate([broadcast(i, (1,)) for i in indices], 0), None
indices = concatenate(
[broadcast_in_dim(x, (size, 1),
broadcast_dimensions=((0,) if i is not None else ()))
@ -4389,7 +4461,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
operand = batching.moveaxis(operand, o_bdims, 0)
outputs = [
_select_and_scatter_add(s, o, **kwargs) for s, o in zip(source, operand)]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = [broadcast(out, (1,)) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
elif s_bdims is not None:
@ -4397,7 +4469,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
source = batching.moveaxis(source, s_bdims, 0)
outputs = [
_select_and_scatter_add(s, operand, **kwargs) for s in source]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = [broadcast(out, (1,)) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
elif o_bdims is not None:
@ -4405,7 +4477,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
operand = batching.moveaxis(operand, o_bdims, 0)
outputs = [
_select_and_scatter_add(source, o, **kwargs) for o in operand]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = [broadcast(out, (1,)) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
@ -5191,7 +5263,7 @@ def _dynamic_slice_indices(operand, start_indices):
if start_indices.ndim != 1:
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
start_indices = [reshape(slice(start_indices, [i], [i+1]), ())
start_indices = [squeeze(slice(start_indices, [i], [i+1]), dimensions=(0,))
for i in range(operand.ndim)]
else:
start_indices = [onp.asarray(i, dtype=dtypes.int_) if isinstance(i, int)
@ -5395,12 +5467,12 @@ def _check_user_dtype_supported(dtype, fun_name=None):
def _canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
axis = int(axis)
if axis < 0:
axis = axis + num_dims
if axis < 0 or axis >= num_dims:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
if axis < 0:
axis = axis + num_dims
return axis

View File

@ -223,6 +223,8 @@ def broadcast_in_dim(operand, shape, broadcast_dimensions):
sum = np.sum
squeeze = np.squeeze
def reshape(operand, new_sizes, dimensions=None):
if dimensions is None:
dimensions = range(len(np.shape(operand)))

View File

@ -204,6 +204,8 @@ load = np.load
### utility functions
_canonicalize_axis = lax._canonicalize_axis
def _promote_shapes(fun_name, *args):
"""Prepend implicit leading singleton dimensions for Numpy broadcasting."""
if len(args) < 2:
@ -217,8 +219,7 @@ def _promote_shapes(fun_name, *args):
if FLAGS.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [lax.reshape(arg, (1,) * (result_rank - len(shp)) + shp)
if shp and len(shp) != result_rank else arg
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
@ -291,17 +292,6 @@ def _promote_args_inexact(fun_name, *args):
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))
def _canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
axis = int(axis)
if axis < 0:
axis = axis + num_dims
if axis < 0 or axis >= num_dims:
raise ValueError(
"axis {} is out of bounds for array of dimension {}".format(
axis, num_dims))
return axis
### implementations of numpy functions in terms of lax
@_wraps(np.fmin)
@ -1151,29 +1141,20 @@ def unravel_index(indices, shape):
@_wraps(np.squeeze)
def squeeze(a, axis=None):
shape_a = shape(a)
def squeeze(a, axis: Union[int, Tuple[int, ...]] = None):
if axis is None:
if 1 not in shape_a:
return a
newshape = [d for d in shape_a if d != 1]
else:
if isinstance(axis, int):
axis = (axis,)
axis = frozenset(_canonicalize_axis(i, ndim(a)) for i in axis)
if _any(shape_a[a] != 1 for a in axis):
raise ValueError("cannot select an axis to squeeze out which has size "
"not equal to one")
newshape = [d for i, d in enumerate(shape_a)
if d != 1 or i not in axis]
return lax.reshape(a, newshape)
a_shape = shape(a)
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
elif isinstance(axis, int):
axis = (axis,)
return lax.squeeze(a, axis)
@_wraps(np.expand_dims)
def expand_dims(a, axis):
shape = _shape(a)
axis = _canonicalize_axis(axis, ndim(a) + 1)
return lax.reshape(a, shape[:axis] + (1,) + shape[axis:])
def expand_dims(a, axis: Union[int, Tuple[int, ...]]):
if isinstance(axis, int):
axis = (axis,)
return lax.expand_dims(a, axis)
@_wraps(np.swapaxes)
@ -1370,7 +1351,7 @@ def broadcast_to(arr, shape):
diff, = np.where(np.not_equal(shape[nlead:], arr_shape))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(squeeze(arr, diff), shape, kept_dims)
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
@_wraps(np.split)
@ -1550,8 +1531,7 @@ def _make_reduction(np_fun, op, init_val, preproc=None, bool_op=None,
result = lax.reduce(a, _reduction_init_val(a, init_val),
op if computation_dtype != np.bool_ else bool_op, dims)
if keepdims:
shape_with_singletons = subvals(shape(a), zip(dims, (1,) * len(dims)))
result = lax.reshape(result, shape_with_singletons)
result = expand_dims(result, dims)
return lax.convert_element_type(result, dtype or result_dtype)
return reduction
@ -1992,13 +1972,11 @@ def stack(arrays, axis=0):
raise ValueError("Need at least one array to stack.")
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_shape = list(shape0)
new_shape.insert(axis, 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(reshape(a, new_shape))
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
@_wraps(np.tile)
@ -2057,7 +2035,7 @@ def column_stack(tup):
for v in tup:
arr = array(v)
if arr.ndim < 2:
arr = arr.reshape((-1, 1))
arr = expand_dims(arr, axis=0)
arrays.append(arr)
return concatenate(arrays, 1)
@ -2102,7 +2080,12 @@ def atleast_1d(*arys):
def atleast_2d(*arys):
if len(arys) == 1:
arr = array(arys[0])
return arr if ndim(arr) >= 2 else reshape(arr, (1, -1))
if ndim(arr) >= 2:
return arr
elif ndim(arr) == 1:
return expand_dims(arr, axis=0)
else:
return expand_dims(arr, axis=(0, 1))
else:
return [atleast_2d(arr) for arr in arys]
@ -2111,10 +2094,12 @@ def atleast_2d(*arys):
def atleast_3d(*arys):
if len(arys) == 1:
arr = array(arys[0])
if ndim(arr) <= 1:
arr = reshape(arr, (1, -1, 1))
if ndim(arr) == 0:
arr = expand_dims(arr, axis=(0, 1, 2))
elif ndim(arr) == 1:
arr = expand_dims(arr, axis=(0, 2))
elif ndim(arr) == 2:
arr = reshape(arr, shape(arr) + (1,))
arr = expand_dims(arr, axis=2)
return arr
else:
return [atleast_3d(arr) for arr in arys]
@ -2154,7 +2139,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
raise TypeError("Unexpected input type for array: {}".format(type(object)))
if ndmin > ndim(out):
out = lax.reshape(out, (1,) * (ndmin - ndim(out)) + shape(out))
out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))
return out
@_wraps(np.asarray)
@ -2397,7 +2382,7 @@ def ix_(*args):
# Numpy uses an integer index type for empty arrays.
output.append(lax.full(shape, np.zeros((), np.intp)))
else:
output.append(lax.reshape(a, shape))
output.append(lax.broadcast_in_dim(a, shape, (i,)))
return tuple(output)
@ -2642,8 +2627,8 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1)
a = lax.reshape(a, (1,) + shape(a)) if a_is_vec else a
b = lax.reshape(b, shape(b) + (1,)) if b_is_vec else b
a = expand_dims(a, axis=0) if a_is_vec else a
b = expand_dims(b, axis=-1) if b_is_vec else b
a, b = _promote_dtypes(a, b)
batch_shape = lax.broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
@ -2653,13 +2638,8 @@ def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
dim_numbers = (((ndim(a) - 1,), (ndim(b) - 2,)), (batch_dims, batch_dims))
result = lax.dot_general(a, b, dim_numbers, precision)
if a_is_vec or b_is_vec:
m, n = shape(result)[-2:]
new_m = () if a_is_vec else (m,)
new_n = () if b_is_vec else (n,)
return lax.reshape(result, batch_shape + new_m + new_n)
else:
return result
squeeze_dims = ((-2,) if a_is_vec else ()) + ((-1,) if b_is_vec else ())
return squeeze(result, squeeze_dims)
@_wraps(np.vdot, lax_description=_PRECISION_DOC)
@ -3325,7 +3305,7 @@ def unique(ar, return_index=False, return_inverse=False,
def _rewriting_take(arr, idx):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and a reshape.
# followed by an optional reverse and broadcast_in_dim.
arr = asarray(arr)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
return _gather(arr, treedef, static_idx, dynamic_idx)
@ -3353,7 +3333,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx):
y = lax.rev(y, indexer.reversed_y_dims)
# This adds np.newaxis/None dimensions.
return lax.reshape(y, indexer.slice_shape)
return expand_dims(y, indexer.newaxis_dims)
_Indexer = collections.namedtuple("_Indexer", [
# The expected shape of the slice output.
@ -3372,9 +3352,8 @@ _Indexer = collections.namedtuple("_Indexer", [
# the gather.
"reversed_y_dims",
# For scatters, we must eliminate any axes created by `newaxis`, which
# are the following dimensions, which must be of size 1. For gathers, we
# simply reshape to `slice_shape` to introduce the new axes.
# Keep track of any axes created by `newaxis`. These must be inserted for
# gathers and eliminated for scatters.
"newaxis_dims",
])

View File

@ -103,8 +103,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None or return_sign:
raise NotImplementedError("Only implemented for b=None, return_sign=False")
dims = _reduction_dims(a, axis)
shape = util.subvals(np.shape(a), zip(dims, (1,) * len(dims)))
dimadd = lambda x: lax.reshape(x, shape)
dimadd = lambda x: lax.expand_dims(x, dims)
amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
amax_singletons = dimadd(amax)

View File

@ -1890,15 +1890,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
"rng_factory": jtu.rand_default}
for arg_shape in [(), (3,), (3, 4)]
for dtype in default_dtypes
for dim in range(-len(arg_shape)+1, len(arg_shape))))
for dim in (list(range(-len(arg_shape)+1, len(arg_shape)))
+ [(0,), (len(arg_shape), len(arg_shape) + 1)])))
def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory):
rng = rng_factory(self.rng())
np_fun = lambda x: np.expand_dims(x, dim)
jnp_fun = lambda x: jnp.expand_dims(x, dim)
args_maker = lambda: [rng(arg_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
if isinstance(dim, tuple) and numpy_version < (1, 18, 0):
raise SkipTest("support for multiple axes added in NumPy 1.18.0")
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axes=({},{})".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
@ -1935,23 +1939,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_axis={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
"arg_shape": arg_shape, "dtype": dtype, "ax": ax,
"rng_factory": jtu.rand_default}
for arg_shape, ax in [
((3,), 0),
((1, 3), 1),
((1, 3, 1), (0, 1))]
for dtype in default_dtypes))
def testSqueezeFailsOnNonsingletonAxis(self, arg_shape, dtype, ax,
rng_factory):
rng = rng_factory(self.rng())
x = jnp.zeros(arg_shape, dtype=dtype)
fun = lambda: jnp.squeeze(x, ax)
self.assertRaisesRegex(ValueError, "cannot select an axis to squeeze", fun)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
jtu.format_shape_dtype_string(shape, dtype),

View File

@ -890,6 +890,47 @@ class LaxTest(jtu.JaxTestCase):
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_dimensions={}".format(
jtu.format_shape_dtype_string(inshape, onp.float32), dimensions),
"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
"err_msg": err_msg}
for inshape, dimensions, error_type, err_msg in [
((1, 2, 3), (0, 0), ValueError, 'dimensions are not unique'),
((1, 2, 3), (3,), ValueError, 'axis 3 is out of bounds'),
((1, 2, 3), (-4,), ValueError, 'axis -4 is out of bounds'),
((1, 2, 3), (1,), ValueError, 'cannot select an axis to squeeze out'),
((1, 2, 3), (None,), TypeError, 'cannot be interpreted as an integer'),
]))
def testSqueezeShapeCheck(self, inshape, dimensions, error_type, err_msg):
rng = jtu.rand_default(self.rng())
x = rng(inshape, onp.float32)
with self.assertRaisesRegex(error_type, err_msg):
lax.squeeze(x, dimensions=dimensions)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_dimensions={}".format(
jtu.format_shape_dtype_string(arg_shape, onp.float32), dimensions),
"arg_shape": arg_shape, "dimensions": dimensions,
"rng_factory": rng_factory}
for arg_shape, dimensions in [
[(1,), (0,)],
[(1,), (-1,)],
[(2, 1, 4), (1,)],
[(2, 1, 3, 1), (1,)],
[(2, 1, 3, 1), (1, 3)],
[(2, 1, 3, 1), (3,)],
]
for rng_factory in [jtu.rand_default]))
def testSqueeze(self, arg_shape, dimensions, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(arg_shape, onp.float32)]
op = lambda x: lax.squeeze(x, dimensions)
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
self._CompileAndCheck(op, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(op, numpy_op, args_maker)
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
@ -3010,6 +3051,30 @@ class LaxVmapTest(jtu.JaxTestCase):
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
jtu.format_shape_dtype_string(arg_shape, onp.float32),
dimensions, bdims),
"arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims,
"rng_factory": rng_factory}
for arg_shape, dimensions in [
[(1,), (0,)],
[(1,), (-1,)],
[(2, 1, 4), (1,)],
[(2, 1, 4), (-2,)],
[(2, 1, 3, 1), (1,)],
[(2, 1, 3, 1), (1, 3)],
[(2, 1, 3, 1), (3,)],
[(2, 1, 3, 1), (1, -1)],
]
for bdims in all_bdims(arg_shape)
for rng_factory in [jtu.rand_default]))
def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
dtype = onp.float32
rng = rng_factory(self.rng())
op = lambda x: lax.squeeze(x, dimensions)
self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),

View File

@ -416,6 +416,29 @@ class MaskingTest(jtu.JaxTestCase):
expected = 8 / 3
self.assertAllClose(ans, expected, check_dtypes=False)
def test_arithmetic(self):
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
def times(x, y):
return x * y
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
ans = times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
dict(n=4, m=5))
# expected = np.array([[1, 2, 3], [8, 10, 12]])
# self.assertAllClose(ans, expected, check_dtypes=False)
def test_stack(self):
@partial(mask, in_shapes=['n','n'], out_shape='(2, n)')
def stack(x, y):
return jnp.stack([x, y], 0)
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
ans = stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
# expected = np.array([[1, 2, 3], [4, 5, 6]])
# self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='')
def padded_sum(x):