mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Replace op_sharding_sharding
with gspmd_sharding
. This is purely an internal change.
PiperOrigin-RevId: 510562354
This commit is contained in:
parent
0ffdeb3de2
commit
d93aa70801
@ -779,7 +779,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
# assume that `x` is committed. This might happen when the input is
|
||||
# a `ShapedDtypeStruct` or `types.SimpleNamespace`, etc that might
|
||||
# only have a `sharding` attribute on them.
|
||||
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim)
|
||||
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim)
|
||||
if getattr(x, '_committed', True) else None)
|
||||
else:
|
||||
return aval, None
|
||||
|
@ -110,7 +110,7 @@ def arg_spec(x: Any) -> ArgSpec:
|
||||
if config.jax_array:
|
||||
if isinstance(x.sharding, PmapSharding):
|
||||
return aval, None
|
||||
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim) # type: ignore
|
||||
return aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore
|
||||
if x._committed else None)
|
||||
else:
|
||||
return aval, x._device
|
||||
|
@ -3353,7 +3353,7 @@ def _get_input_indices(
|
||||
return input_indices
|
||||
|
||||
|
||||
def get_op_sharding_shardings_from_executable(
|
||||
def get_gspmd_shardings_from_executable(
|
||||
xla_executable, device_assignment: Sequence[xc.Device],
|
||||
num_in_avals: int, num_out_avals: int
|
||||
) -> Tuple[Sequence[sharding_internal.XLACompatibleSharding],
|
||||
@ -3550,7 +3550,7 @@ class UnloadedMeshExecutable:
|
||||
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
||||
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
|
||||
assert mesh is None
|
||||
_, out_shardings_xla = get_op_sharding_shardings_from_executable( # type: ignore
|
||||
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment,
|
||||
len(global_in_avals), len(global_out_avals))
|
||||
orig_out_shardings = out_shardings
|
||||
|
@ -1878,12 +1878,11 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
mps = NamedSharding._from_parsed_pspec(
|
||||
resource_env.physical_mesh, ParsedPartitionSpec((), ()))
|
||||
unconstrained_dims = get_unconstrained_dims(mps)
|
||||
op_sharding_sharding = GSPMDSharding.get_replicated(
|
||||
mps._device_assignment)
|
||||
gspmd_sharding = GSPMDSharding.get_replicated(mps._device_assignment)
|
||||
new_eqns.append(core.JaxprEqn(
|
||||
[tmpvar], [outvar], sharding_constraint_p,
|
||||
dict(resource_env=resource_env,
|
||||
sharding=op_sharding_sharding,
|
||||
sharding=gspmd_sharding,
|
||||
unconstrained_dims=unconstrained_dims),
|
||||
set(),
|
||||
eqn.source_info))
|
||||
|
@ -865,7 +865,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
|
||||
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
|
||||
# FROM_GDA is removed.
|
||||
canonicalized_shardings = tuple(
|
||||
i if _is_unspecified_or_from_gda_or_auto(i) else to_op_sharding_sharding(i, aval.ndim)
|
||||
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
|
||||
return tuple(global_in_avals), canonicalized_shardings
|
||||
|
||||
@ -908,7 +908,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
|
||||
# Local or global avals doesn't matter for converting to op sharding because
|
||||
# the `ndim` does not change.
|
||||
canonicalized_in_shardings_flat = tuple(
|
||||
i if _is_from_gda(i) or is_auto(i) else to_op_sharding_sharding(i, aval.ndim)
|
||||
i if _is_from_gda(i) or is_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, local_in_avals))
|
||||
|
||||
global_in_avals = local_to_global(
|
||||
@ -961,7 +961,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
allow_uneven_sharding=False)
|
||||
|
||||
canonicalized_out_shardings_flat = tuple(
|
||||
o if _is_unspecified(o) or is_auto(o) else to_op_sharding_sharding(o, aval.ndim)
|
||||
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
|
||||
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
|
||||
)
|
||||
return canonicalized_out_shardings_flat
|
||||
@ -1233,7 +1233,7 @@ def _resolve_in_shardings(
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
resolved_in_shardings.append(_UNSPECIFIED)
|
||||
else:
|
||||
resolved_in_shardings.append(to_op_sharding_sharding(
|
||||
resolved_in_shardings.append(to_gspmd_sharding(
|
||||
cast(XLACompatibleSharding, arg_s), arg.ndim))
|
||||
else:
|
||||
if dispatch.is_single_device_sharding(arg_s):
|
||||
@ -1671,13 +1671,13 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
_allow_propagation_to_outputs=[True] * len(known_params['out_shardings']),
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
_, out_op_sharding_shardings = pxla.get_op_sharding_shardings_from_executable(
|
||||
_, out_gspmd_shardings = pxla.get_gspmd_shardings_from_executable(
|
||||
compiled.xla_executable, da, len(known_jaxpr.in_avals),
|
||||
len(known_jaxpr.out_avals))
|
||||
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
|
||||
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
|
||||
assert len(out_gspmd_shardings) == len(known_jaxpr.out_avals), (
|
||||
len(out_gspmd_shardings), len(known_jaxpr.out_avals))
|
||||
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
|
||||
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
|
||||
safe_zip(out_gspmd_shardings, known_jaxpr.out_avals)]
|
||||
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
|
||||
else:
|
||||
residual_op_shardings = ()
|
||||
@ -2057,7 +2057,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
|
||||
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
|
||||
allow_uneven_sharding=True)
|
||||
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim),
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim),
|
||||
resource_env=resource_env,
|
||||
unconstrained_dims=ud)
|
||||
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
@ -2152,13 +2152,13 @@ def get_array_mapping(
|
||||
if axes is not None for axis in axes)
|
||||
|
||||
|
||||
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
|
||||
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
op_sharding_sharding = GSPMDSharding(
|
||||
gspmd_sharding = GSPMDSharding(
|
||||
s._device_assignment, s._to_xla_op_sharding(ndim))
|
||||
op_sharding_sharding._original_sharding = s
|
||||
return op_sharding_sharding
|
||||
gspmd_sharding._original_sharding = s
|
||||
return gspmd_sharding
|
||||
|
||||
|
||||
def get_unconstrained_dims(sharding: NamedSharding):
|
||||
@ -2242,14 +2242,14 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. "
|
||||
f"Got GDA sharding: {gda_sharding} and "
|
||||
f"pjit sharding: {in_sharding._original_sharding}") # type: ignore
|
||||
return to_op_sharding_sharding(gda_sharding, ndim)
|
||||
return to_gspmd_sharding(gda_sharding, ndim)
|
||||
|
||||
out = []
|
||||
for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat):
|
||||
if is_auto(in_sharding_flat):
|
||||
out.append(in_sharding_flat)
|
||||
elif isinstance(arg, array.ArrayImpl):
|
||||
out.append(to_op_sharding_sharding(arg.sharding, arg.ndim))
|
||||
out.append(to_gspmd_sharding(arg.sharding, arg.ndim))
|
||||
elif isinstance(arg, GDA):
|
||||
gda_sharding = pxla.create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim))
|
||||
|
@ -137,9 +137,8 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
op_sharding = self._to_xla_op_sharding(len(global_shape))
|
||||
op_sharding_sharding = GSPMDSharding(self._device_assignment,
|
||||
op_sharding)
|
||||
return op_sharding_sharding.devices_indices_map(global_shape)
|
||||
gspmd_sharding = GSPMDSharding(self._device_assignment, op_sharding)
|
||||
return gspmd_sharding.devices_indices_map(global_shape)
|
||||
|
||||
@functools.cached_property
|
||||
def _addressable_device_assignment(self) -> XLADeviceAssignment:
|
||||
|
@ -765,7 +765,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
|
||||
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
|
||||
|
||||
def test_op_sharding_sharding_repr(self):
|
||||
def test_gspmd_sharding_repr(self):
|
||||
op = xc.OpSharding()
|
||||
op.type = xc.OpSharding.Type.OTHER
|
||||
op.tile_assignment_dimensions = [4, 1, 2]
|
||||
|
@ -182,7 +182,7 @@ class PickleTest(jtu.JaxTestCase):
|
||||
s = jax.sharding.PmapSharding(jax.devices(), ss)
|
||||
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
||||
|
||||
def test_pickle_op_sharding_sharding(self):
|
||||
def test_pickle_gspmd_sharding(self):
|
||||
op_sharding = xla.xc.OpSharding()
|
||||
op_sharding.type = xla.xc.OpSharding.Type.REPLICATED
|
||||
s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding)
|
||||
|
@ -393,7 +393,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
shape = (8, 8)
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P(None))
|
||||
ops = pjit_lib.to_op_sharding_sharding(
|
||||
ops = pjit_lib.to_gspmd_sharding(
|
||||
NamedSharding(mesh, P('x', 'y')), len(shape))
|
||||
|
||||
@partial(pjit, in_axis_resources=s, out_axis_resources=s)
|
||||
@ -3960,18 +3960,18 @@ class UtilTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
||||
mps1 = NamedSharding(mesh, P('x', 'y'))
|
||||
op_sharding_sharding = pjit_lib.to_op_sharding_sharding(mps1, ndim)
|
||||
next_loop_sharding = simulated_cached_fun(op_sharding_sharding)
|
||||
gspmd_sharding = pjit_lib.to_gspmd_sharding(mps1, ndim)
|
||||
next_loop_sharding = simulated_cached_fun(gspmd_sharding)
|
||||
cache_info1 = simulated_cached_fun.cache_info()
|
||||
|
||||
next_op_sharding_sharding = pjit_lib.to_op_sharding_sharding(
|
||||
next_gspmd_sharding = pjit_lib.to_gspmd_sharding(
|
||||
next_loop_sharding, ndim)
|
||||
simulated_cached_fun(next_op_sharding_sharding)
|
||||
simulated_cached_fun(next_gspmd_sharding)
|
||||
cache_info2 = simulated_cached_fun.cache_info()
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
self.assertEqual(id(next_op_sharding_sharding), id(op_sharding_sharding))
|
||||
self.assertEqual(id(next_gspmd_sharding), id(gspmd_sharding))
|
||||
|
||||
def test_get_partition_spec(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user