[shape_poly] Fixes for shape polynorphism for jvp(scatter).

Fixes: #18348
This commit is contained in:
George Necula 2023-11-06 18:38:12 +01:00
parent a23aac5566
commit c30ce5f0f5
2 changed files with 45 additions and 8 deletions

View File

@ -2322,10 +2322,14 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
# a) attach a positive ID to each update in `updates`, and perform a scatter
# on the IDs.
ids_shape = np.array(updates.shape, dtype=np.int64)
ids_shape[dnums.update_window_dims,] = 1
ids_shape = list(updates.shape)
for update_dim in dnums.update_window_dims:
ids_shape[update_dim] = 1
num_ids = math.prod(ids_shape)
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
if core.is_constant_dim(num_ids):
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
else:
id_dtype = np.uint64
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
lax._ones(updates, dtype=id_dtype))

View File

@ -2758,25 +2758,58 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=["b, ..."]),
PolyHarness("scatter_add", "",
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True),
arg_descriptors=[RandArg((7, 4), _f32),
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
np.array([[1], [2]], np.int32), # indices: [2, 1]
RandArg((7, 2), _f32), # updates: [7, 2]
RandArg((7, 2), _f32), # updates: [b, 2]
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
polymorphic_shapes=["b, ...", None, "b, ..."]),
PolyHarness("scatter_add", "clip0",
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
arg_descriptors=[RandArg((7, 4), _f32), # [b, 4]
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
np.array([[1], [2]], np.int32), # indices: [2, 1]
RandArg((7, 2), _f32), # updates: [b, 2]
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
polymorphic_shapes=["b, ...", None, "b, ..."]),
PolyHarness("scatter_add", "clip1",
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
arg_descriptors=[RandArg((7, 4), _f32), # [b, 4]
np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32), # indices: [b, 2]
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
# indices: [b, 2]
np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32),
RandArg((7, 1), _f32), # updates: [b, 1]
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))],
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
PolyHarness("scatter_grad", "",
lambda *args: jax.grad(
lambda *args:
jnp.sum(lax.scatter( # type: ignore
*args,
indices_are_sorted=False,
unique_indices=False,
))
)(*args),
arg_descriptors=[RandArg((7, 4), _f32), # : [b, 4]
np.array([[1], [2]], np.int32), # indices: [2, 1]
RandArg((7, 2), _f32), # updates: [b, 2]
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,))),
],
polymorphic_shapes=["b, ...", None, "b, ..."]),
PolyHarness("scatter_grad", "poly_indices",
lambda *args: jax.grad(
lambda *args:
jnp.sum(lax.scatter( # type: ignore
*args,
indices_are_sorted=False,
unique_indices=False))
)(*args),
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
# indices: [b, 2]
np.array(
[[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0],
[3, 0], [0, 5]], np.int32),
RandArg((7, 1), _f32), # updates: [b, 1]
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1))),
],
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
[
PolyHarness("schur",
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}",