Make sure we don't return GSPMDSharding in compiled.input_shardings

PiperOrigin-RevId: 624343180
This commit is contained in:
Yash Katariya 2024-04-12 17:52:08 -07:00 committed by jax authors
parent 09415607bb
commit 9e989321f1
2 changed files with 25 additions and 8 deletions

View File

@ -2574,23 +2574,23 @@ def _get_out_sharding_from_orig_sharding(
out.append(o)
return out
def maybe_get_orig_out_sharding(
in_shardings, out_shardings, in_avals, out_avals):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in out_shardings):
return out_shardings
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
orig_in_s = None
orig_aval = None
for oi, aval in safe_zip(in_shardings, in_avals):
for oi, aval in safe_zip(old_shardings, old_avals):
if type(oi) in _orig_out_sharding_handlers:
orig_in_s = oi
orig_aval = aval
break
if orig_in_s is not None:
return _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval)
new_shardings, new_avals, orig_in_s, orig_aval)
return out_shardings
return new_shardings
def _get_layouts_from_executable(
@ -2744,6 +2744,10 @@ def _maybe_get_and_check_in_shardings(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
new_in_shardings = maybe_recover_user_shardings(
in_shardings, new_in_shardings, global_in_avals, global_in_avals)
return new_in_shardings
@ -2921,7 +2925,7 @@ class UnloadedMeshExecutable:
assert all(i is None for i in in_layouts)
assert all(o is None for o in out_layouts)
out_shardings = maybe_get_orig_out_sharding(
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals)
out_shardings = finalize_out_shardings(out_shardings, da)

View File

@ -3999,6 +3999,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out4, np_inp * 3)
self.assertArraysEqual(out5, np_inp.T)
def test_input_shardings_aot(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
@jax.jit
def f(x, y):
return x * 2, y.T
arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings
for s in arg_shardings:
self.assertIsInstance(s, NamedSharding)
def test_parameter_tupled_jit(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')