mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Fixes for shape polynorphism for jvp(scatter).
Fixes: #18348
This commit is contained in:
parent
a23aac5566
commit
c30ce5f0f5
@ -2322,10 +2322,14 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
||||
|
||||
# a) attach a positive ID to each update in `updates`, and perform a scatter
|
||||
# on the IDs.
|
||||
ids_shape = np.array(updates.shape, dtype=np.int64)
|
||||
ids_shape[dnums.update_window_dims,] = 1
|
||||
ids_shape = list(updates.shape)
|
||||
for update_dim in dnums.update_window_dims:
|
||||
ids_shape[update_dim] = 1
|
||||
num_ids = math.prod(ids_shape)
|
||||
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
|
||||
if core.is_constant_dim(num_ids):
|
||||
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
|
||||
else:
|
||||
id_dtype = np.uint64
|
||||
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
|
||||
lax._ones(updates, dtype=id_dtype))
|
||||
|
||||
|
@ -2758,25 +2758,58 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("scatter_add", "",
|
||||
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True),
|
||||
arg_descriptors=[RandArg((7, 4), _f32),
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
||||
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
||||
RandArg((7, 2), _f32), # updates: [7, 2]
|
||||
RandArg((7, 2), _f32), # updates: [b, 2]
|
||||
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
|
||||
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
||||
PolyHarness("scatter_add", "clip0",
|
||||
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # [b, 4]
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
||||
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
||||
RandArg((7, 2), _f32), # updates: [b, 2]
|
||||
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
|
||||
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
||||
PolyHarness("scatter_add", "clip1",
|
||||
partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True, mode=lax.GatherScatterMode.CLIP),
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # [b, 4]
|
||||
np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32), # indices: [b, 2]
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
||||
# indices: [b, 2]
|
||||
np.array([[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0], [3, 0], [0, 5]], np.int32),
|
||||
RandArg((7, 1), _f32), # updates: [b, 1]
|
||||
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1,)))],
|
||||
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
||||
PolyHarness("scatter_grad", "",
|
||||
lambda *args: jax.grad(
|
||||
lambda *args:
|
||||
jnp.sum(lax.scatter( # type: ignore
|
||||
*args,
|
||||
indices_are_sorted=False,
|
||||
unique_indices=False,
|
||||
))
|
||||
)(*args),
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # : [b, 4]
|
||||
np.array([[1], [2]], np.int32), # indices: [2, 1]
|
||||
RandArg((7, 2), _f32), # updates: [b, 2]
|
||||
StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,))),
|
||||
],
|
||||
polymorphic_shapes=["b, ...", None, "b, ..."]),
|
||||
PolyHarness("scatter_grad", "poly_indices",
|
||||
lambda *args: jax.grad(
|
||||
lambda *args:
|
||||
jnp.sum(lax.scatter( # type: ignore
|
||||
*args,
|
||||
indices_are_sorted=False,
|
||||
unique_indices=False))
|
||||
)(*args),
|
||||
arg_descriptors=[RandArg((7, 4), _f32), # op: [b, 4]
|
||||
# indices: [b, 2]
|
||||
np.array(
|
||||
[[1, 2], [-2, 0], [6, 4], [7, -1], [1, 0],
|
||||
[3, 0], [0, 5]], np.int32),
|
||||
RandArg((7, 1), _f32), # updates: [b, 1]
|
||||
StaticArg(lax.ScatterDimensionNumbers((1,), (0,), (0, 1))),
|
||||
],
|
||||
polymorphic_shapes=["b, ...", "b, ...", "b, ..."]),
|
||||
[
|
||||
PolyHarness("schur",
|
||||
f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{poly=}_{compute_schur_vectors=}",
|
||||
|
Loading…
x
Reference in New Issue
Block a user