mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Prefer using broadcast_in_dim/squeeze instead of reshape (#3217)
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py `reshape()` is quite powerful, but does not necessarily preserve a notion of axis identity (particularly for axes of length 1). This is problematic for transformation rules that need to preserve a notion of axis identity, such as for masking and a new transformation rule I'm exploring for unraveling pytrees. This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim, when feasible, which has a well-defined mapping between input and output axes. In particular: `matmul`, various `stack` functions, the `array` constructor, broadcasting arithmetic, array indexing, `squeeze` and reductions with `keepdims=True` no longer use `lax.reshape`. I also implemented support for multiple axes in `expand_dims` (added in NumPy 1.18), since it was convenient for some of these other functions. I considered trying to write a masking rule for broadcast_in_dim as well, but it was trickier than I expected and @JuliusKunze has probably already thought about it :) * Remove unnecessary branch * Add lax.squeeze primitive * Changes per review * Fix typing * Move expand_dims into lax * Update per review; add comments/documentation * Type annotations for squeeze/expand_dims
This commit is contained in:
parent
7944879cdd
commit
cc8fbb7669
@ -69,6 +69,7 @@ Operators
|
||||
erfc
|
||||
erf_inv
|
||||
exp
|
||||
expand_dims
|
||||
expm1
|
||||
fft
|
||||
floor
|
||||
@ -121,6 +122,7 @@ Operators
|
||||
sort_key_val
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
sub
|
||||
tan
|
||||
tie_in
|
||||
|
@ -120,6 +120,7 @@ from .lax import (
|
||||
erfc_p,
|
||||
exp,
|
||||
exp_p,
|
||||
expand_dims,
|
||||
expm1,
|
||||
expm1_p,
|
||||
floor,
|
||||
@ -254,6 +255,8 @@ from .lax import (
|
||||
sqrt,
|
||||
sqrt_p,
|
||||
square,
|
||||
squeeze,
|
||||
squeeze_p,
|
||||
standard_abstract_eval,
|
||||
standard_naryop,
|
||||
standard_primitive,
|
||||
@ -283,7 +286,7 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_const, _eq_meet, _broadcasting_select,
|
||||
_check_user_dtype_supported, _one, _const,
|
||||
_upcast_fp16_for_computation, _broadcasting_shape_rule,
|
||||
_eye, _tri, _delta, _ones, _zeros)
|
||||
_eye, _tri, _delta, _ones, _zeros, _canonicalize_axis)
|
||||
from .lax_control_flow import (
|
||||
cond,
|
||||
cond_p,
|
||||
|
122
jax/lax/lax.py
122
jax/lax/lax.py
@ -627,14 +627,15 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
|
||||
lhs_batch_dims + lhs_noncontract_dims + lhs_contract_dims)
|
||||
rhs = transpose(rhs,
|
||||
rhs_batch_dims + rhs_noncontract_dims + rhs_contract_dims)
|
||||
new_lhs_shape = onp.insert(onp.array(onp.shape(lhs), dtype=onp.int64),
|
||||
len(lhs_batch_dims) + len(lhs_noncontract_dims),
|
||||
(1,) * len(rhs_noncontract_dims))
|
||||
new_rhs_shape = onp.insert(onp.array(onp.shape(rhs), dtype=onp.int64),
|
||||
len(lhs_batch_dims),
|
||||
(1,) * len(lhs_noncontract_dims))
|
||||
lhs = reshape(lhs, new_lhs_shape)
|
||||
rhs = reshape(rhs, new_rhs_shape)
|
||||
|
||||
lhs_start_expand = len(lhs_batch_dims) + len(lhs_noncontract_dims)
|
||||
lhs_end_expand = lhs_start_expand + len(rhs_noncontract_dims)
|
||||
lhs = expand_dims(lhs, tuple(range(lhs_start_expand, lhs_end_expand)))
|
||||
|
||||
rhs_start_expand = len(lhs_batch_dims)
|
||||
rhs_end_expand = rhs_start_expand + len(lhs_noncontract_dims)
|
||||
rhs = expand_dims(rhs, tuple(range(rhs_start_expand, rhs_end_expand)))
|
||||
|
||||
out_ndim = (len(lhs_batch_dims) + len(lhs_noncontract_dims) +
|
||||
len(rhs_noncontract_dims))
|
||||
op_product = bitwise_and if lhs.dtype == onp.bool_ else mul
|
||||
@ -687,6 +688,10 @@ def reshape(operand: Array, new_sizes: Shape,
|
||||
"""Wraps XLA's `Reshape
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
|
||||
For inserting/removing dimensions of size 1, prefer using ``lax.squeeze`` /
|
||||
``lax.expand_dims``. These preserve information about axis identity that may
|
||||
be useful for advanced transformation rules.
|
||||
"""
|
||||
new_sizes = canonicalize_shape(new_sizes) # TODO
|
||||
new_sizes = tuple(new_sizes)
|
||||
@ -999,7 +1004,7 @@ def scatter(operand: Array, scatter_indices:Array, updates: Array,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers)
|
||||
|
||||
def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array:
|
||||
indices = concatenate([reshape(i, [i.shape[0], 1]) for i in idxs], 1)
|
||||
indices = concatenate([expand_dims(i, (1,)) for i in idxs], 1)
|
||||
indices = indices % onp.array([src.shape[ax] for ax in axes])
|
||||
slice_sizes = list(src.shape)
|
||||
for ax in axes:
|
||||
@ -1559,7 +1564,7 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
|
||||
if keepdims:
|
||||
return result
|
||||
else:
|
||||
return reshape(result, onp.delete(operand.shape, axis))
|
||||
return squeeze(result, (axis,))
|
||||
|
||||
|
||||
def dynamic_slice_in_dim(operand: Array, start_index: Array,
|
||||
@ -1581,7 +1586,7 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
|
||||
if keepdims:
|
||||
return result
|
||||
else:
|
||||
return reshape(result, onp.delete(operand.shape, axis))
|
||||
return squeeze(result, (axis,))
|
||||
|
||||
|
||||
def dynamic_update_slice_in_dim(operand: Array, update: Array,
|
||||
@ -1597,8 +1602,7 @@ def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
|
||||
axis = int(axis)
|
||||
if _ndim(update) != _ndim(operand):
|
||||
assert _ndim(update) + 1 == _ndim(operand)
|
||||
ax = axis % _ndim(operand)
|
||||
update = reshape(update, operand.shape[:ax] + (1,) + operand.shape[ax+1:])
|
||||
update = expand_dims(update, (axis,))
|
||||
return dynamic_update_slice_in_dim(operand, update, index, axis)
|
||||
|
||||
|
||||
@ -1849,8 +1853,8 @@ def _brcast_to(x, shape):
|
||||
assert len(x_shape) == len(shape)
|
||||
broadcast_dimensions, = onp.where(onp.equal(x_shape, shape))
|
||||
squeezed_dimensions, = onp.where(onp.not_equal(x_shape, shape))
|
||||
inshape = onp.delete(x_shape, squeezed_dimensions)
|
||||
return broadcast_in_dim(reshape(x, inshape), shape, broadcast_dimensions)
|
||||
squeezed = squeeze(x, squeezed_dimensions)
|
||||
return broadcast_in_dim(squeezed, shape, broadcast_dimensions)
|
||||
else:
|
||||
return broadcast(x, shape)
|
||||
|
||||
@ -2926,6 +2930,74 @@ ad.primitive_transposes[pad_p] = _pad_transpose
|
||||
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
||||
|
||||
|
||||
# The squeeze primitive exists for the benefit of masking and other
|
||||
# transformations that need to keep track of axis identity.
|
||||
# For example, consider reshaping a 2D array with shape (1, N) into a 1D array
|
||||
# with shape (N,). This results in the following JAXpr:
|
||||
# reshape[ dimension=None new_sizes=(N,) ]
|
||||
# For N > 1, we can match up the output array axis with the second axis of the
|
||||
# input. But for N = 1, it is not clear how axes match up: all we know from the
|
||||
# JAXpr is that we are reshaping from (1, 1) to (1,).
|
||||
# In constrast, squeeze[ dimensions=(0,) ] is unambiguous.
|
||||
|
||||
def squeeze(array: Array, dimensions: Tuple[int, ...]) -> Array:
|
||||
"""Squeeze any number of size 1 dimensions from an array."""
|
||||
ndim = onp.ndim(array)
|
||||
dimensions = tuple(sorted(_canonicalize_axis(i, ndim) for i in dimensions))
|
||||
if not dimensions:
|
||||
return array
|
||||
return squeeze_p.bind(array, dimensions=dimensions)
|
||||
|
||||
def _squeeze_dtype_rule(operand, *, dimensions):
|
||||
return operand.dtype
|
||||
|
||||
def _squeeze_shape_rule(operand, *, dimensions):
|
||||
return _compute_squeeze_shape(onp.shape(operand), dimensions)
|
||||
|
||||
def _compute_squeeze_shape(shape, dimensions):
|
||||
dims_set = set(dimensions)
|
||||
if len(dims_set) != len(dimensions):
|
||||
raise ValueError(f"dimensions are not unique: {dimensions}")
|
||||
if not all(0 <= d < len(shape) for d in dims_set):
|
||||
raise ValueError(f"dimensions outside range [0, ndim): {dimensions}")
|
||||
if any(shape[d] != 1 for d in dimensions):
|
||||
raise ValueError(
|
||||
"cannot select an axis to squeeze out which has size not equal to "
|
||||
f"one, got shape={shape} and dimensions={dimensions}")
|
||||
return tuple(s for i, s in enumerate(shape) if i not in dims_set)
|
||||
|
||||
def _squeeze_translation_rule(c, arg, *, dimensions):
|
||||
new_shape = _compute_squeeze_shape(c.get_shape(arg).dimensions(), dimensions)
|
||||
return xops.Reshape(arg, new_shape)
|
||||
|
||||
def _squeeze_transpose_rule(t, operand, *, dimensions):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
return [expand_dims(t, dimensions)]
|
||||
|
||||
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
operand = batching.moveaxis(operand, bdim, 0)
|
||||
dimensions = tuple(onp.add(1, dimensions))
|
||||
return squeeze(operand, dimensions=dimensions), 0
|
||||
|
||||
squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
|
||||
'squeeze', _squeeze_translation_rule)
|
||||
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
|
||||
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
||||
|
||||
|
||||
def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array:
|
||||
"""Insert any number of size 1 dimensions into an array."""
|
||||
ndim_out = onp.ndim(array) + len(dimensions)
|
||||
dims_set = frozenset(_canonicalize_axis(i, ndim_out) for i in dimensions)
|
||||
result_shape = list(onp.shape(array))
|
||||
for i in sorted(dims_set):
|
||||
result_shape.insert(i, 1)
|
||||
broadcast_dims = [i for i in range(ndim_out) if i not in dims_set]
|
||||
return broadcast_in_dim(array, result_shape, broadcast_dims)
|
||||
|
||||
|
||||
# We have a nonstandard reshape impl so that we can be lazy about data movement.
|
||||
def _reshape_impl(operand, *, new_sizes, dimensions):
|
||||
old_sizes = onp.shape(operand)
|
||||
@ -3324,7 +3396,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
|
||||
def _batch_dynamic_slice_indices(indices, bdims):
|
||||
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
|
||||
if size < 0:
|
||||
return concatenate([reshape(i, [1]) for i in indices], 0), None
|
||||
return concatenate([broadcast(i, (1,)) for i in indices], 0), None
|
||||
indices = concatenate(
|
||||
[broadcast_in_dim(x, (size, 1),
|
||||
broadcast_dimensions=((0,) if i is not None else ()))
|
||||
@ -4389,7 +4461,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
operand = batching.moveaxis(operand, o_bdims, 0)
|
||||
outputs = [
|
||||
_select_and_scatter_add(s, o, **kwargs) for s, o in zip(source, operand)]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = [broadcast(out, (1,)) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif s_bdims is not None:
|
||||
@ -4397,7 +4469,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
source = batching.moveaxis(source, s_bdims, 0)
|
||||
outputs = [
|
||||
_select_and_scatter_add(s, operand, **kwargs) for s in source]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = [broadcast(out, (1,)) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
elif o_bdims is not None:
|
||||
@ -4405,7 +4477,7 @@ def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
operand = batching.moveaxis(operand, o_bdims, 0)
|
||||
outputs = [
|
||||
_select_and_scatter_add(source, o, **kwargs) for o in operand]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = [broadcast(out, (1,)) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
|
||||
@ -5191,7 +5263,7 @@ def _dynamic_slice_indices(operand, start_indices):
|
||||
if start_indices.ndim != 1:
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
start_indices = [reshape(slice(start_indices, [i], [i+1]), ())
|
||||
start_indices = [squeeze(slice(start_indices, [i], [i+1]), dimensions=(0,))
|
||||
for i in range(operand.ndim)]
|
||||
else:
|
||||
start_indices = [onp.asarray(i, dtype=dtypes.int_) if isinstance(i, int)
|
||||
@ -5395,12 +5467,12 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
|
||||
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = int(axis)
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
if axis < 0 or axis >= num_dims:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
if not -num_dims <= axis < num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
|
@ -223,6 +223,8 @@ def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
|
||||
sum = np.sum
|
||||
|
||||
squeeze = np.squeeze
|
||||
|
||||
def reshape(operand, new_sizes, dimensions=None):
|
||||
if dimensions is None:
|
||||
dimensions = range(len(np.shape(operand)))
|
||||
|
@ -204,6 +204,8 @@ load = np.load
|
||||
|
||||
### utility functions
|
||||
|
||||
_canonicalize_axis = lax._canonicalize_axis
|
||||
|
||||
def _promote_shapes(fun_name, *args):
|
||||
"""Prepend implicit leading singleton dimensions for Numpy broadcasting."""
|
||||
if len(args) < 2:
|
||||
@ -217,8 +219,7 @@ def _promote_shapes(fun_name, *args):
|
||||
if FLAGS.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||
return [lax.reshape(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
if shp and len(shp) != result_rank else arg
|
||||
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
for arg, shp in zip(args, shapes)]
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
@ -291,17 +292,6 @@ def _promote_args_inexact(fun_name, *args):
|
||||
def _constant_like(x, const):
|
||||
return np.array(const, dtype=_dtype(x))
|
||||
|
||||
def _canonicalize_axis(axis, num_dims):
|
||||
"""Canonicalize an axis in (-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = int(axis)
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
if axis < 0 or axis >= num_dims:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of dimension {}".format(
|
||||
axis, num_dims))
|
||||
return axis
|
||||
|
||||
### implementations of numpy functions in terms of lax
|
||||
|
||||
@_wraps(np.fmin)
|
||||
@ -1151,29 +1141,20 @@ def unravel_index(indices, shape):
|
||||
|
||||
|
||||
@_wraps(np.squeeze)
|
||||
def squeeze(a, axis=None):
|
||||
shape_a = shape(a)
|
||||
def squeeze(a, axis: Union[int, Tuple[int, ...]] = None):
|
||||
if axis is None:
|
||||
if 1 not in shape_a:
|
||||
return a
|
||||
newshape = [d for d in shape_a if d != 1]
|
||||
else:
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
axis = frozenset(_canonicalize_axis(i, ndim(a)) for i in axis)
|
||||
if _any(shape_a[a] != 1 for a in axis):
|
||||
raise ValueError("cannot select an axis to squeeze out which has size "
|
||||
"not equal to one")
|
||||
newshape = [d for i, d in enumerate(shape_a)
|
||||
if d != 1 or i not in axis]
|
||||
return lax.reshape(a, newshape)
|
||||
a_shape = shape(a)
|
||||
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
|
||||
elif isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
return lax.squeeze(a, axis)
|
||||
|
||||
|
||||
@_wraps(np.expand_dims)
|
||||
def expand_dims(a, axis):
|
||||
shape = _shape(a)
|
||||
axis = _canonicalize_axis(axis, ndim(a) + 1)
|
||||
return lax.reshape(a, shape[:axis] + (1,) + shape[axis:])
|
||||
def expand_dims(a, axis: Union[int, Tuple[int, ...]]):
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
return lax.expand_dims(a, axis)
|
||||
|
||||
|
||||
@_wraps(np.swapaxes)
|
||||
@ -1370,7 +1351,7 @@ def broadcast_to(arr, shape):
|
||||
diff, = np.where(np.not_equal(shape[nlead:], arr_shape))
|
||||
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
|
||||
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
|
||||
return lax.broadcast_in_dim(squeeze(arr, diff), shape, kept_dims)
|
||||
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
|
||||
|
||||
|
||||
@_wraps(np.split)
|
||||
@ -1550,8 +1531,7 @@ def _make_reduction(np_fun, op, init_val, preproc=None, bool_op=None,
|
||||
result = lax.reduce(a, _reduction_init_val(a, init_val),
|
||||
op if computation_dtype != np.bool_ else bool_op, dims)
|
||||
if keepdims:
|
||||
shape_with_singletons = subvals(shape(a), zip(dims, (1,) * len(dims)))
|
||||
result = lax.reshape(result, shape_with_singletons)
|
||||
result = expand_dims(result, dims)
|
||||
return lax.convert_element_type(result, dtype or result_dtype)
|
||||
|
||||
return reduction
|
||||
@ -1992,13 +1972,11 @@ def stack(arrays, axis=0):
|
||||
raise ValueError("Need at least one array to stack.")
|
||||
shape0 = shape(arrays[0])
|
||||
axis = _canonicalize_axis(axis, len(shape0) + 1)
|
||||
new_shape = list(shape0)
|
||||
new_shape.insert(axis, 1)
|
||||
new_arrays = []
|
||||
for a in arrays:
|
||||
if shape(a) != shape0:
|
||||
raise ValueError("All input arrays must have the same shape.")
|
||||
new_arrays.append(reshape(a, new_shape))
|
||||
new_arrays.append(expand_dims(a, axis))
|
||||
return concatenate(new_arrays, axis=axis)
|
||||
|
||||
@_wraps(np.tile)
|
||||
@ -2057,7 +2035,7 @@ def column_stack(tup):
|
||||
for v in tup:
|
||||
arr = array(v)
|
||||
if arr.ndim < 2:
|
||||
arr = arr.reshape((-1, 1))
|
||||
arr = expand_dims(arr, axis=0)
|
||||
arrays.append(arr)
|
||||
return concatenate(arrays, 1)
|
||||
|
||||
@ -2102,7 +2080,12 @@ def atleast_1d(*arys):
|
||||
def atleast_2d(*arys):
|
||||
if len(arys) == 1:
|
||||
arr = array(arys[0])
|
||||
return arr if ndim(arr) >= 2 else reshape(arr, (1, -1))
|
||||
if ndim(arr) >= 2:
|
||||
return arr
|
||||
elif ndim(arr) == 1:
|
||||
return expand_dims(arr, axis=0)
|
||||
else:
|
||||
return expand_dims(arr, axis=(0, 1))
|
||||
else:
|
||||
return [atleast_2d(arr) for arr in arys]
|
||||
|
||||
@ -2111,10 +2094,12 @@ def atleast_2d(*arys):
|
||||
def atleast_3d(*arys):
|
||||
if len(arys) == 1:
|
||||
arr = array(arys[0])
|
||||
if ndim(arr) <= 1:
|
||||
arr = reshape(arr, (1, -1, 1))
|
||||
if ndim(arr) == 0:
|
||||
arr = expand_dims(arr, axis=(0, 1, 2))
|
||||
elif ndim(arr) == 1:
|
||||
arr = expand_dims(arr, axis=(0, 2))
|
||||
elif ndim(arr) == 2:
|
||||
arr = reshape(arr, shape(arr) + (1,))
|
||||
arr = expand_dims(arr, axis=2)
|
||||
return arr
|
||||
else:
|
||||
return [atleast_3d(arr) for arr in arys]
|
||||
@ -2154,7 +2139,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
raise TypeError("Unexpected input type for array: {}".format(type(object)))
|
||||
|
||||
if ndmin > ndim(out):
|
||||
out = lax.reshape(out, (1,) * (ndmin - ndim(out)) + shape(out))
|
||||
out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))
|
||||
return out
|
||||
|
||||
@_wraps(np.asarray)
|
||||
@ -2397,7 +2382,7 @@ def ix_(*args):
|
||||
# Numpy uses an integer index type for empty arrays.
|
||||
output.append(lax.full(shape, np.zeros((), np.intp)))
|
||||
else:
|
||||
output.append(lax.reshape(a, shape))
|
||||
output.append(lax.broadcast_in_dim(a, shape, (i,)))
|
||||
return tuple(output)
|
||||
|
||||
|
||||
@ -2642,8 +2627,8 @@ def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
|
||||
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
|
||||
_check_arraylike("matmul", a, b)
|
||||
a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1)
|
||||
a = lax.reshape(a, (1,) + shape(a)) if a_is_vec else a
|
||||
b = lax.reshape(b, shape(b) + (1,)) if b_is_vec else b
|
||||
a = expand_dims(a, axis=0) if a_is_vec else a
|
||||
b = expand_dims(b, axis=-1) if b_is_vec else b
|
||||
|
||||
a, b = _promote_dtypes(a, b)
|
||||
batch_shape = lax.broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
|
||||
@ -2653,13 +2638,8 @@ def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
|
||||
dim_numbers = (((ndim(a) - 1,), (ndim(b) - 2,)), (batch_dims, batch_dims))
|
||||
result = lax.dot_general(a, b, dim_numbers, precision)
|
||||
|
||||
if a_is_vec or b_is_vec:
|
||||
m, n = shape(result)[-2:]
|
||||
new_m = () if a_is_vec else (m,)
|
||||
new_n = () if b_is_vec else (n,)
|
||||
return lax.reshape(result, batch_shape + new_m + new_n)
|
||||
else:
|
||||
return result
|
||||
squeeze_dims = ((-2,) if a_is_vec else ()) + ((-1,) if b_is_vec else ())
|
||||
return squeeze(result, squeeze_dims)
|
||||
|
||||
|
||||
@_wraps(np.vdot, lax_description=_PRECISION_DOC)
|
||||
@ -3325,7 +3305,7 @@ def unique(ar, return_index=False, return_inverse=False,
|
||||
def _rewriting_take(arr, idx):
|
||||
# Computes arr[idx].
|
||||
# All supported cases of indexing can be implemented as an XLA gather,
|
||||
# followed by an optional reverse and a reshape.
|
||||
# followed by an optional reverse and broadcast_in_dim.
|
||||
arr = asarray(arr)
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
|
||||
return _gather(arr, treedef, static_idx, dynamic_idx)
|
||||
@ -3353,7 +3333,7 @@ def _gather(arr, treedef, static_idx, dynamic_idx):
|
||||
y = lax.rev(y, indexer.reversed_y_dims)
|
||||
|
||||
# This adds np.newaxis/None dimensions.
|
||||
return lax.reshape(y, indexer.slice_shape)
|
||||
return expand_dims(y, indexer.newaxis_dims)
|
||||
|
||||
_Indexer = collections.namedtuple("_Indexer", [
|
||||
# The expected shape of the slice output.
|
||||
@ -3372,9 +3352,8 @@ _Indexer = collections.namedtuple("_Indexer", [
|
||||
# the gather.
|
||||
"reversed_y_dims",
|
||||
|
||||
# For scatters, we must eliminate any axes created by `newaxis`, which
|
||||
# are the following dimensions, which must be of size 1. For gathers, we
|
||||
# simply reshape to `slice_shape` to introduce the new axes.
|
||||
# Keep track of any axes created by `newaxis`. These must be inserted for
|
||||
# gathers and eliminated for scatters.
|
||||
"newaxis_dims",
|
||||
])
|
||||
|
||||
|
@ -103,8 +103,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
if b is not None or return_sign:
|
||||
raise NotImplementedError("Only implemented for b=None, return_sign=False")
|
||||
dims = _reduction_dims(a, axis)
|
||||
shape = util.subvals(np.shape(a), zip(dims, (1,) * len(dims)))
|
||||
dimadd = lambda x: lax.reshape(x, shape)
|
||||
dimadd = lambda x: lax.expand_dims(x, dims)
|
||||
amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
|
||||
amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
|
||||
amax_singletons = dimadd(amax)
|
||||
|
@ -1890,15 +1890,19 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
"rng_factory": jtu.rand_default}
|
||||
for arg_shape in [(), (3,), (3, 4)]
|
||||
for dtype in default_dtypes
|
||||
for dim in range(-len(arg_shape)+1, len(arg_shape))))
|
||||
for dim in (list(range(-len(arg_shape)+1, len(arg_shape)))
|
||||
+ [(0,), (len(arg_shape), len(arg_shape) + 1)])))
|
||||
def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
np_fun = lambda x: np.expand_dims(x, dim)
|
||||
jnp_fun = lambda x: jnp.expand_dims(x, dim)
|
||||
args_maker = lambda: [rng(arg_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
if isinstance(dim, tuple) and numpy_version < (1, 18, 0):
|
||||
raise SkipTest("support for multiple axes added in NumPy 1.18.0")
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_axes=({},{})".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
|
||||
@ -1935,23 +1939,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype), ax),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "ax": ax,
|
||||
"rng_factory": jtu.rand_default}
|
||||
for arg_shape, ax in [
|
||||
((3,), 0),
|
||||
((1, 3), 1),
|
||||
((1, 3, 1), (0, 1))]
|
||||
for dtype in default_dtypes))
|
||||
def testSqueezeFailsOnNonsingletonAxis(self, arg_shape, dtype, ax,
|
||||
rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
x = jnp.zeros(arg_shape, dtype=dtype)
|
||||
fun = lambda: jnp.squeeze(x, ax)
|
||||
self.assertRaisesRegex(ValueError, "cannot select an axis to squeeze", fun)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
|
@ -890,6 +890,47 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda x: lax_reference.broadcast_in_dim(x, outshape, dimensions)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}".format(
|
||||
jtu.format_shape_dtype_string(inshape, onp.float32), dimensions),
|
||||
"inshape": inshape, "dimensions": dimensions, "error_type": error_type,
|
||||
"err_msg": err_msg}
|
||||
for inshape, dimensions, error_type, err_msg in [
|
||||
((1, 2, 3), (0, 0), ValueError, 'dimensions are not unique'),
|
||||
((1, 2, 3), (3,), ValueError, 'axis 3 is out of bounds'),
|
||||
((1, 2, 3), (-4,), ValueError, 'axis -4 is out of bounds'),
|
||||
((1, 2, 3), (1,), ValueError, 'cannot select an axis to squeeze out'),
|
||||
((1, 2, 3), (None,), TypeError, 'cannot be interpreted as an integer'),
|
||||
]))
|
||||
def testSqueezeShapeCheck(self, inshape, dimensions, error_type, err_msg):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(inshape, onp.float32)
|
||||
with self.assertRaisesRegex(error_type, err_msg):
|
||||
lax.squeeze(x, dimensions=dimensions)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, onp.float32), dimensions),
|
||||
"arg_shape": arg_shape, "dimensions": dimensions,
|
||||
"rng_factory": rng_factory}
|
||||
for arg_shape, dimensions in [
|
||||
[(1,), (0,)],
|
||||
[(1,), (-1,)],
|
||||
[(2, 1, 4), (1,)],
|
||||
[(2, 1, 3, 1), (1,)],
|
||||
[(2, 1, 3, 1), (1, 3)],
|
||||
[(2, 1, 3, 1), (3,)],
|
||||
]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSqueeze(self, arg_shape, dimensions, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(arg_shape, onp.float32)]
|
||||
op = lambda x: lax.squeeze(x, dimensions)
|
||||
numpy_op = lambda x: lax_reference.squeeze(x, dimensions)
|
||||
self._CompileAndCheck(op, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
check_grads(op, args_maker(), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
@ -3010,6 +3051,30 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
|
||||
self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_dimensions={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, onp.float32),
|
||||
dimensions, bdims),
|
||||
"arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims,
|
||||
"rng_factory": rng_factory}
|
||||
for arg_shape, dimensions in [
|
||||
[(1,), (0,)],
|
||||
[(1,), (-1,)],
|
||||
[(2, 1, 4), (1,)],
|
||||
[(2, 1, 4), (-2,)],
|
||||
[(2, 1, 3, 1), (1,)],
|
||||
[(2, 1, 3, 1), (1, 3)],
|
||||
[(2, 1, 3, 1), (3,)],
|
||||
[(2, 1, 3, 1), (1, -1)],
|
||||
]
|
||||
for bdims in all_bdims(arg_shape)
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
|
||||
dtype = onp.float32
|
||||
rng = rng_factory(self.rng())
|
||||
op = lambda x: lax.squeeze(x, dimensions)
|
||||
self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
|
@ -416,6 +416,29 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = 8 / 3
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_arithmetic(self):
|
||||
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
|
||||
def times(x, y):
|
||||
return x * y
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
|
||||
ans = times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
|
||||
dict(n=4, m=5))
|
||||
# expected = np.array([[1, 2, 3], [8, 10, 12]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_stack(self):
|
||||
@partial(mask, in_shapes=['n','n'], out_shape='(2, n)')
|
||||
def stack(x, y):
|
||||
return jnp.stack([x, y], 0)
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
|
||||
ans = stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
|
||||
# expected = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='')
|
||||
def padded_sum(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user