mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Ensure that the two offsets of a dynamic_slice have the same dtype regardless
the value of config.enable_x64. PiperOrigin-RevId: 708031525
This commit is contained in:
parent
de8fa8fd19
commit
16712b5116
@ -1347,8 +1347,10 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
def lower_fn(x, y):
|
||||
axis_name = arg_shardings[1].spec[0][0]
|
||||
i = jax.lax.axis_index(axis_name)
|
||||
# Use offset i * 0 instead of 0 to ensure that the two offsets have the
|
||||
# same dtype regardless the value of config.enable_x64.
|
||||
z = jax.lax.psum(
|
||||
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
|
||||
jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name)
|
||||
)
|
||||
return z, z * z
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user