diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 8507e3e91..d3e8c22cf 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1030,6 +1030,8 @@ def _to_physical_op_sharding( ) -> xc.OpSharding | SdyArraySharding | None: if sharding is None: return None + if all_unconstrained(sharding, aval): + return None if isinstance(sharding, AUTO): if config.use_shardy_partitioner.value: return sharding._to_sdy_sharding(aval.ndim) # type: ignore @@ -1071,10 +1073,8 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None: def contains_unconstrained(s): - return ( - isinstance(s, NamedSharding) - and PartitionSpec.UNCONSTRAINED in s._parsed_pspec - ) + return (isinstance(s, NamedSharding) + and PartitionSpec.UNCONSTRAINED in s._parsed_pspec) def all_unconstrained(s, aval): @@ -1084,12 +1084,19 @@ def all_unconstrained(s, aval): return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec) return False -def _get_unconstrained_dimensions(s, aval): +class UnconstrainedVariants(NamedTuple): + contains_unconstrained: bool + all_unconstrained: bool + unconstrained_dims: set[int] | None + +def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants: us = contains_unconstrained(s) - return ( - us, all_unconstrained(s, aval), - ({i for i, p in enumerate(s._parsed_pspec) - if p is PartitionSpec.UNCONSTRAINED} if us else None)) + unconstrained_dims = ({i for i, p in enumerate(s._parsed_pspec) + if p is PartitionSpec.UNCONSTRAINED} if us else None) + return UnconstrainedVariants( + contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval), + unconstrained_dims=unconstrained_dims) + def lower_jaxpr_to_module( module_name: str, @@ -1511,13 +1518,13 @@ def lower_jaxpr_to_fun( for is_donated, types in zip(xla_donated_args, input_types)]) ir_result_shardings = None - unconstrained_shardings = None + unconstrained_variants = None if result_shardings is not None: ir_result_shardings = util.flatten( [[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types) for a, s, types in zip(output_avals, result_shardings, output_types)]) - unconstrained_shardings = util.flatten( - [[_get_unconstrained_dimensions(s, a)] * len_ir_types(types) + unconstrained_variants = util.flatten( + [[_get_unconstrained_variants(s, a)] * len_ir_types(types) for a, s, types in zip(output_avals, result_shardings, output_types)]) ir_result_memory_kinds = None @@ -1633,9 +1640,9 @@ def lower_jaxpr_to_fun( attrs['jax.result_info'] = ir.StringAttr.get(name_) if use_sharding_annotations and ir_result_shardings is not None: - for attrs, sharding, us in zip(result_attrs, ir_result_shardings, - unconstrained_shardings): # type: ignore - if sharding is not None and not us[0]: + for attrs, sharding, uv in zip(result_attrs, ir_result_shardings, + unconstrained_variants): # type: ignore + if sharding is not None and not uv.contains_unconstrained: if config.use_shardy_partitioner.value: attrs["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -1716,13 +1723,15 @@ def lower_jaxpr_to_fun( if ir_result_shardings is not None: temp_flat_outputs = [] - for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings, - output_avals, unconstrained_shardings): # type: ignore - if us[0] and not us[1]: + for o, s, o_aval, uv in zip(flat_outputs, ir_result_shardings, + output_avals, unconstrained_variants): # type: ignore + if (s is not None and uv.contains_unconstrained and + not uv.all_unconstrained): if config.use_shardy_partitioner.value: s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh) temp_flat_outputs.append(wrap_with_sharding_op( - entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2])) + entry_lowering_ctx, o, o_aval, s, + unspecified_dims=uv.unconstrained_dims)) else: temp_flat_outputs.append(o) flat_outputs = temp_flat_outputs diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 07571c7ec..e6c7ac986 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2157,6 +2157,13 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts) +@lru_cache(maxsize=1024) +def _abstract_to_concrete_mesh(abstract_mesh, device_assignment): + np_dev = np.vectorize(lambda i: device_assignment[i], + otypes=[object])(np.arange(len(device_assignment))) + return Mesh(np_dev.reshape(abstract_mesh.axis_sizes), + abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types) + def _concretize_abstract_out_shardings(shardings, avals, device_assignment, out_mem_kinds): if device_assignment is None: @@ -2164,27 +2171,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment, if len(device_assignment) == 1: return shardings - np_dev = np.vectorize(lambda i: device_assignment[i], - otypes=[object])(np.arange(len(device_assignment))) - - @lru_cache(maxsize=128) - def _abstract_to_concrete_mesh(abstract_mesh): - return Mesh( - np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names, - axis_types=abstract_mesh.axis_types) - out = [] for s, a, mem_kind in zip(shardings, avals, out_mem_kinds): if isinstance(s, UnspecifiedValue) and a.sharding is not None: if a.sharding.mesh.empty: out.append(s) + elif a.sharding.mesh._are_all_axes_auto: + out.append(s) else: spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp for sp in a.sharding.spec]) if a.sharding.mesh._any_axis_auto else a.sharding.spec) out.append(NamedSharding( - _abstract_to_concrete_mesh(a.sharding.mesh), spec, - memory_kind=mem_kind)) + _abstract_to_concrete_mesh(a.sharding.mesh, device_assignment), + spec, memory_kind=mem_kind)) else: out.append(s) return tuple(out) @@ -2534,15 +2534,22 @@ def _get_mesh_pspec_shardings_from_executable( _orig_out_sharding_handlers = {} def _gspmd_to_named_sharding( - out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding: + out_s: GSPMDSharding, out_aval, orig_in_s: NamedSharding) -> NamedSharding: assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, NamedSharding) assert isinstance(orig_in_s.mesh, Mesh) - return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) + if (out_aval is not None and not out_aval.sharding.mesh.empty and + out_aval.sharding.mesh._are_all_axes_auto): + mesh = _abstract_to_concrete_mesh( + out_aval.sharding.mesh, out_s._device_assignment) + else: + mesh = orig_in_s.mesh + return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh) _orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding def _gspmd_to_positional_sharding( - out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding: + out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding + ) -> PositionalSharding: assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, PositionalSharding) return sharding_impls._op_sharding_to_pos_sharding( @@ -2550,7 +2557,8 @@ def _gspmd_to_positional_sharding( _orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore def _gspmd_to_single_device_sharding( - out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding: + out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding + ) -> SingleDeviceSharding: assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, SingleDeviceSharding) return SingleDeviceSharding( @@ -2565,15 +2573,17 @@ def _get_out_sharding_from_orig_sharding( for o, out_aval in safe_zip(out_shardings, out_avals): if (isinstance(o, sharding_impls.GSPMDSharding) and out_aval is not core.abstract_token): - if (orig_aval is not None and out_aval is not None and - out_aval.ndim == orig_aval.ndim - and sharding_impls.are_op_shardings_equal( - o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) - and o.memory_kind == orig_in_s.memory_kind): + # TODO(yashkatariya): Remove this condition and ask users to drop into + # explicit mode. + if (orig_aval is not None and out_aval is not None + and out_aval.ndim == orig_aval.ndim + and isinstance(orig_in_s, NamedSharding) + and out_aval.sharding.mesh == orig_in_s.mesh.abstract_mesh + and o.is_equivalent_to(orig_in_s, orig_aval.ndim)): out.append(orig_in_s) else: try: - out.append(orig_handler(o, orig_in_s)) + out.append(orig_handler(o, out_aval, orig_in_s)) except: out.append(o) else: @@ -2581,40 +2591,9 @@ def _get_out_sharding_from_orig_sharding( return out -def try_matching_out_with_in_spec_for_all_auto( - orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals): - recover_in_s, recover_in_aval = None, None - for in_s, in_aval in safe_zip(in_shardings, in_avals): - if isinstance(in_s, NamedSharding): - recover_in_s, recover_in_aval = in_s, in_aval - break - if recover_in_s is None: - return new_out_shardings - - res = [] - for orig_out_s, out_s, out_aval in safe_zip( - orig_out_shardings, new_out_shardings, out_avals): - if (out_aval is not core.abstract_token and - mlir.all_unconstrained(orig_out_s, out_aval) and - isinstance(orig_out_s, NamedSharding) and - isinstance(out_s, NamedSharding) and - orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and - out_aval.ndim == recover_in_aval.ndim and - out_s.is_equivalent_to(recover_in_s, out_aval.ndim)): - res.append(out_s.with_spec(recover_in_s.spec)) - else: - res.append(out_s) - return res - - def maybe_recover_user_shardings( old_shardings, new_shardings, old_avals, new_avals, - intermediate_shardings=None, context_mesh: Mesh | None = None, - orig_out_shardings=None): - if orig_out_shardings is not None: - new_shardings = try_matching_out_with_in_spec_for_all_auto( - orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals) - + intermediate_shardings=None, context_mesh: Mesh | None = None): if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings): return new_shardings @@ -2831,7 +2810,7 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) try: - new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # pytype: disable=wrong-arg-types + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pytype: disable=wrong-arg-types except: new_out_shardings.append(xla_s) else: @@ -3004,7 +2983,7 @@ class UnloadedMeshExecutable: out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals, - intermediate_shardings, context_mesh, orig_out_shardings) + intermediate_shardings, context_mesh) in_shardings = finalize_shardings(in_shardings, da) out_shardings = finalize_shardings(out_shardings, da) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 9c6cf0861..2ad3e089e 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -131,7 +131,7 @@ class ShardAlikeTest(jtu.JaxTestCase): return shard_alike(x, y)[1] out = f(inp) - self.assertEqual(out.sharding, s) + self.assertTrue(out.sharding.is_equivalent_to(s, out.ndim)) self.assertArraysEqual(out, np_inp) def test_shard_map(self): @@ -268,7 +268,8 @@ class ShardAlikeTest(jtu.JaxTestCase): x = jax.device_put(np.arange(8), s) _, y = shard_alike(x, jnp.arange(8)) - self.assertEqual(y.sharding, s) + self.assertTrue(y.sharding.is_equivalent_to(s, y.ndim)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())