mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11580 from jakevdp:fix-dynamic-index
PiperOrigin-RevId: 463311212
This commit is contained in:
commit
c4b255b527
@ -841,8 +841,8 @@ def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes):
|
||||
def _dynamic_slice_jvp(primals, tangents, *, slice_sizes):
|
||||
tangent_out = tangents[0]
|
||||
if type(tangent_out) is not ad_util.Zero:
|
||||
tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes)
|
||||
return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out
|
||||
tangent_out = dynamic_slice_p.bind(tangent_out, *primals[1:], slice_sizes=slice_sizes)
|
||||
return dynamic_slice_p.bind(primals[0], *primals[1:], slice_sizes=slice_sizes), tangent_out
|
||||
|
||||
def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
@ -852,7 +852,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
|
||||
return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
|
||||
else:
|
||||
zeros = lax.full(operand_shape, 0, operand_dtype)
|
||||
return ([dynamic_update_slice(zeros, t, start_indices)] +
|
||||
return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] +
|
||||
[None] * len(start_indices))
|
||||
|
||||
def _batch_dynamic_slice_indices(indices, bdims):
|
||||
@ -935,13 +935,13 @@ def _dynamic_update_slice_jvp(primals, tangents):
|
||||
operand, update = primals[:2]
|
||||
start_indices = primals[2:]
|
||||
g_operand, g_update = tangents[:2]
|
||||
val_out = dynamic_update_slice(operand, update, start_indices)
|
||||
val_out = dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero:
|
||||
tangent_out = ad_util.Zero.from_value(val_out)
|
||||
else:
|
||||
g_operand = ad.instantiate_zeros(g_operand)
|
||||
g_update = ad.instantiate_zeros(g_update)
|
||||
tangent_out = dynamic_update_slice(g_operand, g_update, start_indices)
|
||||
tangent_out = dynamic_update_slice_p.bind(g_operand, g_update, *start_indices)
|
||||
return val_out, tangent_out
|
||||
|
||||
def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
|
||||
@ -954,11 +954,11 @@ def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices):
|
||||
operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
||||
update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None
|
||||
else:
|
||||
dus = dynamic_update_slice
|
||||
ds = dynamic_slice
|
||||
dus = dynamic_update_slice_p.bind
|
||||
ds = dynamic_slice_p.bind
|
||||
zeros = lax._zeros(t, shape=update_shape)
|
||||
operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None
|
||||
update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None
|
||||
operand_t = dus(t, zeros, *start_indices) if ad.is_undefined_primal(operand) else None
|
||||
update_t = ds(t, *start_indices, slice_sizes=update_shape) if ad.is_undefined_primal(update) else None
|
||||
return [operand_t, update_t] + [None] * len(start_indices)
|
||||
|
||||
def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
|
||||
|
@ -616,6 +616,37 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
|
||||
check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
|
||||
|
||||
def testDynamicSliceValueAndGrad(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10984
|
||||
# Issue arose due to an out-of-range negative index.
|
||||
rng = jtu.rand_default(self.rng())
|
||||
shape = (5, 5)
|
||||
axis = 0
|
||||
index = -(shape[axis] + 3)
|
||||
def f(x):
|
||||
return lax.dynamic_index_in_dim(x, index, axis).sum()
|
||||
x = rng(shape, np.float32)
|
||||
|
||||
result1 = f(x)
|
||||
result2, _ = jax.value_and_grad(f, 0)(x)
|
||||
self.assertAllClose(result1, result2)
|
||||
|
||||
def testDynamicUpdateSliceValueAndGrad(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10984
|
||||
# Issue arose due to an out-of-range negative index.
|
||||
rng = jtu.rand_default(self.rng())
|
||||
shape = (5, 5)
|
||||
axis = 0
|
||||
index = -(shape[axis] + 3)
|
||||
def f(x, y):
|
||||
return lax.dynamic_update_index_in_dim(x, y, index, axis).sum()
|
||||
x = rng(shape, np.float32)
|
||||
y = rng([1 for s in shape], np.float32)
|
||||
|
||||
result1 = f(x, y)
|
||||
result2, _ = jax.value_and_grad(f, 0)(x, y)
|
||||
self.assertAllClose(result1, result2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_perm={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), perm),
|
||||
|
Loading…
x
Reference in New Issue
Block a user