Fix gradient for np.amin and np.amax.

The JVP rule for `lax.reduce` depends on being able to identify the reducer as a monoid reducer. To get the correct behavior on complex numbers, `np.{amin,amax}` passed a non-standard reducer that compared complex numbers lexicographically as (real, imaginary) pairs. However, this prevented the gradient rule from identifying the reducer.

Instead, change the `lax.min` and `lax.max` to use the Numpy semantics when comparing complex numbers, and change `np.amin` and `np.amax` to use them.

Move the `np._broadcast_shapes` helper into `lax.py` as `lax.broadcast_shapes`.
This commit is contained in:
Peter Hawkins 2019-02-01 11:07:45 -05:00
parent 8d52ee899c
commit fb659e22b9
3 changed files with 82 additions and 58 deletions

View File

@ -42,7 +42,7 @@ from .interpreters import xla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .util import curry, safe_zip, unzip2, prod
from .util import curry, memoize, safe_zip, unzip2, prod
from .tree_util import build_tree
from .lib import xla_bridge
@ -52,6 +52,23 @@ _max = builtins.max
_min = builtins.max
_reduce = six.moves.reduce
@memoize
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
if len(shapes) == 1:
return shapes[0]
ndim = _max(len(shape) for shape in shapes)
shapes = onp.array([(1,) * (ndim - len(shape)) + shape for shape in shapes])
min_shape = onp.min(shapes, axis=0)
max_shape = onp.max(shapes, axis=0)
result_shape = onp.where(min_shape == 0, 0, max_shape)
if not onp.all((shapes == result_shape) | (shapes == 1)):
raise ValueError("Incompatible shapes for broadcasting: {}"
.format(tuple(map(tuple, shapes))))
return tuple(result_shape)
def identity(x): return x
### traceables
@ -97,8 +114,19 @@ def mul(x, y): return mul_p.bind(x, y)
def div(x, y): return div_p.bind(x, y)
def rem(x, y): return rem_p.bind(x, y)
def max(x, y): return max_p.bind(x, y)
def min(x, y): return min_p.bind(x, y)
def max(x, y):
"""Elementwise maximum.
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return max_p.bind(x, y)
def min(x, y):
"""Elementwise minimum.
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return min_p.bind(x, y)
def shift_left(x, y): return shift_left_p.bind(x, y)
def shift_right_arithmetic(x, y): return shift_right_arithmetic_p.bind(x, y)
@ -312,7 +340,7 @@ def _get_monoid_reducer(monoid_op, x):
return aval.val == _get_min_identity(aval.dtype) and _reduce_and
def _get_max_identity(dtype):
if onp.issubdtype(dtype, onp.floating):
if onp.issubdtype(dtype, onp.inexact):
return onp.array(-onp.inf, dtype)
elif onp.issubdtype(dtype, onp.integer):
return onp.array(onp.iinfo(dtype).min, dtype)
@ -320,7 +348,7 @@ def _get_max_identity(dtype):
return onp.array(False, onp.bool_)
def _get_min_identity(dtype):
if onp.issubdtype(dtype, onp.floating):
if onp.issubdtype(dtype, onp.inexact):
return onp.array(onp.inf, dtype)
elif onp.issubdtype(dtype, onp.integer):
return onp.array(onp.iinfo(dtype).max, dtype)
@ -809,10 +837,11 @@ def broadcasting_shape_rule(name, *avals):
return tuple(result_shape)
def binop(result_dtype, accepted_dtypes, name):
def binop(result_dtype, accepted_dtypes, name, translation_rule=None):
dtype_rule = partial(binop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(broadcasting_shape_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
translation_rule=translation_rule)
batching.defbroadcasting(prim)
parallel.defbroadcasting(prim)
return prim
@ -999,12 +1028,39 @@ ad.defjvp(rem_p,
lambda g, x, y: mul(neg(g), floor(div(x, y))))
max_p = standard_binop([_any, _any], 'max')
def _broadcasting_select(c, which, x, y):
"""Wrapper around XLA `Select` that broadcasts its arguments."""
which_shape, x_shape, y_shape = (
c.GetShape(t).dimensions() for t in (which, x, y))
out_shape = broadcast_shapes(which_shape, x_shape, y_shape)
bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape),
len(out_shape)))
which = c.BroadcastInDim(which, out_shape, bcast_dims(which_shape))
x = c.BroadcastInDim(x, out_shape, bcast_dims(x_shape))
y = c.BroadcastInDim(y, out_shape, bcast_dims(y_shape))
return c.Select(which, x, y)
def _minmax_translation_rule(c, x, y, minmax=None, cmp=None):
dtype = c.GetShape(x).numpy_dtype()
if onp.issubdtype(dtype, onp.complexfloating):
comparator = cmp(c)
rx = c.Real(x)
ry = c.Real(y)
return _broadcasting_select(
c, c.Select(c.Eq(rx, ry), comparator(c.Imag(x), c.Imag(y)),
comparator(rx, ry)),
x, y)
return minmax(c)(x, y)
max_p = standard_binop([_any, _any], 'max', translation_rule=partial(
_minmax_translation_rule, minmax=lambda c: c.Max, cmp=lambda c: c.Gt))
ad.defjvp2(max_p,
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
min_p = standard_binop([_any, _any], 'min')
min_p = standard_binop([_any, _any], 'min', translation_rule=partial(
_minmax_translation_rule, minmax=lambda c: c.Min, cmp=lambda c: c.Lt))
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
@ -1159,7 +1215,7 @@ def conv_general_dilated_batch_rule(
# convolution isn't the first dimension.
if lhs_dim[0] != 0 or out_dim[0] != 0:
raise NotImplementedError
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
batched_size = lhs.shape[0]
n_size = lhs.shape[1]

View File

@ -132,27 +132,10 @@ def _promote_shapes(*args):
return args
else:
shapes = [shape(arg) for arg in args]
nd = len(_broadcast_shapes(*shapes))
nd = len(lax.broadcast_shapes(*shapes))
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
@memoize
def _broadcast_shapes(*shapes):
"""Apply Numpy broadcasting rules to the given shapes."""
if len(shapes) == 1:
return shapes[0]
ndim = _max(len(shape) for shape in shapes)
shapes = onp.array([(1,) * (ndim - len(shape)) + shape for shape in shapes])
min_shape = onp.min(shapes, axis=0)
max_shape = onp.max(shapes, axis=0)
result_shape = onp.where(min_shape == 0, 0, max_shape)
if not onp.all((shapes == result_shape) | (shapes == 1)):
raise ValueError("Incompatible shapes for broadcasting: {}"
.format(tuple(map(tuple, shapes))))
return tuple(result_shape)
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
@ -292,6 +275,8 @@ not_equal = _one_to_one_binop(onp.not_equal, lax.ne)
subtract = _one_to_one_binop(onp.subtract, lax.sub)
power = _one_to_one_binop(onp.power, lax.pow, True)
arctan2 = _one_to_one_binop(onp.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(onp.minimum, lax.min)
maximum = _one_to_one_binop(onp.maximum, lax.max)
def _comparison_op(numpy_fn, lax_fn):
@ -312,24 +297,6 @@ greater = _comparison_op(onp.greater, lax.gt)
less_equal = _comparison_op(onp.less_equal, lax.le)
less = _comparison_op(onp.less, lax.lt)
def _minmax_op(numpy_fn, lax_fn, lax_cmp_fn):
def fn(x, y):
x, y = _promote_args(numpy_fn.__name__, x, y)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if issubdtype(_dtype(x), complexfloating):
rx = lax.real(x)
ry = lax.real(y)
return where(
lax.select(lax.eq(rx, ry), lax_cmp_fn(lax.imag(x), lax.imag(y)),
lax_cmp_fn(rx, ry)),
x, y)
return lax_fn(x, y)
return _wraps(numpy_fn)(fn)
maximum = _minmax_op(onp.maximum, lax.max, lax.gt)
minimum = _minmax_op(onp.minimum, lax.min, lax.lt)
def _logical_op(np_op, bitwise_op):
@_wraps(np_op)
@ -632,7 +599,7 @@ def broadcast_arrays(*args):
if len(set(shapes)) == 1:
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
for arg in args]
result_shape = _broadcast_shapes(*shapes)
result_shape = lax.broadcast_shapes(*shapes)
return [broadcast_to(arg, result_shape) for arg in args]
@ -642,7 +609,7 @@ def broadcast_to(arr, shape):
if _shape(arr) != shape:
# TODO(mattjj): revise this to call lax.broadcast_in_dim rather than
# lax.broadcast and lax.transpose
_broadcast_shapes(shape, _shape(arr)) # error checking
lax.broadcast_shapes(shape, _shape(arr)) # error checking
nlead = len(shape) - len(_shape(arr))
diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr)))
@ -675,11 +642,11 @@ def clip(a, a_min=None, a_max=None):
if a_min is not None:
if _dtype(a_min) != _dtype(a):
a_min = lax.convert_element_type(a_min, _dtype(a))
a = maximum(a_min, a)
a = lax.max(a_min, a)
if a_max is not None:
if _dtype(a_max) != _dtype(a):
a_max = lax.convert_element_type(a_max, _dtype(a))
a = minimum(a_max, a)
a = lax.min(a_max, a)
return a
@ -817,8 +784,8 @@ _cast_to_bool = partial(lax.convert_element_type, new_dtype=onp.bool_)
sum = _make_reduction(onp.sum, lax.add, 0)
prod = _make_reduction(onp.prod, lax.mul, 1)
amax = max = _make_reduction(onp.max, maximum, -onp.inf)
amin = min = _make_reduction(onp.min, minimum, onp.inf)
amax = max = _make_reduction(onp.max, lax.max, -onp.inf)
amin = min = _make_reduction(onp.min, lax.min, onp.inf)
all = alltrue = _make_reduction(onp.all, lax.bitwise_and, True, _cast_to_bool)
any = sometrue = _make_reduction(onp.any, lax.bitwise_or, False, _cast_to_bool)
@ -1303,7 +1270,7 @@ def matmul(a, b): # pylint: disable=missing-docstring
b = lax.reshape(b, shape(b) + (1,)) if b_is_vec else b
a, b = _promote_dtypes(a, b)
batch_shape = _broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
batch_shape = lax.broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
a = broadcast_to(a, batch_shape + shape(a)[-2:])
b = broadcast_to(b, batch_shape + shape(b)[-2:])
batch_dims = tuple(range(len(batch_shape)))

View File

@ -53,6 +53,7 @@ def num_float_bits(dtype):
float_dtypes = [onp.float32, onp.float64]
complex_dtypes = [onp.complex64, onp.complex128]
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes
@ -122,8 +123,8 @@ LAX_OPS = [
op_record(lax.div, 2, default_dtypes + complex_dtypes, jtu.rand_nonzero()),
op_record(lax.rem, 2, default_dtypes, jtu.rand_nonzero()),
op_record(lax.max, 2, default_dtypes, jtu.rand_small()),
op_record(lax.min, 2, default_dtypes, jtu.rand_small()),
op_record(lax.max, 2, all_dtypes, jtu.rand_small()),
op_record(lax.min, 2, all_dtypes, jtu.rand_small()),
op_record(lax.eq, 2, all_dtypes, jtu.rand_some_equal()),
op_record(lax.ne, 2, all_dtypes, jtu.rand_small()),
@ -1976,9 +1977,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, float_dtypes),
(-onp.inf, lax.max, float_dtypes),
(onp.inf, lax.min, float_dtypes),
(0, lax.add, inexact_dtypes),
(-onp.inf, lax.max, inexact_dtypes),
(onp.inf, lax.min, inexact_dtypes),
]
for dtype in dtypes
for shape, dims in [