diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3774fefe7..831b905c9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3071,18 +3071,8 @@ def check_array_xla_sharding_layout_match( db_xs = check_device_backend_on_shardings([xs]) - # Raise memory kind mismatch error even if the arg is uncommitted. - if arg.sharding.memory_kind != xs.memory_kind: - errors.append( - ("Got input sharding(s) that compiled object was called with: " - f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", - 'sharding')) - if (not db_xs and arg._committed and - not op_shardings.are_op_shardings_equal( - arg.sharding._to_xla_hlo_sharding(arg.ndim), - xs._to_xla_hlo_sharding(arg.ndim))): + not arg.sharding.is_equivalent_to(xs, arg.ndim)): errors.append( ("Got input sharding(s) that compiled object was called with: " f"{arg.sharding} and sharding(s) the computation was compiled " diff --git a/tests/memories_test.py b/tests/memories_test.py index e1c4c6df7..d3991c955 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1329,6 +1329,29 @@ class ComputeOffload(jtu.BufferDonationTestCase): self.assertIn("input_output_alias", lowered_text) self.assertDeleted(x) + @jtu.run_on_devices('tpu') + def test_aot_device_implicit_transfer(self): + mesh = jtu.create_global_mesh((1,), 'x') + np_inp = np.arange(8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P())) + + @jax.jit + def f(x): + return x * 2 + + compiled = f.lower(arr).compile() + + cpu_dev = jax.devices('cpu')[0] + with jax.default_device(cpu_dev): + cpu_arr = jnp.arange(8) + self.assertEqual(cpu_arr.sharding, SingleDeviceSharding(cpu_dev)) + self.assertFalse(cpu_arr._committed) + + out = compiled(cpu_arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + self.assertEqual(out.sharding.memory_kind, 'device') + @jtu.with_config(jax_enable_memories=True) class ActivationOffloadingTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index cfc93f970..c31c2abee 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4297,6 +4297,24 @@ class ArrayPjitTest(jtu.JaxTestCase): out = f() self.assertEqual(out.sharding, s) + @jtu.run_on_devices('tpu', 'gpu') + def test_aot_device_mismatch(self): + mesh = jtu.create_global_mesh((1,), 'x') + np_inp = np.arange(8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P())) + + @jax.jit + def f(x): + return x * 2 + + compiled = f.lower(arr).compile() + + cpu_arr = jax.device_put(np_inp, jax.devices('cpu')[0]) + with self.assertRaisesRegex( + ValueError, + "Compiled object called with input sharding.*does not match"): + compiled(cpu_arr) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")