Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.

PiperOrigin-RevId: 510562354
This commit is contained in:
Yash Katariya 2023-02-17 17:52:37 -08:00 committed by jax authors
parent 0ffdeb3de2
commit d93aa70801
9 changed files with 31 additions and 33 deletions

View File

@ -779,7 +779,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
# assume that `x` is committed. This might happen when the input is
# a `ShapedDtypeStruct` or `types.SimpleNamespace`, etc that might
# only have a `sharding` attribute on them.
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim)
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim)
if getattr(x, '_committed', True) else None)
else:
return aval, None

View File

@ -110,7 +110,7 @@ def arg_spec(x: Any) -> ArgSpec:
if config.jax_array:
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim) # type: ignore
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore
if x._committed else None)
else:
return aval, x._device

View File

@ -3353,7 +3353,7 @@ def _get_input_indices(
return input_indices
def get_op_sharding_shardings_from_executable(
def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device],
num_in_avals: int, num_out_avals: int
) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding],
@ -3550,7 +3550,7 @@ class UnloadedMeshExecutable:
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
assert mesh is None
_, out_shardings_xla = get_op_sharding_shardings_from_executable( # type: ignore
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment,
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings

View File

@ -1878,12 +1878,11 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
mps = NamedSharding._from_parsed_pspec(
resource_env.physical_mesh, ParsedPartitionSpec((), ()))
unconstrained_dims = get_unconstrained_dims(mps)
op_sharding_sharding = GSPMDSharding.get_replicated(
mps._device_assignment)
gspmd_sharding = GSPMDSharding.get_replicated(mps._device_assignment)
new_eqns.append(core.JaxprEqn(
[tmpvar], [outvar], sharding_constraint_p,
dict(resource_env=resource_env,
sharding=op_sharding_sharding,
sharding=gspmd_sharding,
unconstrained_dims=unconstrained_dims),
set(),
eqn.source_info))

View File

@ -865,7 +865,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
# FROM_GDA is removed.
canonicalized_shardings = tuple(
i if _is_unspecified_or_from_gda_or_auto(i) else to_op_sharding_sharding(i, aval.ndim)
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
return tuple(global_in_avals), canonicalized_shardings
@ -908,7 +908,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
# Local or global avals doesn't matter for converting to op sharding because
# the `ndim` does not change.
canonicalized_in_shardings_flat = tuple(
i if _is_from_gda(i) or is_auto(i) else to_op_sharding_sharding(i, aval.ndim)
i if _is_from_gda(i) or is_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, local_in_avals))
global_in_avals = local_to_global(
@ -961,7 +961,7 @@ def _check_and_canonicalize_out_shardings(
allow_uneven_sharding=False)
canonicalized_out_shardings_flat = tuple(
o if _is_unspecified(o) or is_auto(o) else to_op_sharding_sharding(o, aval.ndim)
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
)
return canonicalized_out_shardings_flat
@ -1233,7 +1233,7 @@ def _resolve_in_shardings(
if isinstance(arg_s, PmapSharding):
resolved_in_shardings.append(_UNSPECIFIED)
else:
resolved_in_shardings.append(to_op_sharding_sharding(
resolved_in_shardings.append(to_gspmd_sharding(
cast(XLACompatibleSharding, arg_s), arg.ndim))
else:
if dispatch.is_single_device_sharding(arg_s):
@ -1671,13 +1671,13 @@ def _pjit_partial_eval(trace, *in_tracers,
_allow_propagation_to_outputs=[True] * len(known_params['out_shardings']),
_allow_compile_replicated=False)
da = compiled._device_assignment
_, out_op_sharding_shardings = pxla.get_op_sharding_shardings_from_executable(
_, out_gspmd_shardings = pxla.get_gspmd_shardings_from_executable(
compiled.xla_executable, da, len(known_jaxpr.in_avals),
len(known_jaxpr.out_avals))
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
assert len(out_gspmd_shardings) == len(known_jaxpr.out_avals), (
len(out_gspmd_shardings), len(known_jaxpr.out_avals))
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
safe_zip(out_gspmd_shardings, known_jaxpr.out_avals)]
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
else:
residual_op_shardings = ()
@ -2057,7 +2057,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim),
outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim),
resource_env=resource_env,
unconstrained_dims=ud)
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
@ -2152,13 +2152,13 @@ def get_array_mapping(
if axes is not None for axis in axes)
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
op_sharding_sharding = GSPMDSharding(
gspmd_sharding = GSPMDSharding(
s._device_assignment, s._to_xla_op_sharding(ndim))
op_sharding_sharding._original_sharding = s
return op_sharding_sharding
gspmd_sharding._original_sharding = s
return gspmd_sharding
def get_unconstrained_dims(sharding: NamedSharding):
@ -2242,14 +2242,14 @@ def _maybe_replace_from_gda_with_pspec(
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. "
f"Got GDA sharding: {gda_sharding} and "
f"pjit sharding: {in_sharding._original_sharding}") # type: ignore
return to_op_sharding_sharding(gda_sharding, ndim)
return to_gspmd_sharding(gda_sharding, ndim)
out = []
for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat):
if is_auto(in_sharding_flat):
out.append(in_sharding_flat)
elif isinstance(arg, array.ArrayImpl):
out.append(to_op_sharding_sharding(arg.sharding, arg.ndim))
out.append(to_gspmd_sharding(arg.sharding, arg.ndim))
elif isinstance(arg, GDA):
gda_sharding = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim))

View File

@ -137,9 +137,8 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
op_sharding = self._to_xla_op_sharding(len(global_shape))
op_sharding_sharding = GSPMDSharding(self._device_assignment,
op_sharding)
return op_sharding_sharding.devices_indices_map(global_shape)
gspmd_sharding = GSPMDSharding(self._device_assignment, op_sharding)
return gspmd_sharding.devices_indices_map(global_shape)
@functools.cached_property
def _addressable_device_assignment(self) -> XLADeviceAssignment:

View File

@ -765,7 +765,7 @@ class ShardingTest(jtu.JaxTestCase):
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
def test_op_sharding_sharding_repr(self):
def test_gspmd_sharding_repr(self):
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 1, 2]

View File

@ -182,7 +182,7 @@ class PickleTest(jtu.JaxTestCase):
s = jax.sharding.PmapSharding(jax.devices(), ss)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
def test_pickle_op_sharding_sharding(self):
def test_pickle_gspmd_sharding(self):
op_sharding = xla.xc.OpSharding()
op_sharding.type = xla.xc.OpSharding.Type.REPLICATED
s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding)

View File

@ -393,7 +393,7 @@ class PJitTest(jtu.BufferDonationTestCase):
shape = (8, 8)
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P(None))
ops = pjit_lib.to_op_sharding_sharding(
ops = pjit_lib.to_gspmd_sharding(
NamedSharding(mesh, P('x', 'y')), len(shape))
@partial(pjit, in_axis_resources=s, out_axis_resources=s)
@ -3960,18 +3960,18 @@ class UtilTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps1 = NamedSharding(mesh, P('x', 'y'))
op_sharding_sharding = pjit_lib.to_op_sharding_sharding(mps1, ndim)
next_loop_sharding = simulated_cached_fun(op_sharding_sharding)
gspmd_sharding = pjit_lib.to_gspmd_sharding(mps1, ndim)
next_loop_sharding = simulated_cached_fun(gspmd_sharding)
cache_info1 = simulated_cached_fun.cache_info()
next_op_sharding_sharding = pjit_lib.to_op_sharding_sharding(
next_gspmd_sharding = pjit_lib.to_gspmd_sharding(
next_loop_sharding, ndim)
simulated_cached_fun(next_op_sharding_sharding)
simulated_cached_fun(next_gspmd_sharding)
cache_info2 = simulated_cached_fun.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
self.assertEqual(id(next_op_sharding_sharding), id(op_sharding_sharding))
self.assertEqual(id(next_gspmd_sharding), id(gspmd_sharding))
def test_get_partition_spec(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))