mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add VJP for lax._select_and_gather_add
(3rd-order grad of maxpool)
This commit is contained in:
parent
b940245730
commit
55d74d8624
@ -3393,6 +3393,23 @@ def _select_and_gather_add_translation(
|
||||
out = c.ConvertElementType(out, uint_etype)
|
||||
return c.BitcastConvertType(out, etype)
|
||||
|
||||
def _select_and_gather_add_jvp(
|
||||
primals, tangents, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
source, operand = primals
|
||||
g_source, g_operand = tangents
|
||||
val_out = _select_and_gather_add(
|
||||
source, operand, select_prim, window_dimensions, window_strides,
|
||||
padding)
|
||||
del g_operand
|
||||
if g_source is ad_util.zero:
|
||||
tangent_out = ad_util.zero
|
||||
else:
|
||||
tangent_out = _select_and_gather_add(
|
||||
g_source, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _select_and_gather_add_transpose(
|
||||
t, tangents, operand, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
@ -3404,6 +3421,7 @@ def _select_and_gather_add_transpose(
|
||||
select_and_gather_add_p = standard_primitive(
|
||||
_select_and_gather_add_shape_rule, _input_dtype, 'select_and_gather_add',
|
||||
_select_and_gather_add_translation)
|
||||
ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp
|
||||
ad.primitive_transposes[select_and_gather_add_p] = \
|
||||
_select_and_gather_add_transpose
|
||||
|
||||
|
@ -1945,7 +1945,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
self.assertEqual(onp.unique(operand).size, operand.size,
|
||||
msg="test requires operand elements to be unique.")
|
||||
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
|
||||
check_grads(fun, (operand,), 2, 1e-2, 1e-2, 1e-2)
|
||||
check_grads(fun, (operand,), 3, 1e-2, 1e-2, 1e-2)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
|
Loading…
x
Reference in New Issue
Block a user