mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add gradients to the scatter_max and scatter_min operations. (#3111)
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255 Co-authored-by: Alex Davies <adavies@google.com>
This commit is contained in:
parent
8d0749f13e
commit
85fe5a28f1
@ -3729,20 +3729,113 @@ ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule
|
||||
batching.primitive_batchers[scatter_mul_p] = (
|
||||
partial(_scatter_batching_rule, scatter_mul))
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr,
|
||||
update_consts, dimension_numbers):
|
||||
operand, scatter_indices, updates = primals
|
||||
g_operand, g_scatter_indices, g_updates = tangents
|
||||
|
||||
scatter_dnums = dimension_numbers
|
||||
updates_shape = updates.shape
|
||||
|
||||
val_out = scatter_op.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=scatter_dnums)
|
||||
|
||||
if g_operand is ad_util.zero and g_updates is ad_util.zero:
|
||||
tangent_out = ad_util.zero
|
||||
else:
|
||||
g_operand = ad.instantiate_zeros(operand, g_operand)
|
||||
g_updates = ad.instantiate_zeros(updates, g_updates)
|
||||
|
||||
# gather_dnums and slice_sizes define the gather op that is the inverse of
|
||||
# the scatter op specified by scatter_dnums
|
||||
gather_dnums = GatherDimensionNumbers(
|
||||
offset_dims=scatter_dnums.update_window_dims,
|
||||
collapsed_slice_dims=scatter_dnums.inserted_window_dims,
|
||||
start_index_map=scatter_dnums.scatter_dims_to_operand_dims)
|
||||
|
||||
slice_sizes = []
|
||||
pos = 0
|
||||
for i in range(len(operand.shape)):
|
||||
if i in scatter_dnums.inserted_window_dims:
|
||||
slice_sizes.append(1)
|
||||
else:
|
||||
slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]])
|
||||
pos += 1
|
||||
|
||||
# For consistency with other max operations, if there are two or more values
|
||||
# in updates that are contending to replace the same index location, the
|
||||
# resulting tangent at that location will be the average of the associated
|
||||
# tangents for the values in updates.
|
||||
|
||||
initial_vals = gather(
|
||||
operand, scatter_indices, gather_dnums, onp.array(slice_sizes))
|
||||
|
||||
target_vals = gather(
|
||||
val_out, scatter_indices, gather_dnums, onp.array(slice_sizes))
|
||||
|
||||
successful_updates = (updates == target_vals)
|
||||
retained_values = (initial_vals == target_vals)
|
||||
|
||||
num_updates = gather(
|
||||
scatter_add(_zeros(operand),
|
||||
scatter_indices,
|
||||
select(successful_updates, _ones(updates), _zeros(updates)),
|
||||
scatter_dnums),
|
||||
scatter_indices,
|
||||
gather_dnums,
|
||||
onp.array(slice_sizes))
|
||||
|
||||
num_refs = gather(
|
||||
scatter_add(_zeros(operand),
|
||||
scatter_indices,
|
||||
_ones(updates),
|
||||
scatter_dnums),
|
||||
scatter_indices,
|
||||
gather_dnums,
|
||||
onp.array(slice_sizes))
|
||||
|
||||
updates_normalizer = select(retained_values,
|
||||
1.0 / (num_updates + 1),
|
||||
1.0 / num_updates)
|
||||
|
||||
updates_coef = select(successful_updates,
|
||||
updates_normalizer,
|
||||
_zeros(updates))
|
||||
|
||||
operand_normalizer = select(retained_values,
|
||||
1.0 / (num_updates + 1),
|
||||
_zeros(num_updates))
|
||||
|
||||
operand_coef = (-1.0 + operand_normalizer) / num_refs
|
||||
|
||||
# This can be simplified once scatter has transpose implemented
|
||||
target_tangents = gather(
|
||||
g_operand, scatter_indices, gather_dnums, onp.array(slice_sizes))
|
||||
|
||||
tangent_updates = (target_tangents * operand_coef +
|
||||
g_updates * updates_coef)
|
||||
|
||||
tangent_out = scatter_add(g_operand,
|
||||
scatter_indices,
|
||||
tangent_updates,
|
||||
scatter_dnums)
|
||||
|
||||
return val_out, tangent_out
|
||||
|
||||
scatter_min_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-min',
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_min_p] = (
|
||||
partial(_scatter_batching_rule, scatter_min))
|
||||
ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p)
|
||||
|
||||
# TODO(jlebar): Add derivatives.
|
||||
scatter_max_p = standard_primitive(
|
||||
_scatter_shape_rule, _scatter_dtype_rule, 'scatter-max',
|
||||
_scatter_translation_rule)
|
||||
batching.primitive_batchers[scatter_max_p] = (
|
||||
partial(_scatter_batching_rule, scatter_max))
|
||||
|
||||
ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p)
|
||||
|
||||
def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
||||
dimension_numbers):
|
||||
|
@ -2629,6 +2629,68 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
1.)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums,
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
|
||||
rng_factory, rng_idx_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
rng_idx = rng_idx_factory(self.rng())
|
||||
idxs = rng_idx(idxs.shape, idxs.dtype)
|
||||
scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
|
||||
x = rng(arg_shape, dtype)
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
idxs, update_shape, dnums),
|
||||
"arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
|
||||
"update_shape": update_shape, "dnums": dnums,
|
||||
"rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
|
||||
for dtype in grad_float_dtypes
|
||||
for arg_shape, idxs, update_shape, dnums in [
|
||||
((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]
|
||||
for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
|
||||
rng_factory, rng_idx_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
rng_idx = rng_idx_factory(self.rng())
|
||||
idxs = rng_idx(idxs.shape, idxs.dtype)
|
||||
scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
|
||||
x = rng(arg_shape, dtype)
|
||||
y = rng(update_shape, dtype)
|
||||
check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
|
||||
|
||||
def testStopGradient(self):
|
||||
def f(x):
|
||||
return lax.sin(x) * lax.cos(lax.stop_gradient(x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user