diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f9ab2d01..c9c33e4bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.5 +* Deprecations + * `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`. + `jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023. + ## jaxlib 0.4.5 ## jax 0.4.4 (Feb 16, 2023) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index cd21a8f64..439b90914 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -655,7 +655,7 @@ with mesh: c:bf16[32,512,512] d:bf16[512,512] = pjit[ donated_invars=(False, False) in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>) - in_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated})) + in_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated})) jaxpr={ lambda ; e:bf16[32,512,512] f:bf16[512,512]. let g:bf16[8,4,512,512] = reshape[ dimensions=None @@ -734,7 +734,7 @@ with mesh: in (bn, bf) } name= out_positional_semantics=_PositionalSemantics.GLOBAL - out_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated})) + out_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated})) resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ()) ] a b in (c, d) } diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 0f2c90834..3c96c0ccb 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -43,7 +43,7 @@ from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.lax import control_flow as cf -from jax._src.sharding import OpShardingSharding +from jax._src.sharding import GSPMDSharding from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache) @@ -860,7 +860,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, if jax.config.jax_array: sharding = pjit._UNSPECIFIED else: - sharding = OpShardingSharding.get_replicated( + sharding = GSPMDSharding.get_replicated( list(resource_env.physical_mesh.devices.flat)) new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index de278224c..74e6d9529 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding +from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding # pytype: disable=import-error try: @@ -310,7 +310,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, def _op_sharding_callback(op_sharding: xc.OpSharding): if mesh.empty: - return callback(OpShardingSharding( + return callback(GSPMDSharding( devices, op_sharding)) pspec = pjit.parse_flatten_op_sharding( op_sharding, mesh)[0].get_partition_spec() diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 56f5b9f5d..5a94d7e77 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -61,7 +61,7 @@ from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.sharding import (PmapSharding, SingleDeviceSharding, - OpShardingSharding, NamedSharding, PartitionSpec, + GSPMDSharding, NamedSharding, PartitionSpec, Sharding) from jax._src.util import flatten, unflatten @@ -322,7 +322,7 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins): assert len(da) == 1 # in_shardings will be marked as replicated regardless of whatever the input # had. Given that only a single device is allowed above, this is correct. - in_shardings = [OpShardingSharding.get_replicated(da)] * num_ins + in_shardings = [GSPMDSharding.get_replicated(da)] * num_ins return da, in_shardings diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b4c6a78f3..c8b912b15 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2896,7 +2896,7 @@ def lower_sharding_computation( any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore any(not _is_unspecified(o) for o in out_shardings)) # type: ignore - in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment) + in_shardings = tuple(sharding_internal.GSPMDSharding.get_replicated(device_assignment) if _is_unspecified(i) else i for i in in_shardings) log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG @@ -3372,9 +3372,9 @@ def get_op_sharding_shardings_from_executable( in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) - in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i) + in_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, i) for i in in_op_shardings] - out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o) + out_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, o) for o in out_op_shardings] # This condition happens when all the elements in the output tuple have the # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to @@ -3720,7 +3720,7 @@ def _out_shardings_for_trivial( # a replicated sharding from jax._src import array - rep = sharding_internal.OpShardingSharding( + rep = sharding_internal.GSPMDSharding( device_assignment, sharding_internal._get_replicated_op_sharding()) shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {} for constvar, constval in zip(jaxpr.constvars, consts): diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ab5d1b2cb..3f631a4c8 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1860,7 +1860,7 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources): def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): from jax._src.pjit import ( sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims, - OpShardingSharding) + GSPMDSharding) rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name) if isinstance(jaxpr, core.ClosedJaxpr): @@ -1878,7 +1878,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): mps = NamedSharding._from_parsed_pspec( resource_env.physical_mesh, ParsedPartitionSpec((), ())) unconstrained_dims = get_unconstrained_dims(mps) - op_sharding_sharding = OpShardingSharding.get_replicated( + op_sharding_sharding = GSPMDSharding.get_replicated( mps._device_assignment) new_eqns.append(core.JaxprEqn( [tmpvar], [outvar], sharding_constraint_p, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d267fe786..8a59852c9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -37,7 +37,7 @@ from jax.tree_util import ( treedef_tuple) from jax._src.sharding import ( - NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding, + NamedSharding, Sharding, XLACompatibleSharding, GSPMDSharding, XLADeviceAssignment, SingleDeviceSharding, PmapSharding) from jax._src import array from jax._src import dispatch @@ -90,8 +90,8 @@ def _is_unspecified_or_from_gda_or_auto(x): return _is_from_gda(x) or is_auto(x) or _is_unspecified(x) -PjitSharding = Union[OpShardingSharding, _UnspecifiedValue, _AUTOAxisResource] -PjitShardingMinusUnspecified = Union[OpShardingSharding, _AUTOAxisResource] +PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource] +PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource] MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource] MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource] @@ -535,7 +535,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): in_positional_semantics = ( pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics - # in_shardings and out_shardings here are all OpShardingSharding. + # in_shardings and out_shardings here are all GSPMDSharding. params = dict( jaxpr=jaxpr, in_shardings=canonicalized_in_shardings_flat, @@ -1353,7 +1353,7 @@ class SameDeviceAssignmentTuple: device_assignment: Optional[XLADeviceAssignment] def __hash__(self): - shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore + shardings_hash = tuple(s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore for s in self.shardings) if self.device_assignment is None: return hash(shardings_hash) @@ -1364,7 +1364,7 @@ class SameDeviceAssignmentTuple: if not isinstance(other, SameDeviceAssignmentTuple): return False return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error - if isinstance(s, OpShardingSharding) and isinstance(o, OpShardingSharding) + if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding) else s == o for s, o in safe_zip(self.shardings, other.shardings)) and self.device_assignment == other.device_assignment) @@ -1420,14 +1420,14 @@ def _pjit_lower_cached( Tuple[MeshShardingMinusUnspecified, ...], tuple( NamedSharding._from_parsed_pspec( mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore - if isinstance(i, OpShardingSharding) else i + if isinstance(i, GSPMDSharding) else i for i in in_shardings )) out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef] Tuple[MeshSharding, ...], tuple( NamedSharding._from_parsed_pspec( mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore - if isinstance(o, OpShardingSharding) else o + if isinstance(o, GSPMDSharding) else o for o in out_shardings )) @@ -1560,7 +1560,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None) pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: Union[OpShardingSharding, _UnspecifiedValue], + s: Union[GSPMDSharding, _UnspecifiedValue], dim: int, val: Tuple[str, ...], mesh, ndim: int): if _is_unspecified(s): return s @@ -1569,14 +1569,14 @@ def _pjit_batcher_for_sharding( tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad - return OpShardingSharding(s._device_assignment, new_op) # type: ignore + return GSPMDSharding(s._device_assignment, new_op) # type: ignore else: - assert isinstance(s, OpShardingSharding) + assert isinstance(s, GSPMDSharding) assert mesh is not None and not mesh.empty parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val) mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) - return OpShardingSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim)) + return GSPMDSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim)) def _pjit_jvp(primals_in, tangents_in, @@ -1642,7 +1642,7 @@ def _pjit_partial_eval(trace, *in_tracers, residual_shardings = (_UNSPECIFIED,) * num_residuals else: da = list(resource_env.physical_mesh.devices.flat) - residual_shardings = (OpShardingSharding.get_replicated(da),) * num_residuals + residual_shardings = (GSPMDSharding.get_replicated(da),) * num_residuals # Compute the known outputs known_params = dict( jaxpr=known_jaxpr, @@ -1683,7 +1683,7 @@ def _pjit_partial_eval(trace, *in_tracers, residual_op_shardings = () assert len(residual_shardings) == len(residual_op_shardings), ( len(residual_shardings), len(residual_op_shardings)) - residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings) + residual_shardings = tuple(GSPMDSharding(da, op) for op in residual_op_shardings) known_params['out_shardings'] = ( keep_where(out_shardings, known_outs) + residual_shardings) @@ -2080,13 +2080,13 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, aval, = ctx.avals_in axis_ctx = ctx.module_context.axis_context # axis_ctx and manual_axes is *only used with xmap* and xmap only works with - # NamedSharding. So convert the OpShardingSharding to NamedSharding + # NamedSharding. So convert the GSPMDSharding to NamedSharding # and then convert it back with the added special axes. if isinstance(axis_ctx, mlir.SPMDAxisContext): mesh = resource_env.physical_mesh parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0] mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) - sharding = OpShardingSharding( + sharding = GSPMDSharding( mps._device_assignment, mps._to_xla_op_sharding(aval.ndim, axis_ctx=axis_ctx)) return [ mlir.wrap_with_sharding_op( @@ -2152,10 +2152,10 @@ def get_array_mapping( if axes is not None for axis in axes) -def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> OpShardingSharding: - if isinstance(s, OpShardingSharding): +def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding: + if isinstance(s, GSPMDSharding): return s - op_sharding_sharding = OpShardingSharding( + op_sharding_sharding = GSPMDSharding( s._device_assignment, s._to_xla_op_sharding(ndim)) op_sharding_sharding._original_sharding = s return op_sharding_sharding @@ -2232,7 +2232,7 @@ def _maybe_replace_from_gda_with_pspec( @lru_cache() def _gda_check_and_get_sharding( - gda_sharding: NamedSharding, in_sharding: OpShardingSharding, ndim: int): + gda_sharding: NamedSharding, in_sharding: GSPMDSharding, ndim: int): if not _is_from_gda(in_sharding) and not pxla.are_op_shardings_equal( gda_sharding._to_xla_op_sharding(ndim), in_sharding._to_xla_op_sharding(ndim)): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 92e0c0755..86a98f561 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -43,7 +43,7 @@ from jax._src.lib import gpu_prng from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy from jax._src.sharding import ( - NamedSharding, PmapSharding, OpShardingSharding) + NamedSharding, PmapSharding, GSPMDSharding) from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip map, unsafe_map = safe_map, map @@ -375,7 +375,7 @@ class KeyTyRules: if is_out_sharding_from_xla: phys_sharding = out_sharding else: - phys_sharding = OpShardingSharding( + phys_sharding = GSPMDSharding( out_sharding._device_assignment, KeyTyRules.physical_op_sharding(aval, out_sharding)) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index d3d32ce8a..cb878fb11 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -24,6 +24,7 @@ import jax from jax._src import core from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax.interpreters import mlir from jax._src.interpreters import pxla @@ -75,10 +76,10 @@ class Sharding(metaclass=abc.ABCMeta): """Returns True if two shardings put the same logical array (sharded/unsharded) on the same device(s). - For example, every XLACompatibleSharding lowers to OpShardingSharding which + For example, every XLACompatibleSharding lowers to GSPMDSharding which is a general representation. So `jax.sharding.NamedSharding` is equivalent to `jax.sharding.PositionalSharding` if both of them lower to the same - OpShardingSharding. + GSPMDSharding. """ raise NotImplementedError('Subclasses should implement this method.') @@ -136,7 +137,7 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta): @functools.lru_cache(maxsize=4096) def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: op_sharding = self._to_xla_op_sharding(len(global_shape)) - op_sharding_sharding = OpShardingSharding(self._device_assignment, + op_sharding_sharding = GSPMDSharding(self._device_assignment, op_sharding) return op_sharding_sharding.devices_indices_map(global_shape) @@ -645,8 +646,8 @@ class DeviceIdSet: self._ids == other._ids) -@use_cpp_class(xc.OpShardingSharding) -class OpShardingSharding(XLACompatibleSharding): +@use_cpp_class(xc.GSPMDSharding if xla_extension_version >= 129 else xc.OpShardingSharding) # type: ignore +class GSPMDSharding(XLACompatibleSharding): @use_cpp_method() def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding): @@ -661,7 +662,7 @@ class OpShardingSharding(XLACompatibleSharding): return hash(xc.HloSharding.from_proto(self._op_sharding)) def __eq__(self, other): - if not isinstance(other, OpShardingSharding): + if not isinstance(other, GSPMDSharding): return False if id(self) == id(other): return True @@ -674,7 +675,7 @@ class OpShardingSharding(XLACompatibleSharding): return self._hash def __repr__(self): - return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})' + return f'GSPMDSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})' def is_compatible_aval(self, aval_shape: Shape): num_ways_dim_sharded, _ = pxla.get_num_ways_dim_sharded(self._op_sharding) @@ -705,3 +706,8 @@ class OpShardingSharding(XLACompatibleSharding): def get_replicated(cls, device_assignment): proto = _get_replicated_op_sharding() return cls(device_assignment, proto) + + +# TODO(yashkatariya); Remove OpShardingSharding after 3 months from Feb 17, 2023 +# per the deprecation policy. +OpShardingSharding = GSPMDSharding diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 012b322d7..de7fe755d 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -385,8 +385,8 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, def to_mesh_pspec_sharding(op_sharding: xc.OpSharding): if mesh.empty: - from jax._src.sharding import OpShardingSharding - return OpShardingSharding(devices, op_sharding) + from jax._src.sharding import GSPMDSharding + return GSPMDSharding(devices, op_sharding) pspec = pjit.parse_flatten_op_sharding(op_sharding, mesh)[0].get_partition_spec() return jax.sharding.NamedSharding(mesh, pspec) diff --git a/jax/experimental/gda_serialization/serialization_test.py b/jax/experimental/gda_serialization/serialization_test.py index 30edd6199..ff79a0fc8 100644 --- a/jax/experimental/gda_serialization/serialization_test.py +++ b/jax/experimental/gda_serialization/serialization_test.py @@ -22,7 +22,7 @@ from jax._src import util from jax._src import config as jax_config from jax.config import config from jax._src import array -from jax._src.sharding import NamedSharding, OpShardingSharding +from jax._src.sharding import NamedSharding, GSPMDSharding from jax.sharding import PartitionSpec as P from jax.experimental.gda_serialization import serialization import numpy as np @@ -133,7 +133,7 @@ class CheckpointTest(jtu.JaxTestCase): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat)) + new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 593f0f6de..05d69cd3d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -96,7 +96,7 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding.OpShardingSharding(inp.sharding._device_assignment, + reps = sharding.GSPMDSharding(inp.sharding._device_assignment, sharding._get_replicated_op_sharding()) out = pjit(_identity_fn, out_axis_resources=reps)(inp) else: diff --git a/jax/sharding.py b/jax/sharding.py index 39fffa058..9a32f1897 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -18,13 +18,16 @@ from jax._src.sharding import ( Sharding as Sharding, XLACompatibleSharding as XLACompatibleSharding, - # TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months. + # TODO(yashkatariya): Remove MeshPspecSharding in 3 months. MeshPspecSharding as MeshPspecSharding, # New name of MeshPspecSharding to match PositionalSharding below. NamedSharding as NamedSharding, PartitionSpec as PartitionSpec, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, + GSPMDSharding as GSPMDSharding, + # TODO(yashkatariya): Remove OpShardingSharding in 3 months from + # Feb 17, 2023. OpShardingSharding as OpShardingSharding, PositionalSharding as PositionalSharding, ) diff --git a/tests/array_test.py b/tests/array_test.py index f6f254be0..63b2450fd 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -694,7 +694,7 @@ class ShardingTest(jtu.JaxTestCase): shape = (8, 4) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) mps = sharding.NamedSharding(mesh, pspec) - ops = sharding.OpShardingSharding( + ops = sharding.GSPMDSharding( list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape))) self.assertDictEqual( ops.devices_indices_map(shape), mps.devices_indices_map(shape)) @@ -771,16 +771,16 @@ class ShardingTest(jtu.JaxTestCase): op.tile_assignment_dimensions = [4, 1, 2] op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True - s = sharding.OpShardingSharding(jax.devices(), op) + s = sharding.GSPMDSharding(jax.devices(), op) self.assertEqual( repr(s), - 'OpShardingSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' + 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' 'last_tile_dim_replicate})') op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED - s2 = sharding.OpShardingSharding(jax.devices(), op2) - self.assertEqual(repr(s2), 'OpShardingSharding({replicated})') + s2 = sharding.GSPMDSharding(jax.devices(), op2) + self.assertEqual(repr(s2), 'GSPMDSharding({replicated})') @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (4, 2), (), False), @@ -876,9 +876,9 @@ class ShardingTest(jtu.JaxTestCase): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.REPLICATED - s6 = jax.sharding.OpShardingSharding([jax.devices()[0]], op1) + s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1) - s7 = jax.sharding.OpShardingSharding(jax.devices(), op1) + s7 = jax.sharding.GSPMDSharding(jax.devices(), op1) # The OpSharding is replicated but the Sharding itself are on different # devices. @@ -888,7 +888,7 @@ class ShardingTest(jtu.JaxTestCase): op2.type = xc.OpSharding.Type.OTHER op2.tile_assignment_devices = [0, 1] op2.tile_assignment_dimensions = [2, 1] - s8 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op2) + s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2) self.assertTrue(s1.is_equivalent_to(s6, 2)) self.assertTrue(s5.is_equivalent_to(s8, 2)) @@ -901,7 +901,7 @@ class ShardingTest(jtu.JaxTestCase): op3.tile_assignment_devices = [0, 1] op3.tile_assignment_dimensions = [1, 1, 2] op3.replicate_on_last_tile_dim = True - s10 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op3) + s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3) self.assertTrue(s9.is_equivalent_to(s10, 2)) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index b61c999a7..e64b0087a 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -185,7 +185,7 @@ class PickleTest(jtu.JaxTestCase): def test_pickle_op_sharding_sharding(self): op_sharding = xla.xc.OpSharding() op_sharding.type = xla.xc.OpSharding.Type.REPLICATED - s = jax.sharding.OpShardingSharding(jax.devices(), op_sharding) + s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding) self.assertEqual(s, pickle.loads(pickle.dumps(s))) @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 64c9b52c7..a6de700f9 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -44,7 +44,7 @@ from jax.experimental import global_device_array from jax.experimental import multihost_utils from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array -from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding +from jax._src.sharding import NamedSharding, Sharding, GSPMDSharding import jax._src.pjit as pjit_lib from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO) from jax._src.interpreters import pxla @@ -1860,7 +1860,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def _checks(out, input_data): self.assertIsInstance(out, array.ArrayImpl) - self.assertIsInstance(out.sharding, OpShardingSharding) + self.assertIsInstance(out.sharding, GSPMDSharding) self.assertEqual(out.shape, (8, 2)) self.assertEqual(out.addressable_shards[0].data.shape, (2, 1)) for s in out.addressable_shards: @@ -2416,20 +2416,20 @@ class ArrayPjitTest(jtu.JaxTestCase): f = pjit(lambda x: x) out1 = f(arr) - self.assertIsInstance(out1.sharding, OpShardingSharding) + self.assertIsInstance(out1.sharding, GSPMDSharding) out1.sharding.devices_indices_map(shape) - cache_info1 = OpShardingSharding.devices_indices_map.cache_info() + cache_info1 = GSPMDSharding.devices_indices_map.cache_info() out2 = f(out1) - self.assertIsInstance(out2.sharding, OpShardingSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) out2.sharding.devices_indices_map(shape) - cache_info2 = OpShardingSharding.devices_indices_map.cache_info() + cache_info2 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) out3 = f(out2) - self.assertIsInstance(out3.sharding, OpShardingSharding) + self.assertIsInstance(out3.sharding, GSPMDSharding) out3.sharding.devices_indices_map(shape) - cache_info3 = OpShardingSharding.devices_indices_map.cache_info() + cache_info3 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) @jax_array(True) @@ -3871,21 +3871,21 @@ class UtilTest(jtu.JaxTestCase): shape = (8, 4) devices = jax.devices() - ops = OpShardingSharding(devices, op1) + ops = GSPMDSharding(devices, op1) ops.devices_indices_map(shape) - cache_info1 = OpShardingSharding.devices_indices_map.cache_info() + cache_info1 = GSPMDSharding.devices_indices_map.cache_info() ops.devices_indices_map(shape) - cache_info2 = OpShardingSharding.devices_indices_map.cache_info() + cache_info2 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - ops = OpShardingSharding(devices, op2) + ops = GSPMDSharding(devices, op2) ops.devices_indices_map(shape) - cache_info3 = OpShardingSharding.devices_indices_map.cache_info() + cache_info3 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) ops.devices_indices_map(shape) - cache_info4 = OpShardingSharding.devices_indices_map.cache_info() + cache_info4 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info4.hits, cache_info3.hits + 1)