mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Make sure we don't return GSPMDSharding in compiled.input_shardings
PiperOrigin-RevId: 624343180
This commit is contained in:
parent
09415607bb
commit
9e989321f1
@ -2574,23 +2574,23 @@ def _get_out_sharding_from_orig_sharding(
|
||||
out.append(o)
|
||||
return out
|
||||
|
||||
def maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, in_avals, out_avals):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in out_shardings):
|
||||
return out_shardings
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
orig_in_s = None
|
||||
orig_aval = None
|
||||
for oi, aval in safe_zip(in_shardings, in_avals):
|
||||
for oi, aval in safe_zip(old_shardings, old_avals):
|
||||
if type(oi) in _orig_out_sharding_handlers:
|
||||
orig_in_s = oi
|
||||
orig_aval = aval
|
||||
break
|
||||
if orig_in_s is not None:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
out_shardings, out_avals, orig_in_s, orig_aval)
|
||||
new_shardings, new_avals, orig_in_s, orig_aval)
|
||||
|
||||
return out_shardings
|
||||
return new_shardings
|
||||
|
||||
|
||||
def _get_layouts_from_executable(
|
||||
@ -2744,6 +2744,10 @@ def _maybe_get_and_check_in_shardings(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
|
||||
new_in_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, new_in_shardings, global_in_avals, global_in_avals)
|
||||
|
||||
return new_in_shardings
|
||||
|
||||
|
||||
@ -2921,7 +2925,7 @@ class UnloadedMeshExecutable:
|
||||
assert all(i is None for i in in_layouts)
|
||||
assert all(o is None for o in out_layouts)
|
||||
|
||||
out_shardings = maybe_get_orig_out_sharding(
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals)
|
||||
|
||||
out_shardings = finalize_out_shardings(out_shardings, da)
|
||||
|
@ -3999,6 +3999,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out4, np_inp * 3)
|
||||
self.assertArraysEqual(out5, np_inp.T)
|
||||
|
||||
def test_input_shardings_aot(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x * 2, y.T
|
||||
|
||||
arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings
|
||||
for s in arg_shardings:
|
||||
self.assertIsInstance(s, NamedSharding)
|
||||
|
||||
def test_parameter_tupled_jit(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')
|
||||
|
Loading…
x
Reference in New Issue
Block a user