add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool)

This commit is contained in:
James Bradbury 2019-04-20 17:06:35 -07:00
parent 9baf42d978
commit b940245730
2 changed files with 22 additions and 6 deletions

View File

@ -3252,13 +3252,30 @@ def _select_and_scatter_add_translation(
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, zero, scatter)
def _select_and_scatter_add_jvp(
primals, tangents, select_prim, window_dimensions, window_strides,
padding):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_scatter_add(
source, operand, select_prim, window_dimensions, window_strides,
padding)
del g_operand
if g_source is ad_util.zero:
tangent_out = ad_util.zero
else:
tangent_out = _select_and_scatter_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding)
return val_out, tangent_out
def _select_and_scatter_add_transpose(
t, source, operand, select_prim, window_dimensions, window_strides,
padding):
assert source is None and operand is not None
result = _select_and_gather_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
return [result, None]
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
window_strides, padding)
return [source_t, None]
def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
source, operand = batched_args
@ -3295,6 +3312,7 @@ select_and_scatter_add_p = standard_primitive(
_select_and_scatter_add_translation)
ad.primitive_transposes[select_and_scatter_add_p] = \
_select_and_scatter_add_transpose
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
batching.primitive_batchers[select_and_scatter_add_p] = \
_select_and_scatter_add_batch_rule

View File

@ -1945,9 +1945,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
self.assertEqual(onp.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
# TODO(phawkins): enable both gradients after a jaxlib update.
# check_grads(fun, (operand,), 1, 1e-2, 1e-2, 1e-2)
check_grads(fun, (operand,), 2, 1e-2, 1e-2, 1e-2)
# pylint: enable=cell-var-from-loop
# TODO(b/205052657): enable more tests when supported