Merge pull request #11580 from jakevdp:fix-dynamic-index

PiperOrigin-RevId: 463311212
This commit is contained in:
jax authors 2022-07-26 05:28:47 -07:00
commit c4b255b527
2 changed files with 40 additions and 9 deletions

View File

@ -841,8 +841,8 @@ def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
def _dynamic_slice_jvp(primals, tangents, *, slice_sizes):
tangent_out = tangents[0]
if type(tangent_out) is not ad_util.Zero:
tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes)
return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out
tangent_out = dynamic_slice_p.bind(tangent_out, *primals[1:], slice_sizes=slice_sizes)
return dynamic_slice_p.bind(primals[0], *primals[1:], slice_sizes=slice_sizes), tangent_out
def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
assert ad.is_undefined_primal(operand)
@ -852,7 +852,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
else:
zeros = lax.full(operand_shape, 0, operand_dtype)
return ([dynamic_update_slice(zeros, t, start_indices)] +
return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] +
[None] * len(start_indices))
def _batch_dynamic_slice_indices(indices, bdims):
@ -935,13 +935,13 @@ def _dynamic_update_slice_jvp(primals, tangents):
operand, update = primals[:2]
start_indices = primals[2:]
g_operand, g_update = tangents[:2]
val_out = dynamic_update_slice(operand, update, start_indices)
val_out = dynamic_update_slice_p.bind(operand, update, *start_indices)
if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
else:
g_operand = ad.instantiate_zeros(g_operand)
g_update = ad.instantiate_zeros(g_update)
tangent_out = dynamic_update_slice(g_operand, g_update, start_indices)
tangent_out = dynamic_update_slice_p.bind(g_operand, g_update, *start_indices)
return val_out, tangent_out
def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
@ -954,11 +954,11 @@ def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None
else:
dus = dynamic_update_slice
ds = dynamic_slice
dus = dynamic_update_slice_p.bind
ds = dynamic_slice_p.bind
zeros = lax._zeros(t, shape=update_shape)
operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None
update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None
operand_t = dus(t, zeros, *start_indices) if ad.is_undefined_primal(operand) else None
update_t = ds(t, *start_indices, slice_sizes=update_shape) if ad.is_undefined_primal(update) else None
return [operand_t, update_t] + [None] * len(start_indices)
def _dynamic_update_slice_batching_rule(batched_args, batch_dims):

View File

@ -616,6 +616,37 @@ class LaxAutodiffTest(jtu.JaxTestCase):
dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
def testDynamicSliceValueAndGrad(self):
# Regression test for https://github.com/google/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
axis = 0
index = -(shape[axis] + 3)
def f(x):
return lax.dynamic_index_in_dim(x, index, axis).sum()
x = rng(shape, np.float32)
result1 = f(x)
result2, _ = jax.value_and_grad(f, 0)(x)
self.assertAllClose(result1, result2)
def testDynamicUpdateSliceValueAndGrad(self):
# Regression test for https://github.com/google/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
axis = 0
index = -(shape[axis] + 3)
def f(x, y):
return lax.dynamic_update_index_in_dim(x, y, index, axis).sum()
x = rng(shape, np.float32)
y = rng([1 for s in shape], np.float32)
result1 = f(x, y)
result2, _ = jax.value_and_grad(f, 0)(x, y)
self.assertAllClose(result1, result2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),