add VJP for lax._select_and_gather_add (3rd-order grad of maxpool)

This commit is contained in:
James Bradbury 2019-04-20 17:06:56 -07:00
parent b940245730
commit 55d74d8624
2 changed files with 19 additions and 1 deletions

View File

@ -3393,6 +3393,23 @@ def _select_and_gather_add_translation(
out = c.ConvertElementType(out, uint_etype)
return c.BitcastConvertType(out, etype)
def _select_and_gather_add_jvp(
primals, tangents, select_prim, window_dimensions, window_strides,
padding):
source, operand = primals
g_source, g_operand = tangents
val_out = _select_and_gather_add(
source, operand, select_prim, window_dimensions, window_strides,
padding)
del g_operand
if g_source is ad_util.zero:
tangent_out = ad_util.zero
else:
tangent_out = _select_and_gather_add(
g_source, operand, select_prim, window_dimensions,
window_strides, padding)
return val_out, tangent_out
def _select_and_gather_add_transpose(
t, tangents, operand, select_prim, window_dimensions, window_strides,
padding):
@ -3404,6 +3421,7 @@ def _select_and_gather_add_transpose(
select_and_gather_add_p = standard_primitive(
_select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
_select_and_gather_add_translation)
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
ad.primitive_transposes[select_and_gather_add_p] = \
_select_and_gather_add_transpose

View File

@ -1945,7 +1945,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
self.assertEqual(onp.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
check_grads(fun, (operand,), 2, 1e-2, 1e-2, 1e-2)
check_grads(fun, (operand,), 3, 1e-2, 1e-2, 1e-2)
# pylint: enable=cell-var-from-loop
# TODO(b/205052657): enable more tests when supported