Preserve shardings on the output of pjit that were provided on the arguments.

Following are the changes:

* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
  * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
  * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
  * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.

The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.

For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.

PiperOrigin-RevId: 522991145
This commit is contained in:
Yash Katariya 2023-04-09 15:41:32 -07:00 committed by jax authors
parent 90d58f4572
commit 5d2f453094
4 changed files with 670 additions and 266 deletions

View File

@ -218,7 +218,7 @@ def sharded_lowering(fun, name, donated_invars, keep_unused,
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, always_lower=False,
tuple(in_avals), keep_unused=keep_unused, always_lower=False,
devices_from_context=None, lowering_platform=lowering_platform)

View File

@ -63,7 +63,7 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction)
unzip2, HashableFunction, weakref_lru_cache)
# Built in Python lists don't support weak refs but subclasses of lists do.
@ -1474,7 +1474,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
devices)
class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
@ -1997,38 +1996,28 @@ def _get_and_check_device_assignment(
final_device_assignment = first_sharding_info[0]
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
def cache_wrap(fn):
_wrapped_with_lu_cache = lu.cache(fn)
_wrapped_with_weakref_lru_cache = weakref_lru_cache(fn)
def wrapped(f, *args, **kwargs):
if isinstance(f, lu.WrappedFun):
return _wrapped_with_lu_cache(f, *args, **kwargs)
else:
return _wrapped_with_weakref_lru_cache(f, *args, **kwargs)
return wrapped
The caller of this code can pass in a singleton _UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton _UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
@cache_wrap
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
with dispatch.log_elapsed_time(
f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, global_in_avals)
else:
@ -2045,41 +2034,44 @@ def lower_sharding_computation(
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
kept_var_idx, name_stack)
@dataclasses.dataclass(frozen=True)
class SemanticallyEqualShardings:
shardings: Tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
def __hash__(self):
return hash(tuple(
s._op_sharding_hash if isinstance(s, sharding_impls.GSPMDSharding) else s # type: ignore
for s in self.shardings))
def __eq__(self, other):
if not isinstance(other, SemanticallyEqualShardings):
return False
return all(op_shardings.are_op_shardings_equal(s._op_sharding, o._op_sharding)
if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o for s, o in zip(self.shardings, other.shardings))
@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
donated_invars, name_stack):
jaxpr = closed_jaxpr.jaxpr
kept_outputs = [True] * len(global_out_avals)
if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
devices_from_context)
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not _is_unspecified(i) for i in in_shardings) or
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not _is_unspecified(o) for o in out_shardings))
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
device_assignment = da_object.device_assignment
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
@ -2087,55 +2079,14 @@ def lower_sharding_computation(
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
local_device_assignment = [d for d in device_assignment
if d.process_index == d.client.process_index()]
if len(device_assignment) != len(local_device_assignment):
check_multihost_collective_allowlist(jaxpr)
# TODO(yashkatariya): Once jit and pjit's frontend is merged, use the
# argument on jit `_allow_multiprocess` (which will be added later) instead
# of the `api_name` check here.
# Furthermore, `allow_jit` is not allowed yet because `allow_jit` only
# allows explicit `jax.jit` to work but not implicitly jitted `jnp`.
# operations. This restriction will be relaxed in the future when the
# default value of `spmd_mode` config changes to `allow_jit`.
if api_name == 'jit' and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs. "
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html. "
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts,
global_in_avals=global_in_avals, global_out_avals=global_out_avals,
in_shardings=in_shardings, backend=backend,
device_assignment=device_assignment, committed=committed,
kept_var_idx=kept_var_idx, keepalive=None)
# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
# handled here so as to deprecate the lower_xla_callable codepath when
# `jax.Array` is turned on by default.
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
nreps = dispatch.jaxpr_replicas(jaxpr)
dispatch.raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
nreps, backend, fun_name, jaxpr)
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
@ -2179,10 +2130,152 @@ def lower_sharding_computation(
result_shardings=out_op_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
return (module, keepalive, host_callbacks, unordered_effects,
ordered_effects, nreps, tuple_args)
@dataclasses.dataclass(frozen=True)
class _DeviceAssignment:
device_assignment: Tuple[xc.Device, ...]
@cached_property
def _hash(self):
return hash(self.device_assignment)
def __hash__(self):
return self._hash
def __eq__(self, other):
if not isinstance(other, _DeviceAssignment):
return False
if id(self) == id(other):
return True
return (self.device_assignment == other.device_assignment)
@cached_property
def is_fully_addressable(self):
return len(self.device_assignment) == len(self.addressable_device_assignment)
@cached_property
def addressable_device_assignment(self):
return [d for d in self.device_assignment
if d.process_index == d.client.process_index()]
@lru_cache(maxsize=2048)
def _create_da_object(
device_assignment: Tuple[xc.Device, ...]) -> _DeviceAssignment:
return _DeviceAssignment(device_assignment)
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton _UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton _UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
devices_from_context)
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not _is_unspecified(i) for i in in_shardings) or
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not _is_unspecified(o) for o in out_shardings))
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
if not da_object.is_fully_addressable:
check_multihost_collective_allowlist(jaxpr)
# TODO(yashkatariya): Once jit and pjit's frontend is merged, use the
# argument on jit `_allow_multiprocess` (which will be added later) instead
# of the `api_name` check here.
# Furthermore, `allow_jit` is not allowed yet because `allow_jit` only
# allows explicit `jax.jit` to work but not implicitly jitted `jnp`.
# operations. This restriction will be relaxed in the future when the
# default value of `spmd_mode` config changes to `allow_jit`.
if api_name == 'jit' and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. Its very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs. "
"If youre not already familiar with JAXs multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html. "
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
kept_outputs = [True] * len(global_out_avals)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_shardings=in_shardings,
backend=backend, da_object=da_object,
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2208,10 +2301,11 @@ def lower_sharding_computation(
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
device_assignment=device_assignment,
device_assignment=da_object,
committed=committed,
pmap_nreps=nreps)
def _to_logical_op_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
) -> Optional[xc.OpSharding]:
@ -2382,7 +2476,7 @@ def lower_mesh_computation(
keepalive=keepalive,
kept_var_idx=set(range(len(global_in_avals))),
backend=backend,
device_assignment=list(mesh.devices.flat),
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True)
class MeshComputation(stages.XlaLowering):
@ -2522,6 +2616,117 @@ def _get_mesh_pspec_shardings_from_executable(
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
def _get_out_sharding_from_named_sharding(
out_shardings, ns, are_out_sharding_from_xla):
from jax._src import pjit
out = []
for o, from_xla in safe_zip(out_shardings, are_out_sharding_from_xla):
if isinstance(o, sharding_impls.GSPMDSharding):
try:
out.append((sharding_impls.NamedSharding._from_parsed_pspec(
ns.mesh, pjit.parse_flatten_op_sharding(o._op_sharding, ns.mesh)[0]), False))
except:
out.append((o, from_xla))
else:
out.append((o, from_xla))
return out
def maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla):
if all(hasattr(o, '_original_sharding') for o in out_shardings):
return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings))
# TODO(yashkatariya): Handle other shardings too here.
ns = None
for i in in_shardings:
oi = getattr(i, '_original_sharding', None)
if isinstance(oi, sharding_impls.NamedSharding):
ns = oi
break
if ns is not None:
return zip(*_get_out_sharding_from_named_sharding(
out_shardings, ns, are_out_shardings_from_xla))
return out_shardings, are_out_shardings_from_xla
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, num_out_avals, spmd_lowering,
tuple_args, auto_spmd_lowering,
_allow_propagation_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values):
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
else:
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
np.arange(len(device_assignment))
)
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
num_replicas, num_partitions = dev.size, 1
if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
# the same thing here too.
xla_device_assignment = None
else:
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
if compiler_options_keys is None:
compiler_options = None
else:
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=xla_device_assignment,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
)
opts = compile_options.executable_build_options
if auto_spmd_lowering:
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
if _allow_propagation_to_outputs is None:
_allow_propagation_to_outputs = [False] * num_out_avals
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
if hasattr(backend, "compile_replicated"):
return None, compile_options
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
backend, computation, compile_options, host_callbacks)
return xla_executable, compile_options
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
@ -2584,142 +2789,111 @@ class UnloadedMeshExecutable:
keepalive: Any,
kept_var_idx: Set[int],
backend: xb.XlaBackend,
device_assignment: Sequence[xc.Device],
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
committed: bool,
pmap_nreps: int = 1,
compiler_options=None
) -> MeshExecutable:
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
else:
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
np.arange(len(device_assignment))
)
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
num_replicas, num_partitions = dev.size, 1
if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
# the same thing here too.
xla_device_assignment = None
else:
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=xla_device_assignment,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
)
opts = compile_options.executable_build_options
if auto_spmd_lowering:
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
if _allow_propagation_to_outputs is None:
_allow_propagation_to_outputs = [False] * len(out_shardings)
opts.allow_spmd_sharding_propagation_to_output = _allow_propagation_to_outputs
allow_propagation_to_outputs = (
tuple(_allow_propagation_to_outputs)
if _allow_propagation_to_outputs is not None else None)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None
da = device_assignment if isinstance(
device_assignment, _DeviceAssignment) else tuple(device_assignment)
xla_executable, compile_options = _cached_compilation(
computation, name, mesh, len(global_out_avals), spmd_lowering,
tuple_args, auto_spmd_lowering, allow_propagation_to_outputs,
tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
if hasattr(backend, "compile_replicated"):
semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
return _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, bool(unordered_effects), ordered_effects,
kept_var_idx, backend, device_assignment, committed, pmap_nreps)
computation, name, tuple(global_in_avals), tuple(global_out_avals),
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
compile_options, tuple(host_callbacks), bool(unordered_effects),
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps)
del da
device_assignment = device_assignment.device_assignment if isinstance(
device_assignment, _DeviceAssignment) else device_assignment
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment,
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
global_out_avals):
if _is_unspecified(orig):
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
if not op_shardings.are_op_shardings_equal(
xla_s._to_xla_op_sharding(aval.ndim), # type: ignore
orig._to_xla_op_sharding(aval.ndim)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
else:
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
backend, computation, compile_options, host_callbacks)
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment,
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
global_out_avals):
if _is_unspecified(orig):
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
if not op_shardings.are_op_shardings_equal(
xla_s._to_xla_op_sharding(aval.ndim), # type: ignore
orig._to_xla_op_sharding(aval.ndim)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
else:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if pmap_nreps > 1:
local_devices = xla_executable.local_devices()
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(in_shardings)
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(out_shardings)
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
if pmap_nreps > 1:
local_devices = xla_executable.local_devices()
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(in_shardings)
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(out_shardings)
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=device_assignment,
backend=backend,
input_avals=global_in_avals,
input_shardings=in_shardings, # type: ignore
output_avals=global_out_avals,
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering).load()
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=device_assignment,
backend=backend,
input_avals=global_in_avals,
input_shardings=in_shardings, # type: ignore
output_avals=global_out_avals,
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering).load()
class MeshExecutableFastpathData(NamedTuple):
@ -2766,19 +2940,18 @@ class MeshExecutable(stages.XlaExecutable):
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
in_shardings, backend, device_assignment,
in_shardings, backend, da_object,
committed, kept_var_idx, keepalive) -> MeshExecutable:
assert keepalive is None
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
backend, device_assignment, committed, kept_var_idx, 1)
backend, da_object.device_assignment, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, device_assignment)
jaxpr, consts, in_shardings, da_object.device_assignment)
indices = _get_input_indices(global_out_avals, out_shardings)
local_device_assignment = [d for d in device_assignment
if d.process_index == d.client.process_index()]
local_device_assignment = da_object.addressable_device_assignment
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
@ -2787,7 +2960,7 @@ class MeshExecutable(stages.XlaExecutable):
handle_outs, kept_var_idx)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
device_assignment, None)
da_object.device_assignment, None)
# -- stages.XlaExecutable overrides
@ -2861,7 +3034,16 @@ def _out_shardings_for_trivial(
# a replicated sharding
from jax._src import array
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if len(device_assignment) > 1:
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(
i._original_sharding if hasattr(i, '_original_sharding') else i
for i in in_shardings)
else:
dev, = device_assignment
rep = sharding_impls.SingleDeviceSharding(dev)
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
@ -2881,19 +3063,26 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
return out_handler(in_handler(outs))
@weakref_lru_cache
def _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, auto_spmd_lowering, compile_options,
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
semantics_out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, device_assignment, committed, pmap_nreps):
backend, da, committed, pmap_nreps):
assert not auto_spmd_lowering
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
input_indices = _get_input_indices(
global_in_avals, in_shardings) # type: ignore
if pmap_nreps > 1:
# For a jit wrapping a pmap, replicate each input index to match the
# devices of the replicated jit computation.
input_indices = [index * pmap_nreps for index in input_indices]
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
@ -3038,7 +3227,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
_check_aval(v.aval, what_thunk)
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!

View File

@ -1327,8 +1327,9 @@ class SameDeviceAssignmentTuple:
device_assignment: Optional[XLADeviceAssignment]
def __hash__(self):
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore
for s in self.shardings)
shardings_hash = tuple(
s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore
for s in self.shardings)
if self.device_assignment is None:
return hash(shardings_hash)
else:
@ -1337,15 +1338,16 @@ class SameDeviceAssignmentTuple:
def __eq__(self, other):
if not isinstance(other, SameDeviceAssignmentTuple):
return False
return (
all(
op_shardings.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
else s == o
for s, o in zip(self.shardings, other.shardings)
)
and self.device_assignment == other.device_assignment
)
eq = []
for s, o in zip(self.shardings, other.shardings):
s = getattr(s, "_original_sharding", s)
o = getattr(o, "_original_sharding", o)
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
eq.append(op_shardings.are_op_shardings_equal(
s._op_sharding, o._op_sharding))
else:
eq.append(s == o)
return all(eq) and self.device_assignment == other.device_assignment
def _pjit_lower(
@ -1416,8 +1418,8 @@ def _pjit_lower_cached(
lowering_platform=lowering_platform)
else:
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, keep_unused=keep_unused,
jaxpr, api_name, name, in_shardings, out_shardings,
tuple(donated_invars), tuple(jaxpr.in_avals), keep_unused=keep_unused,
always_lower=always_lower,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),

View File

@ -44,7 +44,8 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding
from jax._src import op_shardings
from jax._src.sharding_impls import NamedSharding, GSPMDSharding
from jax._src.sharding_impls import (NamedSharding, GSPMDSharding,
PositionalSharding, SingleDeviceSharding)
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, AUTO)
from jax._src import mesh
@ -647,8 +648,11 @@ class PJitTest(jtu.BufferDonationTestCase):
z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y)
self.assertAllClose(z, x[jnp.newaxis] + y)
self.assertAllClose(w, x)
self.assertEqual(z.sharding._op_sharding.tile_assignment_dimensions, [1, 2])
self.assertEqual(w.sharding._op_sharding.tile_assignment_dimensions, [2])
self.assertEqual(
z.sharding._to_xla_op_sharding(z.ndim).tile_assignment_dimensions,
[1, 2])
self.assertEqual(
w.sharding._to_xla_op_sharding(w.ndim).tile_assignment_dimensions, [2])
@jtu.with_mesh([('x', 2)])
def testVMapShardingConstraint(self):
@ -1379,7 +1383,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def _checks(out, input_data):
self.assertIsInstance(out, array.ArrayImpl)
self.assertIsInstance(out.sharding, GSPMDSharding)
self.assertIsInstance(out.sharding, NamedSharding)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
for s in out.addressable_shards:
@ -1907,20 +1911,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x)
out1 = f(arr)
self.assertIsInstance(out1.sharding, GSPMDSharding)
self.assertIsInstance(out1.sharding, NamedSharding)
out1.sharding.devices_indices_map(shape)
cache_info1 = GSPMDSharding.devices_indices_map.cache_info()
cache_info1 = NamedSharding.devices_indices_map.cache_info()
out2 = f(out1)
self.assertIsInstance(out2.sharding, GSPMDSharding)
self.assertIsInstance(out2.sharding, NamedSharding)
out2.sharding.devices_indices_map(shape)
cache_info2 = GSPMDSharding.devices_indices_map.cache_info()
cache_info2 = NamedSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
out3 = f(out2)
self.assertIsInstance(out3.sharding, GSPMDSharding)
self.assertIsInstance(out3.sharding, NamedSharding)
out3.sharding.devices_indices_map(shape)
cache_info3 = GSPMDSharding.devices_indices_map.cache_info()
cache_info3 = NamedSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
def test_device_put_sharding_prng(self):
@ -2202,8 +2206,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(f_out1, g_out1)
self.assertArraysEqual(f_out2, g_out2)
self.assertEqual(f_out1.sharding, g_out1.sharding)
self.assertEqual(f_out2.sharding, g_out2.sharding)
self.assertTrue(f_out1.sharding.is_equivalent_to(g_out1.sharding, f_out1.ndim))
self.assertTrue(f_out2.sharding.is_equivalent_to(g_out2.sharding, f_out2.ndim))
def test_pjit_on_different_default_device_with_uncommitted_inputs(self):
if jax.device_count() < 2:
@ -2932,6 +2936,216 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Test second order autodiff with src argument specified in device_put.
jtu.check_grads(g, (arr,), order=2)
def test_pjit_out_sharding_preserved(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
def mul(x):
return x * 2
f = pjit(mul, out_shardings=ns)
f2 = pjit(mul, out_shardings=ps)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(arr)
cache_info1 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out.sharding, NamedSharding)
out = f(arr)
self.assertIsInstance(out.sharding, NamedSharding)
self.assertEqual(count[0], 1)
with jtu.count_pjit_cpp_cache_miss() as count:
out2 = f2(arr)
cache_info2 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out2.sharding, PositionalSharding)
out2 = f2(arr)
self.assertIsInstance(out2.sharding, PositionalSharding)
self.assertEqual(count[0], 1)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
out3 = jnp.squeeze(arr, axis=-1)
cache_info3 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out3.sharding, NamedSharding)
out4 = jnp.squeeze(arr2, axis=-1)
cache_info4 = pxla._cached_compilation.cache_info()
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
# GSPMDShardings can be converted to PositionalSharding.
self.assertIsInstance(out4.sharding, GSPMDSharding)
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
np_arr = np.arange(8, dtype=np.float32).reshape(8, 1)
arr = jax.device_put(np_arr, ns)
def mul(x):
return x * 2
f = pjit(mul, in_shardings=ns, out_shardings=ns)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(arr)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out.sharding, NamedSharding)
out2 = f(np_arr)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out2.sharding, NamedSharding)
# Drops out of C++ cache i.e. cache miss
self.assertEqual(count[0], 2)
# Still gets a hit on pjit_lower cache.
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_sharding_preserved_trivial(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
def identity(x):
return x
out = pjit(identity)(arr)
self.assertIsInstance(out.sharding, NamedSharding)
out2 = pjit(identity)(arr2)
self.assertIsInstance(out2.sharding, PositionalSharding)
def test_sharding_preserved_aot(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
compiled = pjit(lambda x: x * 2).lower(arr).compile()
out = compiled(arr)
self.assertIsInstance(out.sharding, NamedSharding)
out2 = compiled(arr2)
# The sharding won't be PositionalSharding since the pjit was already
# Compiled which bakes in the output sharding.
self.assertIsInstance(out2.sharding, NamedSharding)
def test_sharding_on_output_with_vmap(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
arr = jax.device_put(
np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x')))
vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns))
out = vf(arr)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out.sharding, GSPMDSharding)
out2 = vf(out)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out2.sharding, GSPMDSharding)
out3 = vf(out2)
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out3.sharding, GSPMDSharding)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
self.assertEqual(cache_info3.misses, cache_info2.misses)
def test_jit_mul_sum_sharding_preserved(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
f = jax.jit(lambda x: x * 2)
out = f(arr)
cache_info1 = pxla._cached_compilation.cache_info()
pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out.sharding, NamedSharding)
out2 = f(arr2)
cache_info2 = pxla._cached_compilation.cache_info()
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
# GSPMDShardings can be converted to PositionalSharding.
self.assertIsInstance(out2.sharding, GSPMDSharding)
out3 = f(out2)
cache_info3 = pxla._cached_compilation.cache_info()
pl_cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out3.sharding, GSPMDSharding)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
self.assertEqual(cache_info3.misses, cache_info2.misses)
# TODO(yashkatariya): We will get hits here after we can convert
# GSPMDSharding to PositionalSharding.
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
self.assertEqual(pl_cache_info3.misses, pl_cache_info2.misses + 1)
out4 = jnp.sum(arr)
self.assertIsInstance(out4.sharding, NamedSharding)
def test_single_device_sharding_preserved(self):
if jax.device_count() < 2:
self.skipTest('Test requires >=2 devices')
x = jnp.arange(8)
# trivial computation
out = jax.jit(lambda x: x)(x)
self.assertIsInstance(out.sharding, SingleDeviceSharding)
# trivial computation with committed inp
y = jax.device_put(x, jax.devices()[1])
out2 = jax.jit(lambda x: x)(y)
self.assertIsInstance(out2.sharding, SingleDeviceSharding)
self.assertEqual(out2.device(), jax.devices()[1])
out3 = jax.jit(lambda x: x * 2)(x)
self.assertIsInstance(out3.sharding, SingleDeviceSharding)
out4 = jax.jit(lambda x: x * 3,
out_shardings=SingleDeviceSharding(jax.devices()[1]))(x)
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1])
def test_sharding_preserved_apply_primitive(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
arr = jax.device_put(np.arange(8).reshape(8, 1), ns)
out = jnp.copy(arr)
self.assertIsInstance(out.sharding, NamedSharding)
# TODO(yashkatariya): Fix apply_primitive's cache on xla_primitive_callable
# to be like pjit_lower cache.
# ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
# arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
# out2 = jnp.copy(arr2)
# self.assertIsInstance(out2.sharding, PositionalSharding)
class TempSharding(Sharding):