Ensure that the two offsets of a dynamic_slice have the same dtype regardless

the value of config.enable_x64.

PiperOrigin-RevId: 708031525
This commit is contained in:
Bixia Zheng 2024-12-19 14:21:09 -08:00 committed by jax authors
parent de8fa8fd19
commit 16712b5116

View File

@ -1347,8 +1347,10 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def lower_fn(x, y):
axis_name = arg_shardings[1].spec[0][0]
i = jax.lax.axis_index(axis_name)
# Use offset i * 0 instead of 0 to ensure that the two offsets have the
# same dtype regardless the value of config.enable_x64.
z = jax.lax.psum(
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name)
)
return z, z * z