mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong. This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too). PiperOrigin-RevId: 658794885
This commit is contained in:
parent
f85b8e677b
commit
e6851e6b22
@ -3071,18 +3071,8 @@ def check_array_xla_sharding_layout_match(
|
||||
|
||||
db_xs = check_device_backend_on_shardings([xs])
|
||||
|
||||
# Raise memory kind mismatch error even if the arg is uncommitted.
|
||||
if arg.sharding.memory_kind != xs.memory_kind:
|
||||
errors.append(
|
||||
("Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
'sharding'))
|
||||
|
||||
if (not db_xs and arg._committed and
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
arg.sharding._to_xla_hlo_sharding(arg.ndim),
|
||||
xs._to_xla_hlo_sharding(arg.ndim))):
|
||||
not arg.sharding.is_equivalent_to(xs, arg.ndim)):
|
||||
errors.append(
|
||||
("Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
|
@ -1329,6 +1329,29 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertIn("input_output_alias", lowered_text)
|
||||
self.assertDeleted(x)
|
||||
|
||||
@jtu.run_on_devices('tpu')
|
||||
def test_aot_device_implicit_transfer(self):
|
||||
mesh = jtu.create_global_mesh((1,), 'x')
|
||||
np_inp = np.arange(8)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
compiled = f.lower(arr).compile()
|
||||
|
||||
cpu_dev = jax.devices('cpu')[0]
|
||||
with jax.default_device(cpu_dev):
|
||||
cpu_arr = jnp.arange(8)
|
||||
self.assertEqual(cpu_arr.sharding, SingleDeviceSharding(cpu_dev))
|
||||
self.assertFalse(cpu_arr._committed)
|
||||
|
||||
out = compiled(cpu_arr)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
self.assertEqual(out.sharding.memory_kind, 'device')
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_memories=True)
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
@ -4297,6 +4297,24 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = f()
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
@jtu.run_on_devices('tpu', 'gpu')
|
||||
def test_aot_device_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((1,), 'x')
|
||||
np_inp = np.arange(8)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
compiled = f.lower(arr).compile()
|
||||
|
||||
cpu_arr = jax.device_put(np_inp, jax.devices('cpu')[0])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Compiled object called with input sharding.*does not match"):
|
||||
compiled(cpu_arr)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user