From 5d2f4530940f247fc7a4a6bc7e0f2b41fa39f7fb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sun, 9 Apr 2023 15:41:32 -0700 Subject: [PATCH] Preserve shardings on the output of pjit that were provided on the arguments. Following are the changes: * Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function. * Split lower_sharding_computation into 3 caches: * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal. The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user. For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input. PiperOrigin-RevId: 522991145 --- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/pxla.py | 668 ++++++++++++++++++++++------------ jax/_src/pjit.py | 28 +- tests/pjit_test.py | 238 +++++++++++- 4 files changed, 670 insertions(+), 266 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index bdd2a7325..26a5aed73 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -218,7 +218,7 @@ def sharded_lowering(fun, name, donated_invars, keep_unused, # apply it to all out_avals. return pxla.lower_sharding_computation( fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars, - in_avals, keep_unused=keep_unused, always_lower=False, + tuple(in_avals), keep_unused=keep_unused, always_lower=False, devices_from_context=None, lowering_platform=lowering_platform) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e8e81bc11..1aba390b8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -63,7 +63,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, - unzip2, HashableFunction) + unzip2, HashableFunction, weakref_lru_cache) # Built in Python lists don't support weak refs but subclasses of lists do. @@ -1474,7 +1474,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0): devices) - class ExecuteReplicated: """The logic to shard inputs, execute a replicated model, returning outputs.""" __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', @@ -1997,38 +1996,28 @@ def _get_and_check_device_assignment( final_device_assignment = first_sharding_info[0] return xb.get_device_backend(final_device_assignment[0]), final_device_assignment - MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue] -@profiler.annotate_function -def lower_sharding_computation( - fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr], - api_name: str, - fun_name: str, - in_shardings: Sequence[MaybeSharding], - out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue], - donated_invars: Sequence[bool], - global_in_avals: Sequence[core.ShapedArray], - *, - keep_unused: bool, - always_lower: bool, - devices_from_context: Optional[Sequence[xc.Device]] = None, - lowering_platform: Optional[str], -) -> MeshComputation: - """Lowers a computation to XLA. It can take arbitrary shardings as input. +def cache_wrap(fn): + _wrapped_with_lu_cache = lu.cache(fn) + _wrapped_with_weakref_lru_cache = weakref_lru_cache(fn) + def wrapped(f, *args, **kwargs): + if isinstance(f, lu.WrappedFun): + return _wrapped_with_lu_cache(f, *args, **kwargs) + else: + return _wrapped_with_weakref_lru_cache(f, *args, **kwargs) + return wrapped - The caller of this code can pass in a singleton _UNSPECIFIED because the - number of out_avals might not be known at that time and - lower_sharding_computation calculates the number of out_avals so it can apply - the singleton _UNSPECIFIED to all out_avals. - """ - # 1. Trace to jaxpr and preprocess/verify it + +@cache_wrap +def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name, + keep_unused, donated_invars): name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) if isinstance(fun_or_jaxpr, lu.WrappedFun): - with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec", - event=dispatch.JAXPR_TRACE_EVENT): + with dispatch.log_elapsed_time( + f"Finished tracing + transforming {name_stack} " + "in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun_or_jaxpr, global_in_avals) else: @@ -2045,41 +2034,44 @@ def lower_sharding_computation( jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx) - in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) del kept_const_idx jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars, + kept_var_idx, name_stack) + + +@dataclasses.dataclass(frozen=True) +class SemanticallyEqualShardings: + shardings: Tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...] + + def __hash__(self): + return hash(tuple( + s._op_sharding_hash if isinstance(s, sharding_impls.GSPMDSharding) else s # type: ignore + for s in self.shardings)) + + def __eq__(self, other): + if not isinstance(other, SemanticallyEqualShardings): + return False + return all(op_shardings.are_op_shardings_equal(s._op_sharding, o._op_sharding) + if (isinstance(s, sharding_impls.GSPMDSharding) and + isinstance(o, sharding_impls.GSPMDSharding)) + else s == o for s, o in zip(self.shardings, other.shardings)) + + +@weakref_lru_cache +def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, + semantic_in_shardings, semantic_out_shardings, + da_object, lowering_platform, + donated_invars, name_stack): jaxpr = closed_jaxpr.jaxpr - - kept_outputs = [True] * len(global_out_avals) - - if _is_unspecified(out_shardings): - out_shardings = (_UNSPECIFIED,) * len(global_out_avals) - assert isinstance(out_shardings, tuple) - assert len(out_shardings) == len(global_out_avals), ( - len(out_shardings), len(global_out_avals)) - - # Device assignment across all inputs, outputs and shardings inside jaxpr - # should be the same. - jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr)) - backend, device_assignment = _get_and_check_device_assignment( - it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings], - [(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], - [(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) - for js, source_info in jaxpr_sharding]), - devices_from_context) - - committed = bool( - devices_from_context or - len(device_assignment) > 1 or - any(not _is_unspecified(i) for i in in_shardings) or - any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or - any(not _is_unspecified(o) for o in out_shardings)) - - in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment) - if _is_unspecified(i) else i for i in in_shardings) + in_shardings = semantic_in_shardings.shardings + out_shardings = semantic_out_shardings.shardings + global_in_avals = closed_jaxpr.in_avals + global_out_avals = closed_jaxpr.out_avals + device_assignment = da_object.device_assignment log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logger.log(log_priority, @@ -2087,55 +2079,14 @@ def lower_sharding_computation( "Argument mapping: %s.", fun_name, global_in_avals, in_shardings) - local_device_assignment = [d for d in device_assignment - if d.process_index == d.client.process_index()] - if len(device_assignment) != len(local_device_assignment): - check_multihost_collective_allowlist(jaxpr) - # TODO(yashkatariya): Once jit and pjit's frontend is merged, use the - # argument on jit `_allow_multiprocess` (which will be added later) instead - # of the `api_name` check here. - # Furthermore, `allow_jit` is not allowed yet because `allow_jit` only - # allows explicit `jax.jit` to work but not implicitly jitted `jnp`. - # operations. This restriction will be relaxed in the future when the - # default value of `spmd_mode` config changes to `allow_jit`. - if api_name == 'jit' and config.jax_spmd_mode != 'allow_all': - raise RuntimeError( - "Running operations on `Array`s that are not fully addressable by this " - "process (i.e. `Array`s with data sharded across multiple devices and " - "processes.) is dangerous. It’s very important that all processes run " - "the same cross-process computations in the same order otherwise it " - "can lead to hangs. " - "If you’re not already familiar with JAX’s multi-process " - "programming model, please read " - "https://jax.readthedocs.io/en/latest/multi_process.html. " - "To fix this error, run your `jitted` computation inside " - "`with jax.spmd_mode('allow_all'):` context manager.") - - has_outfeed = core.jaxpr_uses_outfeed(jaxpr) - - # Computations that only produce constants and/or only rearrange their inputs, - # which are often produced from partial evaluation, don't need compilation, - # and don't need to evaluate their arguments. - if (not always_lower and not (jaxpr.effects or has_outfeed) and - (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and - all(_is_unspecified(o) for o in out_shardings)): - return MeshComputation( - str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts, - global_in_avals=global_in_avals, global_out_avals=global_out_avals, - in_shardings=in_shardings, backend=backend, - device_assignment=device_assignment, committed=committed, - kept_var_idx=kept_var_idx, keepalive=None) - # Look at the number of replcas present in the jaxpr. In # lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is # handled here so as to deprecate the lower_xla_callable codepath when # `jax.Array` is turned on by default. # TODO(yashkatariya): Remove this when `jit(pmap)` is removed. nreps = dispatch.jaxpr_replicas(jaxpr) - dispatch.raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr) - - # 2. Build up the HLO - tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) + dispatch.raise_warnings_or_errors_for_jit_of_pmap( + nreps, backend, fun_name, jaxpr) in_op_shardings: Optional[List[Optional[xc.OpSharding]]] out_op_shardings: Optional[List[Optional[xc.OpSharding]]] @@ -2179,10 +2130,152 @@ def lower_sharding_computation( result_shardings=out_op_shardings, arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths) - module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) + tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) + return (module, keepalive, host_callbacks, unordered_effects, + ordered_effects, nreps, tuple_args) + + +@dataclasses.dataclass(frozen=True) +class _DeviceAssignment: + device_assignment: Tuple[xc.Device, ...] + + @cached_property + def _hash(self): + return hash(self.device_assignment) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, _DeviceAssignment): + return False + if id(self) == id(other): + return True + return (self.device_assignment == other.device_assignment) + + @cached_property + def is_fully_addressable(self): + return len(self.device_assignment) == len(self.addressable_device_assignment) + + @cached_property + def addressable_device_assignment(self): + return [d for d in self.device_assignment + if d.process_index == d.client.process_index()] + + +@lru_cache(maxsize=2048) +def _create_da_object( + device_assignment: Tuple[xc.Device, ...]) -> _DeviceAssignment: + return _DeviceAssignment(device_assignment) + + +@profiler.annotate_function +def lower_sharding_computation( + fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr], + api_name: str, + fun_name: str, + in_shardings: Sequence[MaybeSharding], + out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue], + donated_invars: Sequence[bool], + global_in_avals: Sequence[core.ShapedArray], + *, + keep_unused: bool, + always_lower: bool, + devices_from_context: Optional[Sequence[xc.Device]] = None, + lowering_platform: Optional[str], +) -> MeshComputation: + """Lowers a computation to XLA. It can take arbitrary shardings as input. + + The caller of this code can pass in a singleton _UNSPECIFIED because the + number of out_avals might not be known at that time and + lower_sharding_computation calculates the number of out_avals so it can apply + the singleton _UNSPECIFIED to all out_avals. + """ + # 1. Trace to jaxpr and preprocess/verify it + (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, + kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce( + fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused, + donated_invars) + jaxpr = closed_jaxpr.jaxpr + in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) + + if _is_unspecified(out_shardings): + out_shardings = (_UNSPECIFIED,) * len(global_out_avals) + assert isinstance(out_shardings, tuple) + assert len(out_shardings) == len(global_out_avals), ( + len(out_shardings), len(global_out_avals)) + + # Device assignment across all inputs, outputs and shardings inside jaxpr + # should be the same. + jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr)) + backend, device_assignment = _get_and_check_device_assignment( + it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings], + [(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], + [(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) + for js, source_info in jaxpr_sharding]), + devices_from_context) + + committed = bool( + devices_from_context or + len(device_assignment) > 1 or + any(not _is_unspecified(i) for i in in_shardings) or + any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or + any(not _is_unspecified(o) for o in out_shardings)) + + in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment) + if _is_unspecified(i) else i for i in in_shardings) + + da_object = _create_da_object(tuple(device_assignment)) + + if not da_object.is_fully_addressable: + check_multihost_collective_allowlist(jaxpr) + # TODO(yashkatariya): Once jit and pjit's frontend is merged, use the + # argument on jit `_allow_multiprocess` (which will be added later) instead + # of the `api_name` check here. + # Furthermore, `allow_jit` is not allowed yet because `allow_jit` only + # allows explicit `jax.jit` to work but not implicitly jitted `jnp`. + # operations. This restriction will be relaxed in the future when the + # default value of `spmd_mode` config changes to `allow_jit`. + if api_name == 'jit' and config.jax_spmd_mode != 'allow_all': + raise RuntimeError( + "Running operations on `Array`s that are not fully addressable by this " + "process (i.e. `Array`s with data sharded across multiple devices and " + "processes.) is dangerous. It’s very important that all processes run " + "the same cross-process computations in the same order otherwise it " + "can lead to hangs. " + "If you’re not already familiar with JAX’s multi-process " + "programming model, please read " + "https://jax.readthedocs.io/en/latest/multi_process.html. " + "To fix this error, run your `jitted` computation inside " + "`with jax.spmd_mode('allow_all'):` context manager.") + + has_outfeed = core.jaxpr_uses_outfeed(jaxpr) + kept_outputs = [True] * len(global_out_avals) + + # Computations that only produce constants and/or only rearrange their inputs, + # which are often produced from partial evaluation, don't need compilation, + # and don't need to evaluate their arguments. + if (not always_lower and not (jaxpr.effects or has_outfeed) and + (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and + all(_is_unspecified(o) for o in out_shardings)): + return MeshComputation( + str(name_stack), None, True, donated_invars, jaxpr=jaxpr, + consts=closed_jaxpr.consts, global_in_avals=global_in_avals, + global_out_avals=global_out_avals, in_shardings=in_shardings, + backend=backend, da_object=da_object, + committed=committed, kept_var_idx=kept_var_idx, keepalive=None) + + # 2. Build up the HLO + semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore + semantic_out_shardings = SemanticallyEqualShardings(out_shardings) + (module, keepalive, host_callbacks, unordered_effects, ordered_effects, + nreps, tuple_args) = _cached_lowering_to_hlo( + closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, + semantic_out_shardings, da_object, lowering_platform, + donated_invars, name_stack) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way @@ -2208,10 +2301,11 @@ def lower_sharding_computation( keepalive=keepalive, kept_var_idx=kept_var_idx, backend=backend, - device_assignment=device_assignment, + device_assignment=da_object, committed=committed, pmap_nreps=nreps) + def _to_logical_op_sharding( aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource] ) -> Optional[xc.OpSharding]: @@ -2382,7 +2476,7 @@ def lower_mesh_computation( keepalive=keepalive, kept_var_idx=set(range(len(global_in_avals))), backend=backend, - device_assignment=list(mesh.devices.flat), + device_assignment=_create_da_object(tuple(mesh.devices.flat)), committed=True) class MeshComputation(stages.XlaLowering): @@ -2522,6 +2616,117 @@ def _get_mesh_pspec_shardings_from_executable( [sharding_impls.NamedSharding(mesh, o) for o in out_pspec]) +def _get_out_sharding_from_named_sharding( + out_shardings, ns, are_out_sharding_from_xla): + from jax._src import pjit + out = [] + for o, from_xla in safe_zip(out_shardings, are_out_sharding_from_xla): + if isinstance(o, sharding_impls.GSPMDSharding): + try: + out.append((sharding_impls.NamedSharding._from_parsed_pspec( + ns.mesh, pjit.parse_flatten_op_sharding(o._op_sharding, ns.mesh)[0]), False)) + except: + out.append((o, from_xla)) + else: + out.append((o, from_xla)) + return out + + +def maybe_get_orig_out_sharding( + in_shardings, out_shardings, are_out_shardings_from_xla): + if all(hasattr(o, '_original_sharding') for o in out_shardings): + return ([o._original_sharding for o in out_shardings], + (False,) * len(out_shardings)) + + # TODO(yashkatariya): Handle other shardings too here. + ns = None + for i in in_shardings: + oi = getattr(i, '_original_sharding', None) + if isinstance(oi, sharding_impls.NamedSharding): + ns = oi + break + if ns is not None: + return zip(*_get_out_sharding_from_named_sharding( + out_shardings, ns, are_out_shardings_from_xla)) + + return out_shardings, are_out_shardings_from_xla + + +@weakref_lru_cache +def _cached_compilation(computation, name, mesh, num_out_avals, spmd_lowering, + tuple_args, auto_spmd_lowering, + _allow_propagation_to_outputs, host_callbacks, backend, + da, pmap_nreps, compiler_options_keys, + compiler_options_values): + device_assignment = da.device_assignment if isinstance( + da, _DeviceAssignment) else da + + dev: np.ndarray + if auto_spmd_lowering: + assert mesh is not None and spmd_lowering + dev = mesh.devices + num_replicas, num_partitions = 1, mesh.size + else: + # TODO(phawkins): One would normally just write: + # dev = np.array(device_assignment) + # The formulation below is substantially faster if there are many devices. + # If we were to optimize __getattr__ on xc.Device we might not need this + # workaround. + dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])( + np.arange(len(device_assignment)) + ) + if pmap_nreps > 1: + num_replicas, num_partitions = pmap_nreps, 1 + elif spmd_lowering: + num_replicas, num_partitions = 1, dev.size + else: + num_replicas, num_partitions = dev.size, 1 + + if pmap_nreps > 1: + # In `jit` device_assignment is set to None when num_replicas > 1. Do + # the same thing here too. + xla_device_assignment = None + else: + xla_device_assignment = dev.reshape((num_replicas, num_partitions)) + + if compiler_options_keys is None: + compiler_options = None + else: + compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + + compile_options = xb.get_compile_options( + num_replicas=num_replicas, + num_partitions=num_partitions, + device_assignment=xla_device_assignment, + use_spmd_partitioning=spmd_lowering, + use_auto_spmd_partitioning=auto_spmd_lowering, + env_options_overrides=compiler_options, + ) + + opts = compile_options.executable_build_options + if auto_spmd_lowering: + assert mesh is not None + opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values()) + opts.auto_spmd_partitioning_mesh_ids = ( + sharding_specs.get_logical_mesh_ids(list(mesh.shape.values())) + .reshape(-1)) + compile_options.parameter_is_tupled_arguments = tuple_args + + if _allow_propagation_to_outputs is None: + _allow_propagation_to_outputs = [False] * num_out_avals + opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs) + + if hasattr(backend, "compile_replicated"): + return None, compile_options + + with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " + "in {elapsed_time} sec", + event=dispatch.BACKEND_COMPILE_EVENT): + xla_executable = dispatch.compile_or_get_cached( + backend, computation, compile_options, host_callbacks) + return xla_executable, compile_options + + @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any @@ -2584,142 +2789,111 @@ class UnloadedMeshExecutable: keepalive: Any, kept_var_idx: Set[int], backend: xb.XlaBackend, - device_assignment: Sequence[xc.Device], + device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]], committed: bool, pmap_nreps: int = 1, compiler_options=None ) -> MeshExecutable: - - dev: np.ndarray - if auto_spmd_lowering: - assert mesh is not None and spmd_lowering - dev = mesh.devices - num_replicas, num_partitions = 1, mesh.size - else: - # TODO(phawkins): One would normally just write: - # dev = np.array(device_assignment) - # The formulation below is substantially faster if there are many devices. - # If we were to optimize __getattr__ on xc.Device we might not need this - # workaround. - dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])( - np.arange(len(device_assignment)) - ) - if pmap_nreps > 1: - num_replicas, num_partitions = pmap_nreps, 1 - elif spmd_lowering: - num_replicas, num_partitions = 1, dev.size - else: - num_replicas, num_partitions = dev.size, 1 - - if pmap_nreps > 1: - # In `jit` device_assignment is set to None when num_replicas > 1. Do - # the same thing here too. - xla_device_assignment = None - else: - xla_device_assignment = dev.reshape((num_replicas, num_partitions)) - - compile_options = xb.get_compile_options( - num_replicas=num_replicas, - num_partitions=num_partitions, - device_assignment=xla_device_assignment, - use_spmd_partitioning=spmd_lowering, - use_auto_spmd_partitioning=auto_spmd_lowering, - env_options_overrides=compiler_options, - ) - opts = compile_options.executable_build_options - if auto_spmd_lowering: - assert mesh is not None - opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values()) - opts.auto_spmd_partitioning_mesh_ids = ( - sharding_specs.get_logical_mesh_ids(list(mesh.shape.values())) - .reshape(-1)) - compile_options.parameter_is_tupled_arguments = tuple_args - - if _allow_propagation_to_outputs is None: - _allow_propagation_to_outputs = [False] * len(out_shardings) - opts.allow_spmd_sharding_propagation_to_output = _allow_propagation_to_outputs + allow_propagation_to_outputs = ( + tuple(_allow_propagation_to_outputs) + if _allow_propagation_to_outputs is not None else None) + compiler_options_keys = tuple( + compiler_options.keys()) if compiler_options is not None else None + compiler_options_values = tuple( + compiler_options.values()) if compiler_options is not None else None + da = device_assignment if isinstance( + device_assignment, _DeviceAssignment) else tuple(device_assignment) + xla_executable, compile_options = _cached_compilation( + computation, name, mesh, len(global_out_avals), spmd_lowering, + tuple_args, auto_spmd_lowering, allow_propagation_to_outputs, + tuple(host_callbacks), backend, da, pmap_nreps, + compiler_options_keys, compiler_options_values) if hasattr(backend, "compile_replicated"): + semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore + semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore return _compile_replicated_mesh_executable_from_hlo( - name, computation, global_in_avals, global_out_avals, in_shardings, - out_shardings, auto_spmd_lowering, compile_options, - host_callbacks, bool(unordered_effects), ordered_effects, - kept_var_idx, backend, device_assignment, committed, pmap_nreps) + computation, name, tuple(global_in_avals), tuple(global_out_avals), + semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering, + compile_options, tuple(host_callbacks), bool(unordered_effects), + tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed, + pmap_nreps) + + del da + device_assignment = device_assignment.device_assignment if isinstance( + device_assignment, _DeviceAssignment) else device_assignment + + if auto_spmd_lowering: + assert mesh is not None + in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( + xla_executable, mesh) + in_shardings = [x if is_auto(i) else i + for x, i in safe_zip(in_shardings_xla, in_shardings)] + out_shardings_tuple = [ + (x, True) if is_auto(o) else (o, False) + for x, o in safe_zip(out_shardings_xla, out_shardings) + ] + out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) + elif (out_shardings and any(_is_unspecified(o) for o in out_shardings) + and pmap_nreps == 1): + assert mesh is None + _, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore + xla_executable, device_assignment, + len(global_in_avals), len(global_out_avals)) + orig_out_shardings = out_shardings + out_shardings, are_out_shardings_from_xla = [], [] # type: ignore + for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, + global_out_avals): + if _is_unspecified(orig): + out_shardings.append(xla_s) + are_out_shardings_from_xla.append(True) + else: + if not op_shardings.are_op_shardings_equal( + xla_s._to_xla_op_sharding(aval.ndim), # type: ignore + orig._to_xla_op_sharding(aval.ndim)): # type: ignore + raise AssertionError( + f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " + "(User sharding)") + out_shardings.append(orig) + are_out_shardings_from_xla.append(False) else: - with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " - "in {elapsed_time} sec", - event=dispatch.BACKEND_COMPILE_EVENT): - xla_executable = dispatch.compile_or_get_cached( - backend, computation, compile_options, host_callbacks) + are_out_shardings_from_xla = (False,) * len(global_out_avals) - if auto_spmd_lowering: - assert mesh is not None - in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( - xla_executable, mesh) - in_shardings = [x if is_auto(i) else i - for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings_tuple = [ - (x, True) if is_auto(o) else (o, False) - for x, o in safe_zip(out_shardings_xla, out_shardings) - ] - out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) - elif (out_shardings and any(_is_unspecified(o) for o in out_shardings) - and pmap_nreps == 1): - assert mesh is None - _, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore - xla_executable, device_assignment, - len(global_in_avals), len(global_out_avals)) - orig_out_shardings = out_shardings - out_shardings, are_out_shardings_from_xla = [], [] # type: ignore - for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, - global_out_avals): - if _is_unspecified(orig): - out_shardings.append(xla_s) - are_out_shardings_from_xla.append(True) - else: - if not op_shardings.are_op_shardings_equal( - xla_s._to_xla_op_sharding(aval.ndim), # type: ignore - orig._to_xla_op_sharding(aval.ndim)): # type: ignore - raise AssertionError( - f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " - "(User sharding)") - out_shardings.append(orig) - are_out_shardings_from_xla.append(False) - else: - are_out_shardings_from_xla = (False,) * len(global_out_avals) + if pmap_nreps > 1: + local_devices = xla_executable.local_devices() + # Create replicated shardings for jit(pmap) path with local devices + # because multihost jit(pmap) is not allowed. + in_shardings = [ + sharding_impls.GSPMDSharding.get_replicated(local_devices) + ] * len(in_shardings) + out_shardings = [ + sharding_impls.GSPMDSharding.get_replicated(local_devices) + ] * len(out_shardings) + # jit(pmap) will generate Arrays with multi-device sharding. + # It is unsupported for these shardings to be uncommited, so force + # the outputs to be committed. + committed = True - if pmap_nreps > 1: - local_devices = xla_executable.local_devices() - # Create replicated shardings for jit(pmap) path with local devices - # because multihost jit(pmap) is not allowed. - in_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * len(in_shardings) - out_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * len(out_shardings) - # jit(pmap) will generate Arrays with multi-device sharding. - # It is unsupported for these shardings to be uncommited, so force - # the outputs to be committed. - committed = True + out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding( + in_shardings, out_shardings, are_out_shardings_from_xla) - return UnloadedMeshExecutable( - xla_executable=xla_executable, - device_assignment=device_assignment, - backend=backend, - input_avals=global_in_avals, - input_shardings=in_shardings, # type: ignore - output_avals=global_out_avals, - output_shardings=out_shardings, # type: ignore # arg-type - committed=committed, - are_out_shardings_from_xla=are_out_shardings_from_xla, - name=name, - unordered_effects=unordered_effects, - ordered_effects=ordered_effects, - keepalive=keepalive, - host_callbacks=host_callbacks, - kept_var_idx=kept_var_idx, - auto_spmd_lowering=auto_spmd_lowering).load() + return UnloadedMeshExecutable( + xla_executable=xla_executable, + device_assignment=device_assignment, + backend=backend, + input_avals=global_in_avals, + input_shardings=in_shardings, # type: ignore + output_avals=global_out_avals, + output_shardings=out_shardings, # type: ignore # arg-type + committed=committed, + are_out_shardings_from_xla=are_out_shardings_from_xla, + name=name, + unordered_effects=unordered_effects, + ordered_effects=ordered_effects, + keepalive=keepalive, + host_callbacks=host_callbacks, + kept_var_idx=kept_var_idx, + auto_spmd_lowering=auto_spmd_lowering).load() class MeshExecutableFastpathData(NamedTuple): @@ -2766,19 +2940,18 @@ class MeshExecutable(stages.XlaExecutable): @staticmethod def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, - in_shardings, backend, device_assignment, + in_shardings, backend, da_object, committed, kept_var_idx, keepalive) -> MeshExecutable: assert keepalive is None if hasattr(backend, "compile_replicated"): return _compile_replicated_mesh_executable_from_trivial_jaxpr( jaxpr, consts, global_in_avals, global_out_avals, in_shardings, - backend, device_assignment, committed, kept_var_idx, 1) + backend, da_object.device_assignment, committed, kept_var_idx, 1) out_shardings = _out_shardings_for_trivial( - jaxpr, consts, in_shardings, device_assignment) + jaxpr, consts, in_shardings, da_object.device_assignment) indices = _get_input_indices(global_out_avals, out_shardings) - local_device_assignment = [d for d in device_assignment - if d.process_index == d.client.process_index()] + local_device_assignment = da_object.addressable_device_assignment handle_ins = InputsHandler(local_device_assignment, out_shardings, indices) handle_outs = global_avals_to_results_handler( global_out_avals, out_shardings, committed, @@ -2787,7 +2960,7 @@ class MeshExecutable(stages.XlaExecutable): handle_outs, kept_var_idx) return MeshExecutable(None, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, False, kept_var_idx, - device_assignment, None) + da_object.device_assignment, None) # -- stages.XlaExecutable overrides @@ -2861,7 +3034,16 @@ def _out_shardings_for_trivial( # a replicated sharding from jax._src import array - rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment) + if len(device_assignment) > 1: + rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment) + in_shardings = tuple( + i._original_sharding if hasattr(i, '_original_sharding') else i + for i in in_shardings) + else: + dev, = device_assignment + rep = sharding_impls.SingleDeviceSharding(dev) + in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings) + shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {} for constvar, constval in zip(jaxpr.constvars, consts): if isinstance(constval, array.ArrayImpl): @@ -2881,19 +3063,26 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args return out_handler(in_handler(outs)) +@weakref_lru_cache def _compile_replicated_mesh_executable_from_hlo( - name, computation, global_in_avals, global_out_avals, in_shardings, - out_shardings, auto_spmd_lowering, compile_options, + computation, name, global_in_avals, global_out_avals, semantics_in_shardings, + semantics_out_shardings, auto_spmd_lowering, compile_options, host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx, - backend, device_assignment, committed, pmap_nreps): + backend, da, committed, pmap_nreps): assert not auto_spmd_lowering + + in_shardings = semantics_in_shardings.shardings + out_shardings = semantics_out_shardings.shardings + device_assignment = da.device_assignment if isinstance( + da, _DeviceAssignment) else da + input_indices = _get_input_indices( global_in_avals, in_shardings) # type: ignore if pmap_nreps > 1: # For a jit wrapping a pmap, replicate each input index to match the # devices of the replicated jit computation. input_indices = [index * pmap_nreps for index in input_indices] - + kept_var_idx = set(kept_var_idx) # Will compute out_handler with executable information. unsafe_call = backend.compile_replicated( is_trivial=False, name=name, computation=computation, @@ -3038,7 +3227,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk): _check_aval(v.aval, what_thunk) - def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False): mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} # NOTE: This takes in the non-sharded avals! diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0f4fe5a77..395a5ca3d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1327,8 +1327,9 @@ class SameDeviceAssignmentTuple: device_assignment: Optional[XLADeviceAssignment] def __hash__(self): - shardings_hash = tuple(s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore - for s in self.shardings) + shardings_hash = tuple( + s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore + for s in self.shardings) if self.device_assignment is None: return hash(shardings_hash) else: @@ -1337,15 +1338,16 @@ class SameDeviceAssignmentTuple: def __eq__(self, other): if not isinstance(other, SameDeviceAssignmentTuple): return False - return ( - all( - op_shardings.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error - if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding) - else s == o - for s, o in zip(self.shardings, other.shardings) - ) - and self.device_assignment == other.device_assignment - ) + eq = [] + for s, o in zip(self.shardings, other.shardings): + s = getattr(s, "_original_sharding", s) + o = getattr(o, "_original_sharding", o) + if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding): + eq.append(op_shardings.are_op_shardings_equal( + s._op_sharding, o._op_sharding)) + else: + eq.append(s == o) + return all(eq) and self.device_assignment == other.device_assignment def _pjit_lower( @@ -1416,8 +1418,8 @@ def _pjit_lower_cached( lowering_platform=lowering_platform) else: return pxla.lower_sharding_computation( - jaxpr, api_name, name, in_shardings, out_shardings, donated_invars, - jaxpr.in_avals, keep_unused=keep_unused, + jaxpr, api_name, name, in_shardings, out_shardings, + tuple(donated_invars), tuple(jaxpr.in_avals), keep_unused=keep_unused, always_lower=always_lower, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 84314c6e3..3e1fcc04e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -44,7 +44,8 @@ from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array from jax._src.sharding import Sharding from jax._src import op_shardings -from jax._src.sharding_impls import NamedSharding, GSPMDSharding +from jax._src.sharding_impls import (NamedSharding, GSPMDSharding, + PositionalSharding, SingleDeviceSharding) import jax._src.pjit as pjit_lib from jax._src.pjit import (pjit, pjit_p, AUTO) from jax._src import mesh @@ -647,8 +648,11 @@ class PJitTest(jtu.BufferDonationTestCase): z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y) self.assertAllClose(z, x[jnp.newaxis] + y) self.assertAllClose(w, x) - self.assertEqual(z.sharding._op_sharding.tile_assignment_dimensions, [1, 2]) - self.assertEqual(w.sharding._op_sharding.tile_assignment_dimensions, [2]) + self.assertEqual( + z.sharding._to_xla_op_sharding(z.ndim).tile_assignment_dimensions, + [1, 2]) + self.assertEqual( + w.sharding._to_xla_op_sharding(w.ndim).tile_assignment_dimensions, [2]) @jtu.with_mesh([('x', 2)]) def testVMapShardingConstraint(self): @@ -1379,7 +1383,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def _checks(out, input_data): self.assertIsInstance(out, array.ArrayImpl) - self.assertIsInstance(out.sharding, GSPMDSharding) + self.assertIsInstance(out.sharding, NamedSharding) self.assertEqual(out.shape, (8, 2)) self.assertEqual(out.addressable_shards[0].data.shape, (2, 1)) for s in out.addressable_shards: @@ -1907,20 +1911,20 @@ class ArrayPjitTest(jtu.JaxTestCase): f = pjit(lambda x: x) out1 = f(arr) - self.assertIsInstance(out1.sharding, GSPMDSharding) + self.assertIsInstance(out1.sharding, NamedSharding) out1.sharding.devices_indices_map(shape) - cache_info1 = GSPMDSharding.devices_indices_map.cache_info() + cache_info1 = NamedSharding.devices_indices_map.cache_info() out2 = f(out1) - self.assertIsInstance(out2.sharding, GSPMDSharding) + self.assertIsInstance(out2.sharding, NamedSharding) out2.sharding.devices_indices_map(shape) - cache_info2 = GSPMDSharding.devices_indices_map.cache_info() + cache_info2 = NamedSharding.devices_indices_map.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) out3 = f(out2) - self.assertIsInstance(out3.sharding, GSPMDSharding) + self.assertIsInstance(out3.sharding, NamedSharding) out3.sharding.devices_indices_map(shape) - cache_info3 = GSPMDSharding.devices_indices_map.cache_info() + cache_info3 = NamedSharding.devices_indices_map.cache_info() self.assertEqual(cache_info3.hits, cache_info2.hits + 1) def test_device_put_sharding_prng(self): @@ -2202,8 +2206,8 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertArraysEqual(f_out1, g_out1) self.assertArraysEqual(f_out2, g_out2) - self.assertEqual(f_out1.sharding, g_out1.sharding) - self.assertEqual(f_out2.sharding, g_out2.sharding) + self.assertTrue(f_out1.sharding.is_equivalent_to(g_out1.sharding, f_out1.ndim)) + self.assertTrue(f_out2.sharding.is_equivalent_to(g_out2.sharding, f_out2.ndim)) def test_pjit_on_different_default_device_with_uncommitted_inputs(self): if jax.device_count() < 2: @@ -2932,6 +2936,216 @@ class ArrayPjitTest(jtu.JaxTestCase): # Test second order autodiff with src argument specified in device_put. jtu.check_grads(g, (arr,), order=2) + def test_pjit_out_sharding_preserved(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + + arr = jax.device_put(np.arange(8).reshape(8, 1), ns) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + + def mul(x): + return x * 2 + + f = pjit(mul, out_shardings=ns) + f2 = pjit(mul, out_shardings=ps) + + with jtu.count_pjit_cpp_cache_miss() as count: + out = f(arr) + cache_info1 = pxla._cached_compilation.cache_info() + self.assertIsInstance(out.sharding, NamedSharding) + + out = f(arr) + self.assertIsInstance(out.sharding, NamedSharding) + self.assertEqual(count[0], 1) + + with jtu.count_pjit_cpp_cache_miss() as count: + out2 = f2(arr) + cache_info2 = pxla._cached_compilation.cache_info() + self.assertIsInstance(out2.sharding, PositionalSharding) + + out2 = f2(arr) + self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertEqual(count[0], 1) + + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) + + out3 = jnp.squeeze(arr, axis=-1) + cache_info3 = pxla._cached_compilation.cache_info() + self.assertIsInstance(out3.sharding, NamedSharding) + + out4 = jnp.squeeze(arr2, axis=-1) + cache_info4 = pxla._cached_compilation.cache_info() + # TODO(yashkatariya): Handle PositionalSharding inside pxla so that + # GSPMDShardings can be converted to PositionalSharding. + self.assertIsInstance(out4.sharding, GSPMDSharding) + + self.assertEqual(cache_info4.hits, cache_info3.hits + 1) + self.assertEqual(cache_info4.misses, cache_info3.misses) + + def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + np_arr = np.arange(8, dtype=np.float32).reshape(8, 1) + arr = jax.device_put(np_arr, ns) + + def mul(x): + return x * 2 + + f = pjit(mul, in_shardings=ns, out_shardings=ns) + + with jtu.count_pjit_cpp_cache_miss() as count: + out = f(arr) + cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out.sharding, NamedSharding) + + out2 = f(np_arr) + cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out2.sharding, NamedSharding) + + # Drops out of C++ cache i.e. cache miss + self.assertEqual(count[0], 2) + # Still gets a hit on pjit_lower cache. + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) + + def test_sharding_preserved_trivial(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + + arr = jax.device_put(np.arange(8).reshape(8, 1), ns) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + + def identity(x): + return x + + out = pjit(identity)(arr) + self.assertIsInstance(out.sharding, NamedSharding) + + out2 = pjit(identity)(arr2) + self.assertIsInstance(out2.sharding, PositionalSharding) + + def test_sharding_preserved_aot(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + + arr = jax.device_put(np.arange(8).reshape(8, 1), ns) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + + compiled = pjit(lambda x: x * 2).lower(arr).compile() + out = compiled(arr) + self.assertIsInstance(out.sharding, NamedSharding) + + out2 = compiled(arr2) + # The sharding won't be PositionalSharding since the pjit was already + # Compiled which bakes in the output sharding. + self.assertIsInstance(out2.sharding, NamedSharding) + + def test_sharding_on_output_with_vmap(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + arr = jax.device_put( + np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x'))) + vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns)) + out = vf(arr) + cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out.sharding, GSPMDSharding) + + out2 = vf(out) + cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out2.sharding, GSPMDSharding) + + out3 = vf(out2) + cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out3.sharding, GSPMDSharding) + + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info3.hits, cache_info2.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) + self.assertEqual(cache_info3.misses, cache_info2.misses) + + def test_jit_mul_sum_sharding_preserved(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + + arr = jax.device_put(np.arange(8).reshape(8, 1), ns) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + + f = jax.jit(lambda x: x * 2) + out = f(arr) + cache_info1 = pxla._cached_compilation.cache_info() + pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out.sharding, NamedSharding) + + out2 = f(arr2) + cache_info2 = pxla._cached_compilation.cache_info() + pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + # TODO(yashkatariya): Handle PositionalSharding inside pxla so that + # GSPMDShardings can be converted to PositionalSharding. + self.assertIsInstance(out2.sharding, GSPMDSharding) + + out3 = f(out2) + cache_info3 = pxla._cached_compilation.cache_info() + pl_cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out3.sharding, GSPMDSharding) + + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info3.hits, cache_info2.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) + self.assertEqual(cache_info3.misses, cache_info2.misses) + + # TODO(yashkatariya): We will get hits here after we can convert + # GSPMDSharding to PositionalSharding. + self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) + self.assertEqual(pl_cache_info3.misses, pl_cache_info2.misses + 1) + + out4 = jnp.sum(arr) + self.assertIsInstance(out4.sharding, NamedSharding) + + def test_single_device_sharding_preserved(self): + if jax.device_count() < 2: + self.skipTest('Test requires >=2 devices') + + x = jnp.arange(8) + + # trivial computation + out = jax.jit(lambda x: x)(x) + self.assertIsInstance(out.sharding, SingleDeviceSharding) + + # trivial computation with committed inp + y = jax.device_put(x, jax.devices()[1]) + out2 = jax.jit(lambda x: x)(y) + self.assertIsInstance(out2.sharding, SingleDeviceSharding) + self.assertEqual(out2.device(), jax.devices()[1]) + + out3 = jax.jit(lambda x: x * 2)(x) + self.assertIsInstance(out3.sharding, SingleDeviceSharding) + + out4 = jax.jit(lambda x: x * 3, + out_shardings=SingleDeviceSharding(jax.devices()[1]))(x) + self.assertIsInstance(out4.sharding, SingleDeviceSharding) + self.assertEqual(out4.device(), jax.devices()[1]) + + def test_sharding_preserved_apply_primitive(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + ns = NamedSharding(mesh, P('x')) + + arr = jax.device_put(np.arange(8).reshape(8, 1), ns) + + out = jnp.copy(arr) + self.assertIsInstance(out.sharding, NamedSharding) + + # TODO(yashkatariya): Fix apply_primitive's cache on xla_primitive_callable + # to be like pjit_lower cache. + # ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + # arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + # out2 = jnp.copy(arr2) + # self.assertIsInstance(out2.sharding, PositionalSharding) + class TempSharding(Sharding):