From 217f08236ebddeb359b7a27731341b60a3cee17f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Feb 2024 15:21:50 -0800 Subject: [PATCH] Allow sharding propagation to input for prng keys whose sharding is not specified. Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`. PiperOrigin-RevId: 611246986 --- jax/_src/array.py | 5 +- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/pxla.py | 90 +++++++++++++---------------------- jax/_src/lax/lax.py | 10 ++-- jax/_src/op_shardings.py | 4 +- jax/_src/pjit.py | 4 +- jax/_src/prng.py | 50 +++++++++---------- tests/lax_test.py | 7 ++- 8 files changed, 73 insertions(+), 99 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 71c8914a1..594873dac 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -926,13 +926,12 @@ def _array_shard_arg(x, sharding): pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg -def _array_global_result_handler(global_aval, out_sharding, committed, - is_out_sharding_from_xla): +def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( - global_aval, out_sharding, committed, is_out_sharding_from_xla) + global_aval, out_sharding, committed) return xc.array_result_handler( global_aval, out_sharding, committed=committed, _skip_checks=True ) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 3636176cd..8ea1a4ba4 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -322,7 +322,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): - result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False) + result_handler = pxla.global_aval_to_result_handler(aval, s, committed) return result_handler(pxla.shard_arg(x, s)) def _override_get_device_assignment(sharding, *args, **kwargs): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9c9ca6c0c..3b2065491 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -224,8 +224,7 @@ local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {} def global_aval_to_result_handler( - aval: core.AbstractValue, out_sharding, committed: bool, - is_out_sharding_from_xla: bool + aval: core.AbstractValue, out_sharding, committed: bool ) -> Callable[[Sequence[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -235,8 +234,6 @@ def global_aval_to_result_handler( Used for creating GSDAs. global_mesh: The global device mesh that generated this output. Used for creating GSDAs. - is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e. - the sharding is extracted from the HLO. Returns: A function for handling the Buffers that will eventually be produced @@ -244,8 +241,7 @@ def global_aval_to_result_handler( to the user, e.g. an Array. """ try: - return global_result_handlers[type(aval)]( - aval, out_sharding, committed, is_out_sharding_from_xla) + return global_result_handlers[type(aval)](aval, out_sharding, committed) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err @@ -1139,12 +1135,10 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding], - committed: bool, - are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler: + committed: bool) -> ResultsHandler: handlers = [ - global_aval_to_result_handler(global_aval, s, committed, x) - for global_aval, s, x in safe_zip(global_out_avals, shardings, - are_out_shardings_from_xla) + global_aval_to_result_handler(global_aval, s, committed) + for global_aval, s in safe_zip(global_out_avals, shardings) ] return ResultsHandler(handlers, shardings, global_out_avals) @@ -2010,12 +2004,6 @@ def lower_sharding_computation( if xla_extension_version < 240 or hasattr(backend, "compile_replicated"): in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) - # TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after - # output sharding of XLA is partially constrained on the trailing dimensions. - in_shardings = tuple( - gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) - else i for i, a in safe_zip(in_shardings, global_in_avals)) - da_object = _create_da_object(tuple(device_assignment)) all_default_mem_kind = are_all_shardings_default_mem_kind( @@ -2466,11 +2454,10 @@ _register_out_sharding_handler( def _get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla): + out_shardings, out_avals, orig_in_s, orig_aval): out = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] - for o, out_aval, from_xla in safe_zip(out_shardings, out_avals, - are_out_sharding_from_xla): + for o, out_aval in safe_zip(out_shardings, out_avals): if isinstance(o, sharding_impls.GSPMDSharding): try: # Only return the same input sharding object if the OpShardings and @@ -2482,21 +2469,19 @@ def _get_out_sharding_from_orig_sharding( and sharding_impls.are_op_shardings_equal( o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and o.memory_kind == orig_in_s.memory_kind): - out.append((orig_in_s, False)) + out.append(orig_in_s) else: - out.append((orig_handler(o, orig_in_s), False)) + out.append(orig_handler(o, orig_in_s)) except: - out.append((o, from_xla)) + out.append(o) else: - out.append((o, from_xla)) + out.append(o) return out def maybe_get_orig_out_sharding( - in_shardings, out_shardings, are_out_shardings_from_xla, in_avals, - out_avals): + in_shardings, out_shardings, in_avals, out_avals): if all(hasattr(o, '_original_sharding') for o in out_shardings): - return ([o._original_sharding for o in out_shardings], - (False,) * len(out_shardings)) + return [o._original_sharding for o in out_shardings] orig_in_s = None orig_aval = None @@ -2507,10 +2492,10 @@ def maybe_get_orig_out_sharding( orig_aval = aval break if orig_in_s is not None: - return zip(*_get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla)) + return _get_out_sharding_from_orig_sharding( + out_shardings, out_avals, orig_in_s, orig_aval) - return out_shardings, are_out_shardings_from_xla + return out_shardings def _get_layouts_from_executable( @@ -2653,6 +2638,7 @@ def _maybe_get_and_check_in_shardings( if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) + xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -2677,17 +2663,17 @@ def _get_out_shardings_from_executable( xla_executable, device_assignment, len(global_out_avals), num_ordered_effects, all_default_mem_kind) # type: ignore if out_shardings_xla is None: - return out_shardings, (False,) * len(global_out_avals) + return out_shardings - new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore + new_out_shardings = [] for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) + xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_out_shardings.append(xla_s) - are_out_shardings_from_xla.append(True) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -2700,17 +2686,14 @@ def _get_out_shardings_from_executable( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") new_out_shardings.append(orig) - are_out_shardings_from_xla.append(False) - return new_out_shardings, are_out_shardings_from_xla + return new_out_shardings -def finalize_out_shardings(out_shardings, are_out_shardings_from_xla, - device_assignment): +def finalize_out_shardings(out_shardings, device_assignment): if len(device_assignment) == 1: - return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings], - are_out_shardings_from_xla) - return out_shardings, are_out_shardings_from_xla + return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) + if isinstance(o, GSPMDSharding) else o for o in out_shardings] + return out_shardings @dataclasses.dataclass @@ -2723,7 +2706,6 @@ class UnloadedMeshExecutable: output_avals: Sequence[ShapedArray] output_shardings: Sequence[sharding_impls.XLACompatibleSharding] committed: bool - are_out_shardings_from_xla: Sequence[bool] name: str unordered_effects: list[core.Effect] ordered_effects: list[core.Effect] @@ -2744,8 +2726,7 @@ class UnloadedMeshExecutable: handle_args = InputsHandler( self.input_shardings, self.xla_executable.local_devices(), input_indices) handle_outs = global_avals_to_results_handler( - self.output_avals, self.output_shardings, self.committed, - self.are_out_shardings_from_xla) # type: ignore # arg-type + self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type unsafe_call = ExecuteReplicated( # type: ignore # assignment self.xla_executable, self.name, self.backend, handle_args, @@ -2833,11 +2814,8 @@ class UnloadedMeshExecutable: xla_executable, mesh) in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings_tuple = [ - (x, True) if is_auto(o) else (o, False) - for x, o in safe_zip(out_shardings_xla, out_shardings) - ] - out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) + out_shardings = [x if is_auto(o) else o + for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: assert mesh is None @@ -2845,13 +2823,12 @@ class UnloadedMeshExecutable: in_shardings = _maybe_get_and_check_in_shardings( xla_executable, in_shardings, tuple(da), global_in_avals, len(ordered_effects)) - out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable( + out_shardings = _get_out_shardings_from_executable( xla_executable, out_shardings, tuple(da), global_out_avals, len(ordered_effects), all_default_mem_kind) else: in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - are_out_shardings_from_xla = (False,) * len(global_out_avals) if xla_extension_version >= 217: in_layouts, out_layouts = _get_layouts_from_executable( @@ -2860,12 +2837,10 @@ class UnloadedMeshExecutable: assert all(i is None for i in in_layouts) assert all(o is None for o in out_layouts) - out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding( - in_shardings, out_shardings, are_out_shardings_from_xla, - global_in_avals, global_out_avals) + out_shardings = maybe_get_orig_out_sharding( + in_shardings, out_shardings, global_in_avals, global_out_avals) - out_shardings, are_out_shardings_from_xla = finalize_out_shardings( - out_shardings, are_out_shardings_from_xla, da) + out_shardings = finalize_out_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, @@ -2876,7 +2851,6 @@ class UnloadedMeshExecutable: output_avals=global_out_avals, output_shardings=out_shardings, # type: ignore # arg-type committed=committed, - are_out_shardings_from_xla=are_out_shardings_from_xla, name=name, unordered_effects=unordered_effects, ordered_effects=ordered_effects, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index eb62aea1c..315cbc4fa 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5082,8 +5082,7 @@ class BIntRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] @@ -5091,8 +5090,7 @@ class BIntRules: raise NotImplementedError # TODO(mattjj) else: phys_sharding = out_sharding - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, - is_out_sharding_from_xla) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return core.DArray(aval, phys_handler(bufs)) @@ -5102,6 +5100,10 @@ class BIntRules: def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: return hlo_sharding + @staticmethod + def logical_op_sharding(aval, phys_sharding): + return phys_sharding + @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index 74d4a6320..ee3ea6b44 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -25,7 +25,7 @@ from jax._src.lib import xla_client as xc def get_num_ways_dim_sharded( - hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]: + hlo_sharding: xc.HloSharding) -> tuple[list[int], int]: if hlo_sharding.is_replicated(): # type: ignore return [], 1 partitions = hlo_sharding.tile_assignment_dimensions() @@ -42,7 +42,7 @@ def get_num_ways_dim_sharded( if replicate_on_last_tile_dim: num_replicas = partitions[-1] partitions = partitions[:-1] - return partitions, num_replicas + return list(partitions), num_replicas def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0ed8a55d2..ed1b1142e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1701,8 +1701,8 @@ def _pjit_batcher_for_sharding( else: new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore if hasattr(s, '_original_sharding'): - vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding( - [new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore + vmapped_s = pxla._get_out_sharding_from_orig_sharding( + [new_gs], [None], s._original_sharding, None)[0] # type: ignore new_gs = to_gspmd_sharding(vmapped_s, ndim) return new_gs else: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 73977d38a..a1f2906f7 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -318,7 +318,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] -def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): +def make_key_array_phys_sharding(aval, sharding): if dispatch.is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): @@ -335,8 +335,6 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): return NamedSharding( sharding.mesh, PartitionSpec(*sharding.spec, *trailing_spec)) - elif is_sharding_from_xla: - return sharding else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( @@ -367,11 +365,11 @@ class KeyTyRules: @staticmethod def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: key_shape = aval.dtype._impl.key_shape - op_sharding_proto = hlo_sharding.to_proto() # type: ignore - new_op_sharding = op_sharding_proto.clone() - tad = list(new_op_sharding.tile_assignment_dimensions) - suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else [] - tad.extend([1] * len(key_shape) + suffix) + new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * len(key_shape) + suffix new_op_sharding.tile_assignment_dimensions = tad return xc.HloSharding.from_proto(new_op_sharding) @@ -393,11 +391,14 @@ class KeyTyRules: PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) else: key_shape = aval.dtype._impl.key_shape - phys_op_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)).to_proto() - logical_op_sharding = phys_op_sharding.clone() - tad = list(logical_op_sharding.tile_assignment_dimensions) - tad = tad[:-len(key_shape)] + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + len(key_shape)) + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-len(key_shape)] + suffix logical_op_sharding.tile_assignment_dimensions = tad return GSPMDSharding(phys_sharding._device_assignment, xc.HloSharding.from_proto(logical_op_sharding)) @@ -417,8 +418,7 @@ class KeyTyRules: # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): - phys_sharding = make_key_array_phys_sharding( - aval, sharding, is_sharding_from_xla=False) + phys_sharding = make_key_array_phys_sharding(aval, sharding) else: assert False, f'impossible sharding {sharding} in local sharded result handler' @@ -436,15 +436,12 @@ class KeyTyRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - phys_sharding = make_key_array_phys_sharding( - aval, out_sharding, is_out_sharding_from_xla) - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, - is_out_sharding_from_xla) + phys_sharding = make_key_array_phys_sharding(aval, out_sharding) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) return handler @@ -455,8 +452,8 @@ class KeyTyRules: phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] - phys_sharding = make_key_array_phys_sharding(aval, sharding, False) - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False) + phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) @@ -464,7 +461,7 @@ class KeyTyRules: def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) - physical_sharding = make_key_array_phys_sharding(aval, sharding, False) + physical_sharding = make_key_array_phys_sharding(aval, sharding) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype._impl) @@ -473,7 +470,7 @@ class KeyTyRules: physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) - physical_sharding = make_key_array_phys_sharding(aval, sharding, False) + physical_sharding = make_key_array_phys_sharding(aval, sharding) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) @@ -554,8 +551,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): arr = x._base_array - phys_sharding = make_key_array_phys_sharding( - x.aval, sharding, is_sharding_from_xla=False) + phys_sharding = make_key_array_phys_sharding(x.aval, sharding) return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) diff --git a/tests/lax_test.py b/tests/lax_test.py index 7f0c5a3a9..6e908d51a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2977,6 +2977,10 @@ class FooTyRules: new_op_sharding.tile_assignment_dimensions = [*tad, 1] return xc.HloSharding.from_proto(new_op_sharding) + @staticmethod + def logical_op_sharding(aval, phys_sharding): + return phys_sharding + @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -2985,8 +2989,7 @@ class FooTyRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): def handler(arr): from jax._src.array import ArrayImpl if isinstance(arr, ArrayImpl):