mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix gradient for np.amin
and np.amax
.
The JVP rule for `lax.reduce` depends on being able to identify the reducer as a monoid reducer. To get the correct behavior on complex numbers, `np.{amin,amax}` passed a non-standard reducer that compared complex numbers lexicographically as (real, imaginary) pairs. However, this prevented the gradient rule from identifying the reducer. Instead, change the `lax.min` and `lax.max` to use the Numpy semantics when comparing complex numbers, and change `np.amin` and `np.amax` to use them. Move the `np._broadcast_shapes` helper into `lax.py` as `lax.broadcast_shapes`.
This commit is contained in:
parent
8d52ee899c
commit
fb659e22b9
76
jax/lax.py
76
jax/lax.py
@ -42,7 +42,7 @@ from .interpreters import xla
|
||||
from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
from .interpreters import parallel
|
||||
from .util import curry, safe_zip, unzip2, prod
|
||||
from .util import curry, memoize, safe_zip, unzip2, prod
|
||||
from .tree_util import build_tree
|
||||
from .lib import xla_bridge
|
||||
|
||||
@ -52,6 +52,23 @@ _max = builtins.max
|
||||
_min = builtins.max
|
||||
_reduce = six.moves.reduce
|
||||
|
||||
|
||||
@memoize
|
||||
def broadcast_shapes(*shapes):
|
||||
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
|
||||
if len(shapes) == 1:
|
||||
return shapes[0]
|
||||
ndim = _max(len(shape) for shape in shapes)
|
||||
shapes = onp.array([(1,) * (ndim - len(shape)) + shape for shape in shapes])
|
||||
min_shape = onp.min(shapes, axis=0)
|
||||
max_shape = onp.max(shapes, axis=0)
|
||||
result_shape = onp.where(min_shape == 0, 0, max_shape)
|
||||
if not onp.all((shapes == result_shape) | (shapes == 1)):
|
||||
raise ValueError("Incompatible shapes for broadcasting: {}"
|
||||
.format(tuple(map(tuple, shapes))))
|
||||
return tuple(result_shape)
|
||||
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
### traceables
|
||||
@ -97,8 +114,19 @@ def mul(x, y): return mul_p.bind(x, y)
|
||||
def div(x, y): return div_p.bind(x, y)
|
||||
def rem(x, y): return rem_p.bind(x, y)
|
||||
|
||||
def max(x, y): return max_p.bind(x, y)
|
||||
def min(x, y): return min_p.bind(x, y)
|
||||
def max(x, y):
|
||||
"""Elementwise maximum.
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
return max_p.bind(x, y)
|
||||
|
||||
def min(x, y):
|
||||
"""Elementwise minimum.
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
return min_p.bind(x, y)
|
||||
|
||||
def shift_left(x, y): return shift_left_p.bind(x, y)
|
||||
def shift_right_arithmetic(x, y): return shift_right_arithmetic_p.bind(x, y)
|
||||
@ -312,7 +340,7 @@ def _get_monoid_reducer(monoid_op, x):
|
||||
return aval.val == _get_min_identity(aval.dtype) and _reduce_and
|
||||
|
||||
def _get_max_identity(dtype):
|
||||
if onp.issubdtype(dtype, onp.floating):
|
||||
if onp.issubdtype(dtype, onp.inexact):
|
||||
return onp.array(-onp.inf, dtype)
|
||||
elif onp.issubdtype(dtype, onp.integer):
|
||||
return onp.array(onp.iinfo(dtype).min, dtype)
|
||||
@ -320,7 +348,7 @@ def _get_max_identity(dtype):
|
||||
return onp.array(False, onp.bool_)
|
||||
|
||||
def _get_min_identity(dtype):
|
||||
if onp.issubdtype(dtype, onp.floating):
|
||||
if onp.issubdtype(dtype, onp.inexact):
|
||||
return onp.array(onp.inf, dtype)
|
||||
elif onp.issubdtype(dtype, onp.integer):
|
||||
return onp.array(onp.iinfo(dtype).max, dtype)
|
||||
@ -809,10 +837,11 @@ def broadcasting_shape_rule(name, *avals):
|
||||
return tuple(result_shape)
|
||||
|
||||
|
||||
def binop(result_dtype, accepted_dtypes, name):
|
||||
def binop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
||||
dtype_rule = partial(binop_dtype_rule, result_dtype, accepted_dtypes, name)
|
||||
shape_rule = partial(broadcasting_shape_rule, name)
|
||||
prim = standard_primitive(shape_rule, dtype_rule, name)
|
||||
prim = standard_primitive(shape_rule, dtype_rule, name,
|
||||
translation_rule=translation_rule)
|
||||
batching.defbroadcasting(prim)
|
||||
parallel.defbroadcasting(prim)
|
||||
return prim
|
||||
@ -999,12 +1028,39 @@ ad.defjvp(rem_p,
|
||||
lambda g, x, y: mul(neg(g), floor(div(x, y))))
|
||||
|
||||
|
||||
max_p = standard_binop([_any, _any], 'max')
|
||||
def _broadcasting_select(c, which, x, y):
|
||||
"""Wrapper around XLA `Select` that broadcasts its arguments."""
|
||||
which_shape, x_shape, y_shape = (
|
||||
c.GetShape(t).dimensions() for t in (which, x, y))
|
||||
out_shape = broadcast_shapes(which_shape, x_shape, y_shape)
|
||||
bcast_dims = lambda shape: tuple(range(len(out_shape) - len(shape),
|
||||
len(out_shape)))
|
||||
which = c.BroadcastInDim(which, out_shape, bcast_dims(which_shape))
|
||||
x = c.BroadcastInDim(x, out_shape, bcast_dims(x_shape))
|
||||
y = c.BroadcastInDim(y, out_shape, bcast_dims(y_shape))
|
||||
return c.Select(which, x, y)
|
||||
|
||||
|
||||
def _minmax_translation_rule(c, x, y, minmax=None, cmp=None):
|
||||
dtype = c.GetShape(x).numpy_dtype()
|
||||
if onp.issubdtype(dtype, onp.complexfloating):
|
||||
comparator = cmp(c)
|
||||
rx = c.Real(x)
|
||||
ry = c.Real(y)
|
||||
return _broadcasting_select(
|
||||
c, c.Select(c.Eq(rx, ry), comparator(c.Imag(x), c.Imag(y)),
|
||||
comparator(rx, ry)),
|
||||
x, y)
|
||||
return minmax(c)(x, y)
|
||||
|
||||
max_p = standard_binop([_any, _any], 'max', translation_rule=partial(
|
||||
_minmax_translation_rule, minmax=lambda c: c.Max, cmp=lambda c: c.Gt))
|
||||
ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
||||
|
||||
min_p = standard_binop([_any, _any], 'min')
|
||||
min_p = standard_binop([_any, _any], 'min', translation_rule=partial(
|
||||
_minmax_translation_rule, minmax=lambda c: c.Min, cmp=lambda c: c.Lt))
|
||||
ad.defjvp2(min_p,
|
||||
lambda g, ans, x, y: mul(_brcast(g, y), _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(_brcast(g, x), _balanced_eq(y, ans, x)))
|
||||
@ -1159,7 +1215,7 @@ def conv_general_dilated_batch_rule(
|
||||
# convolution isn't the first dimension.
|
||||
if lhs_dim[0] != 0 or out_dim[0] != 0:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
|
||||
batched_size = lhs.shape[0]
|
||||
n_size = lhs.shape[1]
|
||||
|
@ -132,27 +132,10 @@ def _promote_shapes(*args):
|
||||
return args
|
||||
else:
|
||||
shapes = [shape(arg) for arg in args]
|
||||
nd = len(_broadcast_shapes(*shapes))
|
||||
nd = len(lax.broadcast_shapes(*shapes))
|
||||
return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
|
||||
if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
|
||||
|
||||
|
||||
@memoize
|
||||
def _broadcast_shapes(*shapes):
|
||||
"""Apply Numpy broadcasting rules to the given shapes."""
|
||||
if len(shapes) == 1:
|
||||
return shapes[0]
|
||||
ndim = _max(len(shape) for shape in shapes)
|
||||
shapes = onp.array([(1,) * (ndim - len(shape)) + shape for shape in shapes])
|
||||
min_shape = onp.min(shapes, axis=0)
|
||||
max_shape = onp.max(shapes, axis=0)
|
||||
result_shape = onp.where(min_shape == 0, 0, max_shape)
|
||||
if not onp.all((shapes == result_shape) | (shapes == 1)):
|
||||
raise ValueError("Incompatible shapes for broadcasting: {}"
|
||||
.format(tuple(map(tuple, shapes))))
|
||||
return tuple(result_shape)
|
||||
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
@ -292,6 +275,8 @@ not_equal = _one_to_one_binop(onp.not_equal, lax.ne)
|
||||
subtract = _one_to_one_binop(onp.subtract, lax.sub)
|
||||
power = _one_to_one_binop(onp.power, lax.pow, True)
|
||||
arctan2 = _one_to_one_binop(onp.arctan2, lax.atan2, True)
|
||||
minimum = _one_to_one_binop(onp.minimum, lax.min)
|
||||
maximum = _one_to_one_binop(onp.maximum, lax.max)
|
||||
|
||||
|
||||
def _comparison_op(numpy_fn, lax_fn):
|
||||
@ -312,24 +297,6 @@ greater = _comparison_op(onp.greater, lax.gt)
|
||||
less_equal = _comparison_op(onp.less_equal, lax.le)
|
||||
less = _comparison_op(onp.less, lax.lt)
|
||||
|
||||
def _minmax_op(numpy_fn, lax_fn, lax_cmp_fn):
|
||||
def fn(x, y):
|
||||
x, y = _promote_args(numpy_fn.__name__, x, y)
|
||||
# Comparison on complex types are defined as a lexicographic ordering on
|
||||
# the (real, imag) pair.
|
||||
if issubdtype(_dtype(x), complexfloating):
|
||||
rx = lax.real(x)
|
||||
ry = lax.real(y)
|
||||
return where(
|
||||
lax.select(lax.eq(rx, ry), lax_cmp_fn(lax.imag(x), lax.imag(y)),
|
||||
lax_cmp_fn(rx, ry)),
|
||||
x, y)
|
||||
return lax_fn(x, y)
|
||||
return _wraps(numpy_fn)(fn)
|
||||
|
||||
maximum = _minmax_op(onp.maximum, lax.max, lax.gt)
|
||||
minimum = _minmax_op(onp.minimum, lax.min, lax.lt)
|
||||
|
||||
|
||||
def _logical_op(np_op, bitwise_op):
|
||||
@_wraps(np_op)
|
||||
@ -632,7 +599,7 @@ def broadcast_arrays(*args):
|
||||
if len(set(shapes)) == 1:
|
||||
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
|
||||
for arg in args]
|
||||
result_shape = _broadcast_shapes(*shapes)
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
|
||||
@ -642,7 +609,7 @@ def broadcast_to(arr, shape):
|
||||
if _shape(arr) != shape:
|
||||
# TODO(mattjj): revise this to call lax.broadcast_in_dim rather than
|
||||
# lax.broadcast and lax.transpose
|
||||
_broadcast_shapes(shape, _shape(arr)) # error checking
|
||||
lax.broadcast_shapes(shape, _shape(arr)) # error checking
|
||||
nlead = len(shape) - len(_shape(arr))
|
||||
diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr)))
|
||||
|
||||
@ -675,11 +642,11 @@ def clip(a, a_min=None, a_max=None):
|
||||
if a_min is not None:
|
||||
if _dtype(a_min) != _dtype(a):
|
||||
a_min = lax.convert_element_type(a_min, _dtype(a))
|
||||
a = maximum(a_min, a)
|
||||
a = lax.max(a_min, a)
|
||||
if a_max is not None:
|
||||
if _dtype(a_max) != _dtype(a):
|
||||
a_max = lax.convert_element_type(a_max, _dtype(a))
|
||||
a = minimum(a_max, a)
|
||||
a = lax.min(a_max, a)
|
||||
return a
|
||||
|
||||
|
||||
@ -817,8 +784,8 @@ _cast_to_bool = partial(lax.convert_element_type, new_dtype=onp.bool_)
|
||||
|
||||
sum = _make_reduction(onp.sum, lax.add, 0)
|
||||
prod = _make_reduction(onp.prod, lax.mul, 1)
|
||||
amax = max = _make_reduction(onp.max, maximum, -onp.inf)
|
||||
amin = min = _make_reduction(onp.min, minimum, onp.inf)
|
||||
amax = max = _make_reduction(onp.max, lax.max, -onp.inf)
|
||||
amin = min = _make_reduction(onp.min, lax.min, onp.inf)
|
||||
all = alltrue = _make_reduction(onp.all, lax.bitwise_and, True, _cast_to_bool)
|
||||
any = sometrue = _make_reduction(onp.any, lax.bitwise_or, False, _cast_to_bool)
|
||||
|
||||
@ -1303,7 +1270,7 @@ def matmul(a, b): # pylint: disable=missing-docstring
|
||||
b = lax.reshape(b, shape(b) + (1,)) if b_is_vec else b
|
||||
|
||||
a, b = _promote_dtypes(a, b)
|
||||
batch_shape = _broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
|
||||
batch_shape = lax.broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
|
||||
a = broadcast_to(a, batch_shape + shape(a)[-2:])
|
||||
b = broadcast_to(b, batch_shape + shape(b)[-2:])
|
||||
batch_dims = tuple(range(len(batch_shape)))
|
||||
|
@ -53,6 +53,7 @@ def num_float_bits(dtype):
|
||||
|
||||
float_dtypes = [onp.float32, onp.float64]
|
||||
complex_dtypes = [onp.complex64, onp.complex128]
|
||||
inexact_dtypes = float_dtypes + complex_dtypes
|
||||
int_dtypes = [onp.int32, onp.int64]
|
||||
bool_dtypes = [onp.bool_]
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
@ -122,8 +123,8 @@ LAX_OPS = [
|
||||
op_record(lax.div, 2, default_dtypes + complex_dtypes, jtu.rand_nonzero()),
|
||||
op_record(lax.rem, 2, default_dtypes, jtu.rand_nonzero()),
|
||||
|
||||
op_record(lax.max, 2, default_dtypes, jtu.rand_small()),
|
||||
op_record(lax.min, 2, default_dtypes, jtu.rand_small()),
|
||||
op_record(lax.max, 2, all_dtypes, jtu.rand_small()),
|
||||
op_record(lax.min, 2, all_dtypes, jtu.rand_small()),
|
||||
|
||||
op_record(lax.eq, 2, all_dtypes, jtu.rand_some_equal()),
|
||||
op_record(lax.ne, 2, all_dtypes, jtu.rand_small()),
|
||||
@ -1976,9 +1977,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
|
||||
"dims": dims, "rng": rng}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, float_dtypes),
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(onp.inf, lax.min, float_dtypes),
|
||||
(0, lax.add, inexact_dtypes),
|
||||
(-onp.inf, lax.max, inexact_dtypes),
|
||||
(onp.inf, lax.min, inexact_dtypes),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for shape, dims in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user