From 704b2e5fba10af0e4e0b2566df02355847585af1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 22 Jan 2025 16:47:58 -0800 Subject: [PATCH] [sharding_in_types] Make `vmap` work with shard_map + pallas PiperOrigin-RevId: 718578207 --- jax/_src/core.py | 13 +++++++++---- jax/_src/lax/lax.py | 16 ++++++++++------ jax/_src/lax/utils.py | 3 ++- jax/_src/mesh.py | 4 ++++ jax/_src/pallas/mosaic/lowering.py | 3 ++- jax/_src/pallas/pallas_call.py | 13 +++++++++---- jax/_src/pjit.py | 4 ++-- 7 files changed, 38 insertions(+), 18 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index d5becbcf7..17f61680e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -39,6 +39,7 @@ from jax._src import config from jax._src import effects from jax._src import compute_on from jax._src import mesh as mesh_lib +from jax._src.mesh import AxisTypes from jax._src.partition_spec import PartitionSpec as P from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, @@ -1687,7 +1688,9 @@ def _invalid_shape_error(shape: Shape, context: str=""): # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with # Collective too. -def modify_spec_for_hidden(spec, mesh) -> P: +def modify_spec_for_hidden_collective(spec, mesh) -> P: + if all(s is None for s in spec): + return spec new_spec = [] # type: ignore for s in spec: if s is None: @@ -1695,13 +1698,15 @@ def modify_spec_for_hidden(spec, mesh) -> P: else: temp_s = s[0] if isinstance(s, tuple) else s new_spec.append( - None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s) + None + if mesh._name_to_type[temp_s] in (AxisTypes.Hidden, AxisTypes.Collective) + else s) return P(*new_spec) def _maybe_modify_sharding(sharding): - if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types: + if sharding.mesh._are_all_axes_visible: return sharding - new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh) + new_spec = modify_spec_for_hidden_collective(sharding.spec, sharding.mesh) return sharding.with_spec(new_spec) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 10dc7ca0d..17330a36c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5585,7 +5585,7 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod') + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod, @@ -5613,8 +5613,9 @@ pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule -reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, - 'reduce_min') +reduce_min_p = standard_primitive( + _reduce_op_shape_rule, _input_dtype, 'reduce_min', + sharding_rule=_reduce_op_sharding_rule) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min, @@ -5705,22 +5706,25 @@ def _reduce_logical_shape_rule(operand, *, axes): raise TypeError(f"logical reduction requires operand dtype bool or int, got {operand.dtype}.") return tuple(np.delete(operand.shape, axes)) +def _reduce_logical_sharding_rule(operand, *, axes): + return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes)) + reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index c5f8f3157..6019f430b 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -50,7 +50,8 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): if config.sharding_in_types.value: if rule is None: - if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore return None if num_out is None else [None] * num_out else: raise ValueError( diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 5ca0c894f..69f3cbbe9 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -475,6 +475,10 @@ class AbstractMesh: def _are_all_axes_hidden(self) -> bool: return all(t == AxisTypes.Hidden for t in self.axis_types.keys()) + @functools.cached_property + def _are_all_axes_visible(self) -> bool: + return all(t == AxisTypes.Visible for t in self.axis_types.keys()) + @functools.cached_property def _any_axis_collective(self) -> bool: return any(t == AxisTypes.Collective for t in self.axis_types.keys()) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 79fccb481..29b0bbfe8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -921,7 +921,8 @@ def jaxpr_subcomp( name_stack=ctx.name_stack + eqn.source_info.name_stack ) loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) - with source_info_util.user_context(eqn.source_info.traceback), loc: + with (source_info_util.user_context(eqn.source_info.traceback), loc, + eqn.ctx.manager): if eqn.primitive in lowering_rules: if eqn.primitive not in skip_mlir_conversions: invals = [_ensure_mlir_value(x, v.aval) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 87a63db39..fe1b25fe7 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1117,10 +1117,15 @@ def _pallas_call_batching_rule( # assert ragged_axis_length is not None args = (ragged_axis_length, *args) assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) - batched_out_avals = tuple( - aval.update(shape=tuple_insert(aval.shape, 0, axis_size)) - for aval in out_avals - ) + + batched_out_avals = [] + for aval in out_avals: + sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) + if config.sharding_in_types.value else None) + shape = tuple_insert(aval.shape, 0, axis_size) + batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) + batched_out_avals = tuple(batched_out_avals) # type: ignore + out = pallas_call_p.bind( *dynamic_grid_args, *args, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5b0b5a967..ae975a636 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2840,7 +2840,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None, def decorator(*args, **kwargs): new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden) with mesh_lib.set_abstract_mesh(new_mesh): - in_specs = tree_map(lambda a: core.modify_spec_for_hidden( + in_specs = tree_map(lambda a: core.modify_spec_for_hidden_collective( a.aval.sharding.spec, new_mesh), args) args = mesh_cast(args, in_specs) out = fun(*args, **kwargs) @@ -2861,7 +2861,7 @@ def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None, with mesh_lib.set_abstract_mesh(new_mesh): args = mesh_cast(args, in_shardings) out = fun(*args, **kwargs) - out_specs = tree_map(lambda o: core.modify_spec_for_hidden( + out_specs = tree_map(lambda o: core.modify_spec_for_hidden_collective( o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out) return mesh_cast(out, out_specs) return decorator