From 16712b5116d9117cf200c3c26614f10af0a594c4 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 19 Dec 2024 14:21:09 -0800 Subject: [PATCH] Ensure that the two offsets of a dynamic_slice have the same dtype regardless the value of config.enable_x64. PiperOrigin-RevId: 708031525 --- tests/pjit_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8a7c0b2e6..63620e5ad 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1347,8 +1347,10 @@ class CustomPartitionerTest(jtu.JaxTestCase): def lower_fn(x, y): axis_name = arg_shardings[1].spec[0][0] i = jax.lax.axis_index(axis_name) + # Use offset i * 0 instead of 0 to ensure that the two offsets have the + # same dtype regardless the value of config.enable_x64. z = jax.lax.psum( - jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name) + jax.lax.dynamic_slice(x, (i * 0, i * 8), (8, 8)) @ y, (axis_name) ) return z, z * z