Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.

So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
This commit is contained in:
Yash Katariya 2024-08-02 08:15:01 -07:00 committed by jax authors
parent f85b8e677b
commit e6851e6b22
3 changed files with 42 additions and 11 deletions

View File

@ -3071,18 +3071,8 @@ def check_array_xla_sharding_layout_match(
db_xs = check_device_backend_on_shardings([xs])
# Raise memory kind mismatch error even if the arg is uncommitted.
if arg.sharding.memory_kind != xs.memory_kind:
errors.append(
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
'sharding'))
if (not db_xs and arg._committed and
not op_shardings.are_op_shardings_equal(
arg.sharding._to_xla_hlo_sharding(arg.ndim),
xs._to_xla_hlo_sharding(arg.ndim))):
not arg.sharding.is_equivalent_to(xs, arg.ndim)):
errors.append(
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "

View File

@ -1329,6 +1329,29 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertIn("input_output_alias", lowered_text)
self.assertDeleted(x)
@jtu.run_on_devices('tpu')
def test_aot_device_implicit_transfer(self):
mesh = jtu.create_global_mesh((1,), 'x')
np_inp = np.arange(8)
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
@jax.jit
def f(x):
return x * 2
compiled = f.lower(arr).compile()
cpu_dev = jax.devices('cpu')[0]
with jax.default_device(cpu_dev):
cpu_arr = jnp.arange(8)
self.assertEqual(cpu_arr.sharding, SingleDeviceSharding(cpu_dev))
self.assertFalse(cpu_arr._committed)
out = compiled(cpu_arr)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
self.assertEqual(out.sharding.memory_kind, 'device')
@jtu.with_config(jax_enable_memories=True)
class ActivationOffloadingTest(jtu.JaxTestCase):

View File

@ -4297,6 +4297,24 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = f()
self.assertEqual(out.sharding, s)
@jtu.run_on_devices('tpu', 'gpu')
def test_aot_device_mismatch(self):
mesh = jtu.create_global_mesh((1,), 'x')
np_inp = np.arange(8)
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
@jax.jit
def f(x):
return x * 2
compiled = f.lower(arr).compile()
cpu_arr = jax.device_put(np_inp, jax.devices('cpu')[0])
with self.assertRaisesRegex(
ValueError,
"Compiled object called with input sharding.*does not match"):
compiled(cpu_arr)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")