diff --git a/jax/_src/array.py b/jax/_src/array.py index 097a8d36a..d92f21b16 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -330,7 +330,7 @@ class ArrayImpl(basearray.Array): @property def is_fully_replicated(self) -> bool: - return self.shape == self._arrays[0].shape + return self.sharding.is_fully_replicated def __repr__(self): prefix = 'Array(' diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 670807b65..f26ec7312 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -34,7 +34,6 @@ from jax._src import core from jax._src import dtypes from jax._src import effects as effects_lib from jax._src import linear_util as lu -from jax._src import op_shardings from jax._src import sharding_impls from jax._src import source_info_util from jax._src import util @@ -47,6 +46,7 @@ from jax._src.lib.mlir import dialects from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src.sharding_impls import XLACompatibleSharding map, unsafe_map = util.safe_map, map @@ -491,7 +491,7 @@ def flatten_lowering_ir_args( _module_name_regex = re.compile(r"[^\w.-]") def sharded_aval(aval: core.AbstractValue, - sharding: Optional[xc.OpSharding]) -> core.AbstractValue: + sharding: Optional[XLACompatibleSharding]) -> core.AbstractValue: """Returns the new aval sharded based on sharding proto.""" if sharding is None: return aval @@ -499,18 +499,7 @@ def sharded_aval(aval: core.AbstractValue, return aval if not isinstance(aval, core.ShapedArray): raise NotImplementedError - - if (op_shardings.is_op_sharding_replicated(sharding) or - sharding.type == xc.OpSharding.Type.MANUAL): - return aval - - partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding) - out = [] - for s, p in zip(aval.shape, partitions): - quotient, remainder = divmod(s, p) - assert remainder == 0 - out.append(quotient) - return aval.update(tuple(out)) + return aval.update(sharding.shard_shape(aval.shape)) def eval_dynamic_shape(ctx: LoweringRuleContext, @@ -537,6 +526,16 @@ class LoweringResult(NamedTuple): _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] +def _to_logical_op_sharding( + aval: core.AbstractValue, sharding: Optional[XLACompatibleSharding], +) -> Optional[xc.OpSharding]: + if sharding is None: + return None + assert isinstance(sharding, sharding_impls.XLACompatibleSharding) + assert isinstance(aval, core.ShapedArray) + return sharding._to_xla_op_sharding(aval.ndim) + + def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -547,8 +546,8 @@ def lower_jaxpr_to_module( name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Optional[Sequence[bool]] = None, - arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, - result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, + arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None, + result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None, arg_names: Optional[Sequence[Optional[str]]] = None, result_names: Optional[Sequence[Optional[str]]] = None, num_replicas: int = 1, @@ -596,6 +595,13 @@ def lower_jaxpr_to_module( else: dim_vars = () + arg_op_shardings = ( + map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings) + if arg_shardings is not None else arg_shardings) + result_op_shardings = ( + map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) + if result_shardings is not None else result_shardings) + ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, keepalives, channel_iter, host_callbacks, dim_vars=dim_vars) with ctx.context, ir.Location.unknown(ctx.context): @@ -611,9 +617,11 @@ def lower_jaxpr_to_module( replace_tokens_with_dummy=True, num_output_tokens=0, replicated_args=replicated_args, - arg_shardings=arg_shardings, result_shardings=result_shardings, + arg_shardings=arg_op_shardings, + result_shardings=result_op_shardings, input_output_aliases=input_output_aliases, - arg_names=arg_names, result_names=result_names) + arg_names=arg_names, + result_names=result_names) if not ctx.module.operation.verify(): module_string = module_to_string(ctx.module) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d0c7cd0ae..c9d400c2b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1892,21 +1892,21 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, dispatch.raise_warnings_or_errors_for_jit_of_pmap( nreps, backend, fun_name, jaxpr) - in_op_shardings: Optional[List[Optional[xc.OpSharding]]] - out_op_shardings: Optional[List[Optional[xc.OpSharding]]] + in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]] + out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]] axis_ctx: mlir.AxisContext if nreps == 1: - in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings) - out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings) + in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) + out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) axis_ctx = sharding_impls.ShardingContext(device_assignment) num_partitions = len(device_assignment) else: # This path is triggered for `jit(pmap)` cases. replicated_args = None - in_op_shardings = None - out_op_shardings = None + in_mlir_shardings = None + out_mlir_shardings = None axis_env = sharding_impls.AxisEnv(nreps, (), ()) axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) num_partitions = 1 @@ -1929,8 +1929,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, name_stack, donated_invars, replicated_args=replicated_args, - arg_shardings=in_op_shardings, - result_shardings=out_op_shardings, + arg_shardings=in_mlir_shardings, + result_shardings=out_mlir_shardings, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=nreps, @@ -2113,14 +2113,14 @@ def lower_sharding_computation( pmap_nreps=nreps) -def _to_logical_op_sharding( +def _to_logical_sharding( aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource] -) -> Optional[xc.OpSharding]: +) -> Optional[sharding_impls.XLACompatibleSharding]: if is_unspecified(sharding) or is_auto(sharding): return None elif isinstance(aval, ShapedArray): assert isinstance(sharding, sharding_impls.XLACompatibleSharding) - return sharding._to_xla_op_sharding(aval.ndim) + return sharding elif isinstance(aval, core.AbstractToken): return None else: @@ -2219,12 +2219,12 @@ def lower_mesh_computation( # 2. Build up the HLO tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform) - in_partitions: Optional[List[Optional[xc.OpSharding]]] - out_partitions: Optional[List[Optional[xc.OpSharding]]] + in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]] + out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]] axis_ctx: mlir.AxisContext if spmd_lowering: - in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings) - out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings) + in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings) + out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(in_jaxpr_avals) axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes) num_replicas = 1 @@ -2370,12 +2370,7 @@ def _get_input_indices( index = tuple( (slice(None),) for _ in range(num_addressable_devices)) else: - # We special case this logic to support fully replicated values because - # the mesh is global mesh and the indices returned by `spec_to_indices` will - # represent index for each device in the global mesh. But here we want - # indices for the local devices of the global mesh. - proto = sharding._to_xla_op_sharding(aval.ndim) - if op_shardings.is_op_sharding_replicated(proto): + if sharding.is_fully_replicated: index = tuple( (slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore else: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 82abfbe8b..49ae4a749 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1029,9 +1029,8 @@ def _resolve_in_shardings( 'multiple devices is not supported.') else: if (isinstance(arg, np.ndarray) and - not op_shardings.is_op_sharding_replicated( - pjit_in_s._to_xla_op_sharding(arg.ndim)) # type: ignore - and xb.process_count() > 1): + not pjit_in_s.is_fully_replicated and # type: ignore + xb.process_count() > 1): raise ValueError( 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 60d3277da..17c031faa 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -80,6 +80,11 @@ class Sharding: """ raise NotImplementedError('Subclasses should implement this method.') + @property + def is_fully_replicated(self) -> bool: + """Returns if a sharding is fully replicated on all the devices.""" + raise NotImplementedError('Subclasses should implement this method.') + ############################################################################# # Default implementations below that all subclasses will inherit. diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index a98703210..05ba70b92 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -191,7 +191,7 @@ class NamedSharding(XLACompatibleSharding): mesh: mesh_lib.Mesh spec: PartitionSpec - _parsed_pspec: Optional[Any] + _parsed_pspec: ParsedPartitionSpec @use_cpp_method() def __init__( @@ -269,6 +269,17 @@ class NamedSharding(XLACompatibleSharding): # across multiple NamedSharding objects will be the same. return self.mesh._local_devices_set + @functools.cached_property + def is_fully_replicated(self) -> bool: + if self.mesh.size == 1: + return True + array_mapping = cast(ParsedPartitionSpec, get_array_mapping(self._parsed_pspec)) + mesh_shape = self.mesh.shape + num_partitions = 1 + for name in array_mapping: + num_partitions *= mesh_shape[name] + return num_partitions == 1 + @functools.lru_cache(maxsize=4096) def _to_xla_op_sharding( self, @@ -350,6 +361,10 @@ class SingleDeviceSharding(XLACompatibleSharding): def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: return get_replicated_op_sharding() + @property + def is_fully_replicated(self) -> bool: + return True + @use_cpp_class(xc.PmapSharding) class PmapSharding(XLACompatibleSharding): @@ -447,6 +462,13 @@ class PmapSharding(XLACompatibleSharding): def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: raise NotImplementedError("pmap doesn't use OpSharding.") + @functools.cached_property + def is_fully_replicated(self) -> bool: + for s in self.sharding_spec.sharding: + if isinstance(s, sharding_specs.Unstacked): + return False + return True + @functools.lru_cache(maxsize=4096) def shard_shape(self, global_shape: Shape) -> Shape: sharded_dim = None @@ -554,6 +576,10 @@ class PositionalSharding(XLACompatibleSharding): def device_set(self) -> set[xc.Device]: return set(self._devices) + @functools.cached_property + def is_fully_replicated(self) -> bool: + return self.shape == (1,) * self.ndim + # XLACompatibleSharding interface @property @@ -670,6 +696,10 @@ class GSPMDSharding(XLACompatibleSharding): def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding: return self._op_sharding + @functools.cached_property + def is_fully_replicated(self) -> bool: + return is_op_sharding_replicated(self._op_sharding) + @classmethod def get_replicated(cls, device_assignment): proto = get_replicated_op_sharding() diff --git a/tests/array_test.py b/tests/array_test.py index e4717aac9..66f183c5d 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -860,7 +860,7 @@ class ShardingTest(jtu.JaxTestCase): ("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)), ("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)), ("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))), - ("3d_mesh2_none_y_none", (1, 2, 4), P(None, None, 'z')), + ("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')), ("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)), ) def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): @@ -875,6 +875,42 @@ class ShardingTest(jtu.JaxTestCase): self.assertTrue(op_shardings.are_op_shardings_equal( original_op_sharding, out_op_sharding)) + @parameterized.named_parameters( + ("2d_mesh_x", (1, 1), P("x", "y")), + ("2d_mesh_x_y", (4, 2), P("x", "y")), + ("2d_mesh_empty", (2, 1), P()), + ("2d_mesh_p_none", (2, 1), P(None)), + ("2d_mesh_none_none", (2, 1), P(None, None)), + ("2d_mesh_tuple_empty", (2, 1), P((),)), + ("2d_mesh_x_none", (2, 1), P(('x',), None)), + ("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)), + ("2d_mesh_none", (2, 1), None), + ("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())), + ("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())), + ("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)), + ("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)), + ("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)), + ) + def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): + if len(mesh_shape) == 2: + axis_names = ('x', 'y') + elif len(mesh_shape) == 3: + axis_names = ('x', 'y', 'z') + else: + axis_names = ('x',) + mesh = jtu.create_global_mesh(mesh_shape, axis_names) + mps = jax.sharding.NamedSharding(mesh, pspec) + shape = (8, 2, 4) + mps_op_sharding = mps._to_xla_op_sharding(len(shape)) + ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) + self.assertEqual(mps.is_fully_replicated, ops_ifr) + + ps = _from_op_sharding_to_pos_sharding(mps_op_sharding, + mps._device_assignment) + self.assertEqual(ps.is_fully_replicated, + op_shardings.is_op_sharding_replicated( + ps._to_xla_op_sharding(len(shape)))) + def test_devices_sharding_respects_init_mesh_shape(self): value_shape = (8, 4)