mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add VJP for lax._select_and_scatter_add
(2nd-order grad of maxpool)
This commit is contained in:
parent
9baf42d978
commit
b940245730
@ -3252,13 +3252,30 @@ def _select_and_scatter_add_translation(
|
||||
return c.SelectAndScatter(operand, select, window_dimensions, window_strides,
|
||||
padding, source, zero, scatter)
|
||||
|
||||
def _select_and_scatter_add_jvp(
|
||||
primals, tangents, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
source, operand = primals
|
||||
g_source, g_operand = tangents
|
||||
val_out = _select_and_scatter_add(
|
||||
source, operand, select_prim, window_dimensions, window_strides,
|
||||
padding)
|
||||
del g_operand
|
||||
if g_source is ad_util.zero:
|
||||
tangent_out = ad_util.zero
|
||||
else:
|
||||
tangent_out = _select_and_scatter_add(
|
||||
g_source, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _select_and_scatter_add_transpose(
|
||||
t, source, operand, select_prim, window_dimensions, window_strides,
|
||||
padding):
|
||||
assert source is None and operand is not None
|
||||
result = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
return [result, None]
|
||||
source_t = _select_and_gather_add(t, operand, select_prim, window_dimensions,
|
||||
window_strides, padding)
|
||||
return [source_t, None]
|
||||
|
||||
def _select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
|
||||
source, operand = batched_args
|
||||
@ -3295,6 +3312,7 @@ select_and_scatter_add_p = standard_primitive(
|
||||
_select_and_scatter_add_translation)
|
||||
ad.primitive_transposes[select_and_scatter_add_p] = \
|
||||
_select_and_scatter_add_transpose
|
||||
ad.primitive_jvps[select_and_scatter_add_p] = _select_and_scatter_add_jvp
|
||||
batching.primitive_batchers[select_and_scatter_add_p] = \
|
||||
_select_and_scatter_add_batch_rule
|
||||
|
||||
|
@ -1945,9 +1945,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
self.assertEqual(onp.unique(operand).size, operand.size,
|
||||
msg="test requires operand elements to be unique.")
|
||||
jtu.check_vjp(fun, partial(api.vjp, fun), (operand,), 1e-2, 1e-2, 1e-2)
|
||||
|
||||
# TODO(phawkins): enable both gradients after a jaxlib update.
|
||||
# check_grads(fun, (operand,), 1, 1e-2, 1e-2, 1e-2)
|
||||
check_grads(fun, (operand,), 2, 1e-2, 1e-2, 1e-2)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
|
Loading…
x
Reference in New Issue
Block a user