diff --git a/jax/_src/array.py b/jax/_src/array.py index 71c8914a1..594873dac 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -926,13 +926,12 @@ def _array_shard_arg(x, sharding): pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg -def _array_global_result_handler(global_aval, out_sharding, committed, - is_out_sharding_from_xla): +def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( - global_aval, out_sharding, committed, is_out_sharding_from_xla) + global_aval, out_sharding, committed) return xc.array_result_handler( global_aval, out_sharding, committed=committed, _skip_checks=True ) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 3636176cd..8ea1a4ba4 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -322,7 +322,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): - result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False) + result_handler = pxla.global_aval_to_result_handler(aval, s, committed) return result_handler(pxla.shard_arg(x, s)) def _override_get_device_assignment(sharding, *args, **kwargs): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9c9ca6c0c..3b2065491 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -224,8 +224,7 @@ local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {} def global_aval_to_result_handler( - aval: core.AbstractValue, out_sharding, committed: bool, - is_out_sharding_from_xla: bool + aval: core.AbstractValue, out_sharding, committed: bool ) -> Callable[[Sequence[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -235,8 +234,6 @@ def global_aval_to_result_handler( Used for creating GSDAs. global_mesh: The global device mesh that generated this output. Used for creating GSDAs. - is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e. - the sharding is extracted from the HLO. Returns: A function for handling the Buffers that will eventually be produced @@ -244,8 +241,7 @@ def global_aval_to_result_handler( to the user, e.g. an Array. """ try: - return global_result_handlers[type(aval)]( - aval, out_sharding, committed, is_out_sharding_from_xla) + return global_result_handlers[type(aval)](aval, out_sharding, committed) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err @@ -1139,12 +1135,10 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding], - committed: bool, - are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler: + committed: bool) -> ResultsHandler: handlers = [ - global_aval_to_result_handler(global_aval, s, committed, x) - for global_aval, s, x in safe_zip(global_out_avals, shardings, - are_out_shardings_from_xla) + global_aval_to_result_handler(global_aval, s, committed) + for global_aval, s in safe_zip(global_out_avals, shardings) ] return ResultsHandler(handlers, shardings, global_out_avals) @@ -2010,12 +2004,6 @@ def lower_sharding_computation( if xla_extension_version < 240 or hasattr(backend, "compile_replicated"): in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) - # TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after - # output sharding of XLA is partially constrained on the trailing dimensions. - in_shardings = tuple( - gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended) - else i for i, a in safe_zip(in_shardings, global_in_avals)) - da_object = _create_da_object(tuple(device_assignment)) all_default_mem_kind = are_all_shardings_default_mem_kind( @@ -2466,11 +2454,10 @@ _register_out_sharding_handler( def _get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla): + out_shardings, out_avals, orig_in_s, orig_aval): out = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] - for o, out_aval, from_xla in safe_zip(out_shardings, out_avals, - are_out_sharding_from_xla): + for o, out_aval in safe_zip(out_shardings, out_avals): if isinstance(o, sharding_impls.GSPMDSharding): try: # Only return the same input sharding object if the OpShardings and @@ -2482,21 +2469,19 @@ def _get_out_sharding_from_orig_sharding( and sharding_impls.are_op_shardings_equal( o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and o.memory_kind == orig_in_s.memory_kind): - out.append((orig_in_s, False)) + out.append(orig_in_s) else: - out.append((orig_handler(o, orig_in_s), False)) + out.append(orig_handler(o, orig_in_s)) except: - out.append((o, from_xla)) + out.append(o) else: - out.append((o, from_xla)) + out.append(o) return out def maybe_get_orig_out_sharding( - in_shardings, out_shardings, are_out_shardings_from_xla, in_avals, - out_avals): + in_shardings, out_shardings, in_avals, out_avals): if all(hasattr(o, '_original_sharding') for o in out_shardings): - return ([o._original_sharding for o in out_shardings], - (False,) * len(out_shardings)) + return [o._original_sharding for o in out_shardings] orig_in_s = None orig_aval = None @@ -2507,10 +2492,10 @@ def maybe_get_orig_out_sharding( orig_aval = aval break if orig_in_s is not None: - return zip(*_get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla)) + return _get_out_sharding_from_orig_sharding( + out_shardings, out_avals, orig_in_s, orig_aval) - return out_shardings, are_out_shardings_from_xla + return out_shardings def _get_layouts_from_executable( @@ -2653,6 +2638,7 @@ def _maybe_get_and_check_in_shardings( if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) + xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_in_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -2677,17 +2663,17 @@ def _get_out_shardings_from_executable( xla_executable, device_assignment, len(global_out_avals), num_ordered_effects, all_default_mem_kind) # type: ignore if out_shardings_xla is None: - return out_shardings, (False,) * len(global_out_avals) + return out_shardings - new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore + new_out_shardings = [] for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): if is_unspecified(orig): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval) + xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s) new_out_shardings.append(xla_s) - are_out_shardings_from_xla.append(True) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore @@ -2700,17 +2686,14 @@ def _get_out_shardings_from_executable( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") new_out_shardings.append(orig) - are_out_shardings_from_xla.append(False) - return new_out_shardings, are_out_shardings_from_xla + return new_out_shardings -def finalize_out_shardings(out_shardings, are_out_shardings_from_xla, - device_assignment): +def finalize_out_shardings(out_shardings, device_assignment): if len(device_assignment) == 1: - return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings], - are_out_shardings_from_xla) - return out_shardings, are_out_shardings_from_xla + return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) + if isinstance(o, GSPMDSharding) else o for o in out_shardings] + return out_shardings @dataclasses.dataclass @@ -2723,7 +2706,6 @@ class UnloadedMeshExecutable: output_avals: Sequence[ShapedArray] output_shardings: Sequence[sharding_impls.XLACompatibleSharding] committed: bool - are_out_shardings_from_xla: Sequence[bool] name: str unordered_effects: list[core.Effect] ordered_effects: list[core.Effect] @@ -2744,8 +2726,7 @@ class UnloadedMeshExecutable: handle_args = InputsHandler( self.input_shardings, self.xla_executable.local_devices(), input_indices) handle_outs = global_avals_to_results_handler( - self.output_avals, self.output_shardings, self.committed, - self.are_out_shardings_from_xla) # type: ignore # arg-type + self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type unsafe_call = ExecuteReplicated( # type: ignore # assignment self.xla_executable, self.name, self.backend, handle_args, @@ -2833,11 +2814,8 @@ class UnloadedMeshExecutable: xla_executable, mesh) in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings_tuple = [ - (x, True) if is_auto(o) else (o, False) - for x, o in safe_zip(out_shardings_xla, out_shardings) - ] - out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) + out_shardings = [x if is_auto(o) else o + for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: assert mesh is None @@ -2845,13 +2823,12 @@ class UnloadedMeshExecutable: in_shardings = _maybe_get_and_check_in_shardings( xla_executable, in_shardings, tuple(da), global_in_avals, len(ordered_effects)) - out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable( + out_shardings = _get_out_shardings_from_executable( xla_executable, out_shardings, tuple(da), global_out_avals, len(ordered_effects), all_default_mem_kind) else: in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - are_out_shardings_from_xla = (False,) * len(global_out_avals) if xla_extension_version >= 217: in_layouts, out_layouts = _get_layouts_from_executable( @@ -2860,12 +2837,10 @@ class UnloadedMeshExecutable: assert all(i is None for i in in_layouts) assert all(o is None for o in out_layouts) - out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding( - in_shardings, out_shardings, are_out_shardings_from_xla, - global_in_avals, global_out_avals) + out_shardings = maybe_get_orig_out_sharding( + in_shardings, out_shardings, global_in_avals, global_out_avals) - out_shardings, are_out_shardings_from_xla = finalize_out_shardings( - out_shardings, are_out_shardings_from_xla, da) + out_shardings = finalize_out_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, @@ -2876,7 +2851,6 @@ class UnloadedMeshExecutable: output_avals=global_out_avals, output_shardings=out_shardings, # type: ignore # arg-type committed=committed, - are_out_shardings_from_xla=are_out_shardings_from_xla, name=name, unordered_effects=unordered_effects, ordered_effects=ordered_effects, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index eb62aea1c..315cbc4fa 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5082,8 +5082,7 @@ class BIntRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] @@ -5091,8 +5090,7 @@ class BIntRules: raise NotImplementedError # TODO(mattjj) else: phys_sharding = out_sharding - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, - is_out_sharding_from_xla) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return core.DArray(aval, phys_handler(bufs)) @@ -5102,6 +5100,10 @@ class BIntRules: def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: return hlo_sharding + @staticmethod + def logical_op_sharding(aval, phys_sharding): + return phys_sharding + @staticmethod def convert_from(bint_dtype, other_dtype) -> bool: return other_dtype in (np.dtype('int32'), np.dtype('int64')) diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index 74d4a6320..ee3ea6b44 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -25,7 +25,7 @@ from jax._src.lib import xla_client as xc def get_num_ways_dim_sharded( - hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]: + hlo_sharding: xc.HloSharding) -> tuple[list[int], int]: if hlo_sharding.is_replicated(): # type: ignore return [], 1 partitions = hlo_sharding.tile_assignment_dimensions() @@ -42,7 +42,7 @@ def get_num_ways_dim_sharded( if replicate_on_last_tile_dim: num_replicas = partitions[-1] partitions = partitions[:-1] - return partitions, num_replicas + return list(partitions), num_replicas def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0ed8a55d2..ed1b1142e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1701,8 +1701,8 @@ def _pjit_batcher_for_sharding( else: new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore if hasattr(s, '_original_sharding'): - vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding( - [new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore + vmapped_s = pxla._get_out_sharding_from_orig_sharding( + [new_gs], [None], s._original_sharding, None)[0] # type: ignore new_gs = to_gspmd_sharding(vmapped_s, ndim) return new_gs else: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 73977d38a..a1f2906f7 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -318,7 +318,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): base_ndim = len(impl.key_shape) return base_arr_shape[:-base_ndim] -def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): +def make_key_array_phys_sharding(aval, sharding): if dispatch.is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): @@ -335,8 +335,6 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): return NamedSharding( sharding.mesh, PartitionSpec(*sharding.spec, *trailing_spec)) - elif is_sharding_from_xla: - return sharding else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( @@ -367,11 +365,11 @@ class KeyTyRules: @staticmethod def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: key_shape = aval.dtype._impl.key_shape - op_sharding_proto = hlo_sharding.to_proto() # type: ignore - new_op_sharding = op_sharding_proto.clone() - tad = list(new_op_sharding.tile_assignment_dimensions) - suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else [] - tad.extend([1] * len(key_shape) + suffix) + new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * len(key_shape) + suffix new_op_sharding.tile_assignment_dimensions = tad return xc.HloSharding.from_proto(new_op_sharding) @@ -393,11 +391,14 @@ class KeyTyRules: PartitionSpec(*phys_sharding.spec[:-len(key_shape)])) else: key_shape = aval.dtype._impl.key_shape - phys_op_sharding = phys_sharding._to_xla_hlo_sharding( - aval.ndim + len(key_shape)).to_proto() - logical_op_sharding = phys_op_sharding.clone() - tad = list(logical_op_sharding.tile_assignment_dimensions) - tad = tad[:-len(key_shape)] + phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( + aval.ndim + len(key_shape)) + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + phys_hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + # Create logical sharding by cutting off the replicated trailing dims. + logical_op_sharding = phys_hlo_sharding.to_proto().clone() + tad = partitions[:-len(key_shape)] + suffix logical_op_sharding.tile_assignment_dimensions = tad return GSPMDSharding(phys_sharding._device_assignment, xc.HloSharding.from_proto(logical_op_sharding)) @@ -417,8 +418,7 @@ class KeyTyRules: # set up a grounded sharding (with a grounded sharding spec) if isinstance(sharding, (PmapSharding, NamedSharding)): - phys_sharding = make_key_array_phys_sharding( - aval, sharding, is_sharding_from_xla=False) + phys_sharding = make_key_array_phys_sharding(aval, sharding) else: assert False, f'impossible sharding {sharding} in local sharded result handler' @@ -436,15 +436,12 @@ class KeyTyRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - phys_sharding = make_key_array_phys_sharding( - aval, out_sharding, is_out_sharding_from_xla) - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, - is_out_sharding_from_xla) + phys_sharding = make_key_array_phys_sharding(aval, out_sharding) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) return handler @@ -455,8 +452,8 @@ class KeyTyRules: phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_arrays = [random_unwrap(arr) for arr in arrays] - phys_sharding = make_key_array_phys_sharding(aval, sharding, False) - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False) + phys_sharding = make_key_array_phys_sharding(aval, sharding) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) @@ -464,7 +461,7 @@ class KeyTyRules: def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) physical_buffers = tree_util.tree_map(random_unwrap, vals) - physical_sharding = make_key_array_phys_sharding(aval, sharding, False) + physical_sharding = make_key_array_phys_sharding(aval, sharding) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices)) return random_wrap(physical_result, impl=aval.dtype._impl) @@ -473,7 +470,7 @@ class KeyTyRules: physical_aval = core.physical_aval(aval) assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) - physical_sharding = make_key_array_phys_sharding(aval, sharding, False) + physical_sharding = make_key_array_phys_sharding(aval, sharding) physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices) return random_wrap(physical_result, impl=aval.dtype._impl) @@ -554,8 +551,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): arr = x._base_array - phys_sharding = make_key_array_phys_sharding( - x.aval, sharding, is_sharding_from_xla=False) + phys_sharding = make_key_array_phys_sharding(x.aval, sharding) return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) diff --git a/tests/lax_test.py b/tests/lax_test.py index 7f0c5a3a9..6e908d51a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2977,6 +2977,10 @@ class FooTyRules: new_op_sharding.tile_assignment_dimensions = [*tad, 1] return xc.HloSharding.from_proto(new_op_sharding) + @staticmethod + def logical_op_sharding(aval, phys_sharding): + return phys_sharding + @staticmethod def result_handler(sticky_device, aval): def handler(_, buf): @@ -2985,8 +2989,7 @@ class FooTyRules: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed, - is_out_sharding_from_xla): + def global_sharded_result_handler(aval, out_sharding, committed): def handler(arr): from jax._src.array import ArrayImpl if isinstance(arr, ArrayImpl):