Add gradients to the scatter_max and scatter_min operations. (#3111)

This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255

Co-authored-by: Alex Davies <adavies@google.com>
This commit is contained in:
alexdavies 2020-05-19 07:06:32 +01:00 committed by GitHub
parent 8d0749f13e
commit 85fe5a28f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 158 additions and 3 deletions

View File

@ -3729,20 +3729,113 @@ ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
batching.primitive_batchers[scatter_mul_p] = (
partial(_scatter_batching_rule, scatter_mul))
# TODO(jlebar): Add derivatives.
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
update_consts, dimension_numbers):
operand, scatter_indices, updates = primals
g_operand, g_scatter_indices, g_updates = tangents
scatter_dnums = dimension_numbers
updates_shape = updates.shape
val_out = scatter_op.bind(
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=scatter_dnums)
if g_operand is ad_util.zero and g_updates is ad_util.zero:
tangent_out = ad_util.zero
else:
g_operand = ad.instantiate_zeros(operand, g_operand)
g_updates = ad.instantiate_zeros(updates, g_updates)
# gather_dnums and slice_sizes define the gather op that is the inverse of
# the scatter op specified by scatter_dnums
gather_dnums = GatherDimensionNumbers(
offset_dims=scatter_dnums.update_window_dims,
collapsed_slice_dims=scatter_dnums.inserted_window_dims,
start_index_map=scatter_dnums.scatter_dims_to_operand_dims)
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in scatter_dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]])
pos += 1
# For consistency with other max operations, if there are two or more values
# in updates that are contending to replace the same index location, the
# resulting tangent at that location will be the average of the associated
# tangents for the values in updates.
initial_vals = gather(
operand, scatter_indices, gather_dnums, onp.array(slice_sizes))
target_vals = gather(
val_out, scatter_indices, gather_dnums, onp.array(slice_sizes))
successful_updates = (updates == target_vals)
retained_values = (initial_vals == target_vals)
num_updates = gather(
scatter_add(_zeros(operand),
scatter_indices,
select(successful_updates, _ones(updates), _zeros(updates)),
scatter_dnums),
scatter_indices,
gather_dnums,
onp.array(slice_sizes))
num_refs = gather(
scatter_add(_zeros(operand),
scatter_indices,
_ones(updates),
scatter_dnums),
scatter_indices,
gather_dnums,
onp.array(slice_sizes))
updates_normalizer = select(retained_values,
1.0 / (num_updates + 1),
1.0 / num_updates)
updates_coef = select(successful_updates,
updates_normalizer,
_zeros(updates))
operand_normalizer = select(retained_values,
1.0 / (num_updates + 1),
_zeros(num_updates))
operand_coef = (-1.0 + operand_normalizer) / num_refs
# This can be simplified once scatter has transpose implemented
target_tangents = gather(
g_operand, scatter_indices, gather_dnums, onp.array(slice_sizes))
tangent_updates = (target_tangents * operand_coef +
g_updates * updates_coef)
tangent_out = scatter_add(g_operand,
scatter_indices,
tangent_updates,
scatter_dnums)
return val_out, tangent_out
scatter_min_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
_scatter_translation_rule)
batching.primitive_batchers[scatter_min_p] = (
partial(_scatter_batching_rule, scatter_min))
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
# TODO(jlebar): Add derivatives.
scatter_max_p = standard_primitive(
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
_scatter_translation_rule)
batching.primitive_batchers[scatter_max_p] = (
partial(_scatter_batching_rule, scatter_max))
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
dimension_numbers):

View File

@ -2629,6 +2629,68 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
1.)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums,
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
for dtype in grad_float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
for rng_factory in [jtu.rand_default]))
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
rng_factory, rng_idx_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
idxs, update_shape, dnums),
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
"update_shape": update_shape, "dnums": dnums,
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
for dtype in grad_float_dtypes
for arg_shape, idxs, update_shape, dnums in [
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
for rng_factory in [jtu.rand_default]))
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
rng_factory, rng_idx_factory):
rng = rng_factory(self.rng())
rng_idx = rng_idx_factory(self.rng())
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
def testStopGradient(self):
def f(x):
return lax.sin(x) * lax.cos(lax.stop_gradient(x))