Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.

A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
This commit is contained in:
Yash Katariya 2022-08-29 22:02:32 -07:00 committed by jax authors
parent 10b2e210ed
commit 6340952e2a
17 changed files with 257 additions and 104 deletions

View File

@ -101,9 +101,6 @@ def pure_callback_lowering(ctx, *args, callback, **params):
if ctx.module_context.platform == "TPU" and jaxlib.version < (0, 3, 15):
raise NotImplementedError("Pure callbacks on TPU not supported. "
"Please upgrade to a jaxlib >= 0.3.15.")
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise NotImplementedError("Sharding for pure callback not implemented.")
def _callback(*flat_args):
return tuple(pure_callback_p.impl(*flat_args, callback=callback, **params))

View File

@ -97,8 +97,6 @@ def arg_spec(x: Any) -> ArgSpec:
aval = xla.abstractify(x)
try:
if config.jax_array:
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
else:
return aval, x._device
@ -182,7 +180,7 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
_, arg_devices = util.unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
if config.jax_array:
# This will be resolved in _xla_callable_device.
# This will be resolved in sharded_lowering.
device = None
else:
device = _device_from_arg_devices(arg_devices)
@ -277,6 +275,10 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
xla.xla_call_p.def_impl(_xla_call_impl)
TracedJaxprInfo = collections.namedtuple(
'TracedJaxprInfo', ['jaxpr', 'out_jaxpr_avals', 'consts'])
def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
*arg_specs):
# TODO(yashkatariya): Remove the local imports from here when the functions
@ -286,29 +288,40 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
in_avals, in_shardings = util.unzip2(arg_specs)
# TODO(yashkatariya): Remove this and make `SingleDeviceSharding` go through
# lower_sharding_computation and resolve all the errors once that happens.
# For pmap, keep using the fallback by checking the jaxpr and then wrapping it
# in a lu.Wrappedfun again.
if any(s is None or isinstance(s, sharding.SingleDeviceSharding) for s in in_shardings):
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun, in_avals, debug_info=pe.debug_info_final(fun, "jit"))
traced_jaxpr_info = TracedJaxprInfo(jaxpr, out_jaxpr_avals, consts)
# If jaxpr has the pmap primitive or if `backend` is provided on `jit`, then
# take the lower_xla_callable lowering path. This is because pmap's programming
# model is not compatible with lower_sharding_computation.
# Specifying backend on `jit` is not supported when Array is enabled. So take
# the `lower_xla_callable` path which can handle it.
if (jaxpr_has_primitive(jaxpr, 'xla_pmap') or
any(isinstance(s, sharding.PmapSharding) for s in in_shardings) or
backend is not None):
arg_specs = tuple(
(a, s._device) if isinstance(s, sharding.SingleDeviceSharding) else (a, None)
for a, s in zip(in_avals, in_shardings))
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
keep_unused, *arg_specs).compile().unsafe_call
return lower_xla_callable(
fun, None, backend, name, donated_invars, False, keep_unused, *arg_specs,
traced_jaxpr_info=traced_jaxpr_info).compile().unsafe_call
committed = any(i is not None for i in in_shardings)
da = pjit._get_and_check_device_assignment(
(i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh)
in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i
for i in in_shardings]
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'xla_callable', name, in_shardings, pjit._UNSPECIFIED,
donated_invars, in_avals,
in_is_global=(True,) * len(arg_specs)).compile(
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
committed=committed, traced_jaxpr_info=traced_jaxpr_info).compile(
_allow_propagation_to_outputs=True).unsafe_call
@ -347,9 +360,10 @@ def should_tuple_args(num_args: int, platform: str):
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, always_lower: bool, keep_unused: bool,
*arg_specs):
def lower_xla_callable(
fun: lu.WrappedFun, device, backend, name, donated_invars,
always_lower: bool, keep_unused: bool, *arg_specs,
traced_jaxpr_info: Optional[TracedJaxprInfo] = None):
"""Lower into XLA.
Args:
@ -371,11 +385,18 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
else:
assert abstract_args == (None,) * len(abstract_args)
abstract_args = [aval for aval, _ in fun.in_type]
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)
if traced_jaxpr_info is None:
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)
else:
jaxpr, out_avals, consts = traced_jaxpr_info
kept_outputs = [True] * len(out_avals)
out_type = tuple(zip(out_avals, kept_outputs))
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")

View File

@ -548,6 +548,8 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
weak_type: bool = False):
from jax.experimental import array
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = dtypes.dtype(operand, canonicalize=False)
@ -575,7 +577,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
old_weak_type = False
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
@ -794,8 +796,10 @@ def broadcast_in_dim(operand: Array, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
from jax.experimental import array
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))):
return operand
if config.jax_dynamic_shapes:
# We must gate this behavior under a flag because otherwise the errors
@ -850,6 +854,8 @@ def reshape(operand: Array, new_sizes: Shape,
>>> reshape(y, (6,), (1, 0))
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
"""
from jax.experimental import array
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
@ -860,7 +866,7 @@ def reshape(operand: Array, new_sizes: Shape,
dims = api_util._ensure_index_tuple(dimensions)
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
if (np.shape(operand) and same_shape and same_dims
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
return operand
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)

View File

@ -1873,7 +1873,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# We can't use the ndarray class because we need to handle internal buffers
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer)
ndarray_types = (device_array.DeviceArray, core.Tracer, Array)
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
@ -4746,7 +4746,7 @@ _NOT_IMPLEMENTED = ['argpartition']
# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer)
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, Array)
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
def __array_module__(self, types):

View File

@ -348,7 +348,7 @@ class KeyTy:
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding):
def global_sharded_result_handler(aval, out_sharding, committed):
phys_aval, = KeyTy.physical_avals(aval)
key_shape = aval.dtype.impl.key_shape
@ -378,7 +378,7 @@ class KeyTy:
# a new op sharding with a trivially extended `tile_assignment_dimensions`
raise NotImplementedError
phys_handler = phys_handler_maker(phys_aval, phys_sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return handler

View File

@ -281,6 +281,15 @@ class Array:
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
def unsafe_buffer_pointer(self):
assert len(self._arrays) == 1
return self._arrays[0].unsafe_buffer_pointer()
@property
def __cuda_array_interface__(self):
assert len(self._arrays) == 1
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
self._check_if_deleted()
@ -434,19 +443,25 @@ def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
return _array_pmap_shard_arg(x, devices, indices, mode)
else:
return x._arrays
if isinstance(x.sharding, SingleDeviceSharding):
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
return x._arrays
pxla.shard_arg_handlers[Array] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding):
def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if core.aval_has_custom_eltype(global_aval):
return global_aval.dtype.global_sharded_result_handler(
global_aval, out_sharding)
else:
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
_skip_checks=True)
global_aval, out_sharding, committed)
return lambda bufs: Array(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
def _array_local_result_handler(aval, sharding, indices):

View File

@ -618,7 +618,10 @@ def _gda_shard_arg(x, devices, indices, mode):
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
def _gda_array_result_handler(global_aval, out_sharding):
def _gda_array_result_handler(global_aval, out_sharding, committed):
if core.aval_has_custom_eltype(global_aval):
return global_aval.dtype.global_sharded_result_handler(
global_aval, out_sharding, committed)
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh,
out_axis_resources)

View File

@ -28,7 +28,7 @@ from jax.experimental.sharding import (
from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg, devices
from jax._src.api import _check_callable, _check_arg, local_devices
from jax._src.config import config
from jax._src import dispatch
from jax._src import source_info_util
@ -922,9 +922,12 @@ def _pjit_lower_cached(
else:
# Pass `in_is_global` here because this path is taken by both host local
# avals and global avals.
# TODO(yashkatariya): Don't set committed to True always. Infer that from
# the arguments just like dispatch.py in `sharded_lowering`.
return pxla.lower_sharding_computation(
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global)
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
committed=True)
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
@ -1519,7 +1522,7 @@ def _get_and_check_device_assignment(shardings, pjit_mesh):
if first_device_assignment is None and not pjit_mesh.empty:
return mesh_devices
if first_device_assignment is None:
return [config.jax_default_device or devices()[0]]
return [config.jax_default_device or local_devices()[0]]
return first_device_assignment

View File

@ -348,7 +348,9 @@ class OpShardingSharding(XLACompatibleSharding):
return self._hash
def __repr__(self):
return repr(self._op_sharding)
if pxla.is_op_sharding_replicated(self._op_sharding):
return 'OpShardingSharding(REPLICATED)'
return f'OpShardingSharding({repr(self._op_sharding)})'
def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)

View File

@ -421,7 +421,14 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {}
def _shard_token(x, devices, indices, mode):
return device_put(np.zeros((), dtype=np.dtype(np.bool_)), devices, replicate=True)
shard_arg_handlers[core.Token] = _shard_token
def _shard_array(x, devices, indices, mode):
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
return device_put([x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
@ -581,7 +588,7 @@ local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_arra
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding,
aval: core.AbstractValue, out_sharding, committed: bool
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
@ -602,7 +609,8 @@ def global_aval_to_result_handler(
elif config.jax_parallel_functions_output_gda:
output_type = OutputType.GlobalDeviceArray
try:
return global_result_handlers[(type(aval), output_type)](aval, out_sharding)
return global_result_handlers[(type(aval), output_type)](
aval, out_sharding, committed)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
@ -1601,7 +1609,7 @@ class PmapExecutable(stages.XlaExecutable):
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks))
bool(host_callbacks), set(range(len(input_indices))))
fingerprint = getattr(compiled, "fingerprint", None)
return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)
@ -1836,12 +1844,13 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
shardings: Sequence[XLACompatibleSharding],
committed: bool) -> ResultsHandler:
from jax.experimental.sharding import MeshPspecSharding
if config.jax_parallel_functions_output_gda or config.jax_array:
handlers = [
global_aval_to_result_handler(global_aval, s)
global_aval_to_result_handler(global_aval, s, committed)
for global_aval, s in safe_zip(global_out_avals, shardings)
]
return ResultsHandler(handlers, shardings, global_out_avals)
@ -1959,13 +1968,14 @@ class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', '__weakref__']
'has_host_callbacks', '_local_devices', 'kept_var_idx',
'__weakref__']
def __init__(self, xla_executable, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect], keepalive: Any,
has_host_callbacks: bool):
has_host_callbacks: bool, kept_var_idx: Set[int]):
self.xla_executable = xla_executable
self.backend = backend
self.in_handler = in_handler
@ -1977,6 +1987,7 @@ class ExecuteReplicated:
assert len(self._local_devices) == 1
self.keepalive = keepalive
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
def _call_with_tokens(self, input_bufs):
# TODO(sharadmv): simplify this logic when minimum jaxlib version is
@ -2013,6 +2024,7 @@ class ExecuteReplicated:
@profiler.annotate_function
def __call__(self, *args):
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
input_bufs = self.in_handler(args)
if (self.ordered_effects or self.has_unordered_effects or
self.has_host_callbacks):
@ -2650,7 +2662,10 @@ def lower_sharding_computation(
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
in_is_global: Sequence[bool]):
in_is_global: Sequence[bool],
keep_unused: bool,
committed: bool,
traced_jaxpr_info: Optional[dispatch.TracedJaxprInfo] = None):
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton _UNSPECIFIED because the
@ -2667,6 +2682,7 @@ def lower_sharding_computation(
# UNSPECIFIED singleton are filtered above.
backend, first_sharding = _get_backend_from_shardings(
it.chain(in_shardings, out_shardings)) # type: ignore
device_assignment = first_sharding._device_assignment
name_stack = new_name_stack(wrap_name(fun_name, api_name))
@ -2678,11 +2694,13 @@ def lower_sharding_computation(
global_in_avals, in_shardings)
# 1. Trace to jaxpr and preprocess/verify it
in_jaxpr_avals = global_in_avals
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
if traced_jaxpr_info is None:
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"for sharded computation in {elapsed_time} sec"):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, "sharded computation"))
else:
jaxpr, out_jaxpr_avals, consts = traced_jaxpr_info
if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(out_jaxpr_avals)
@ -2692,33 +2710,49 @@ def lower_sharding_computation(
global_out_avals = out_jaxpr_avals
if keep_unused:
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
in_is_global = tuple(g for i, g in enumerate(in_is_global) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
_sanitize_mesh_jaxpr(jaxpr)
if not first_sharding.is_fully_addressable():
check_multihost_collective_allowlist(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
axis_ctx: mlir.ShardingContext
in_op_shardings = [i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)]
in_op_shardings = [
None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(global_in_avals, in_shardings)
]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)]
replicated_args = [False] * len(in_jaxpr_avals)
out_op_shardings = [
None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)
]
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
raise ValueError("Ordered effects not supported in mesh computations.")
if len(device_assignment) > 1:
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
@ -2726,7 +2760,8 @@ def lower_sharding_computation(
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
unordered_effects, ordered_effects,
unordered_effects,
ordered_effects,
backend,
backend.platform,
axis_ctx,
@ -2735,10 +2770,16 @@ def lower_sharding_computation(
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
# to get the device_assignment and backend. So pass it to MeshExecutable
# because we calculate the device_assignment and backend before in_shardings,
# etc are pruned.
return MeshComputation(
str(name_stack),
module,
@ -2755,7 +2796,11 @@ def lower_sharding_computation(
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=host_callbacks,
keepalive=keepalive)
keepalive=keepalive,
kept_var_idx=kept_var_idx,
backend=backend,
device_assignment=device_assignment,
committed=committed)
@profiler.annotate_function
@ -2903,7 +2948,11 @@ def lower_mesh_computation(
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=host_callbacks,
keepalive=keepalive)
keepalive=keepalive,
kept_var_idx=set(range(len(global_in_avals))),
backend=backend,
device_assignment=list(mesh.devices.flat),
committed=True)
class MeshComputation(stages.XlaLowering):
@ -2963,15 +3012,19 @@ def _get_input_metadata(
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec)
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
proto = sharding._to_xla_op_sharding(aval.ndim)
if is_op_sharding_replicated(proto):
index = tuple((slice(None),) * aval.ndim for _ in range(len(sharding.addressable_devices)))
if aval is core.abstract_token:
index = (slice(None),)
else:
index = tuple(sharding.devices_indices_map(aval.shape).values())
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
proto = sharding._to_xla_op_sharding(aval.ndim)
if is_op_sharding_replicated(proto):
index = tuple((slice(None),) * aval.ndim
for _ in range(len(sharding.addressable_devices))) # type: ignore
else:
index = tuple(sharding.devices_indices_map(aval.shape).values()) # type: ignore
shardings.append(sharding)
input_indices.append(index)
@ -3044,32 +3097,28 @@ class MeshExecutable(stages.XlaExecutable):
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any) -> MeshExecutable:
if auto_spmd_lowering:
assert mesh is not None
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
else:
backend, first_sharding = _get_backend_from_shardings(
it.chain(in_shardings, out_shardings)) # type: ignore
keepalive: Any,
kept_var_idx: Set[int],
backend: xb.XlaBackend,
device_assignment: Sequence[xc.Device],
committed: bool) -> MeshExecutable:
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
else:
dev = np.array(first_sharding._device_assignment)
dev = np.array(device_assignment)
if spmd_lowering:
num_replicas, num_partitions = 1, dev.size
else:
num_replicas, num_partitions = dev.size, 1
device_assignment = dev.reshape((num_replicas, num_partitions))
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
device_assignment=xla_device_assignment,
use_spmd_partitioning=spmd_lowering,
use_auto_spmd_partitioning=auto_spmd_lowering,
)
@ -3088,7 +3137,7 @@ class MeshExecutable(stages.XlaExecutable):
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings) # type: ignore # arg-type
global_out_avals, out_shardings, committed) # type: ignore # arg-type
unsafe_call = backend.compile_replicated(computation, compile_options,
host_callbacks, input_avals,
input_indices, in_shardings,
@ -3108,7 +3157,7 @@ class MeshExecutable(stages.XlaExecutable):
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
assert mesh is None
_, out_shardings_xla = _get_op_sharding_shardings_from_executable(
xla_executable, first_sharding._device_assignment,
xla_executable, device_assignment,
len(global_in_avals), len(global_out_avals))
out_shardings = [x if _is_unspecified(o) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
@ -3116,13 +3165,13 @@ class MeshExecutable(stages.XlaExecutable):
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings) # type: ignore # arg-type
global_out_avals, out_shardings, committed) # type: ignore # arg-type
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
input_indices, InputsHandlerMode.pjit_or_xmap)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
handle_outs, unordered_effects,
ordered_effects, keepalive,
bool(host_callbacks))
bool(host_callbacks), kept_var_idx)
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering)

View File

@ -78,6 +78,10 @@ jax_test(
jax_test(
name = "custom_object_test",
srcs = ["custom_object_test.py"],
# TODO(yashkatariya,mattjj,phawkins): Enable custom_object_test once
# `ExecuteReplicated` supports the use case of having more
# than 1 buffer on a single device.
disable_configs = ["cpu_jax_array"],
)
py_test(

View File

@ -239,13 +239,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert len(side) == 3
def test_jit_device(self):
if config.jax_array:
self.skipTest('The device parameter of jit has been deprecated. Array '
'is not compatible with it and will not work.')
device = jax.devices()[-1]
x = self.jit(lambda x: x, device=device)(3.)
_check_instance(self, x)
if config.jax_array:
self.assertEqual(x.device(), device)
else:
self.assertEqual(x.device_buffer.device(), device)
self.assertEqual(x.device_buffer.device(), device)
@jtu.skip_on_devices("cpu")
def test_jit_default_device(self):
@ -267,10 +267,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertEqual(f(1).device(), system_default_device)
with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(
jax.jit(f, device=system_default_device)(1).device(),
system_default_device)
# Skip this for jax.Array because using the device argument of `jit` is
# deprecated.
if not config.jax_array:
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(
jax.jit(f, device=system_default_device)(1).device(),
system_default_device)
out = jax.jit(f, backend="cpu")(1)
if config.jax_array:
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
@ -1067,7 +1070,10 @@ class CPPJitTest(jtu.BufferDonationTestCase):
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
with jtu.count_device_put() as count:
_ = jitted_f(1, 2)
self.assertEqual(count[0], 1)
if config.jax_array:
self.assertEqual(count[0], 2)
else:
self.assertEqual(count[0], 1)
@jtu.ignore_warning(category=DeprecationWarning)
def test_jit_lower_compile_compiler_ir(self):

View File

@ -38,6 +38,7 @@ import jax.util
from jax.interpreters import xla
from jax.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import pxla
from jax.experimental import array
from jax._src.lib.mlir.dialects import mhlo
from jax._src import dispatch
@ -3042,6 +3043,14 @@ class FooTy:
return FooArray(aval.shape, buf)
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
def handler(bufs):
buf, = bufs
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf)
return handler
# eltype-polymorphic primitive lowering rules
@staticmethod
@ -3152,6 +3161,12 @@ def device_put_foo_array(x: FooArray, device):
return array._device_put_array(x.data, device)
return dispatch._device_put_array(x.data, device)
def shard_foo_array_handler(x, devices, indices, mode):
device, = devices
if isinstance(x.data, array.Array):
return array._device_put_array(x.data, device)
return dispatch._device_put_array(x.data, device)
def foo_array_constant_handler(x, c):
if config.jax_array:
return array._array_mlir_constant_handler(x.data, c)
@ -3186,6 +3201,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
xla.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
dispatch.device_put_handlers[FooArray] = device_put_foo_array
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
mlir._constant_handlers[FooArray] = foo_array_constant_handler
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))

View File

@ -224,6 +224,33 @@ class MultiDeviceTest(jtu.JaxTestCase):
y = jax.device_put(1, devices[2]) + jnp.ones((2, 3))
self.assert_committed_to_device(y, devices[2])
def test_single_input_committed_multi_output(self):
if jax.device_count() < 3:
self.skipTest("Test requires 3 devices")
devices = self.get_devices()
@jax.jit
def f(a, b, c, d, e):
return a, b, c, d, e
outs = f(jax.device_put(1, devices[2]), jnp.array(2), jnp.array(3),
jnp.array(4), jnp.array(5))
for o in outs:
self.assert_committed_to_device(o, devices[2])
def test_different_devices_input_error(self):
if jax.device_count() < 2:
self.skipTest("Test requires 2 devices")
devices = self.get_devices()
a = jax.device_put(1, devices[0])
b = jax.device_put(2, devices[1])
# Don't look for the message because the Array and non-Array path raise
# slightly different error messages.
with self.assertRaises(ValueError):
_ = a + b
def test_transpose(self):
if jax.device_count() < 3:
self.skipTest("test requires 3 devices")

View File

@ -173,9 +173,12 @@ class MultiBackendTest(jtu.JaxTestCase):
result2 = jax.jit(my_sin)(data_on_cpu)
self.assertEqual(result2.device(), cpus[0])
# jit with `device` spec places the data on the specified device
result3 = jax.jit(my_sin, device=cpus[0])(2)
self.assertEqual(result3.device(), cpus[0])
# Skip this for jax.Array because using the device argument of `jit` is
# deprecated.
if not config.jax_array:
# jit with `device` spec places the data on the specified device\
result3 = jax.jit(my_sin, device=cpus[0])(2)
self.assertEqual(result3.device(), cpus[0])
# jit with `backend` spec places the data on the specified backend
result4 = jax.jit(my_sin, backend="cpu")(2)

View File

@ -137,8 +137,6 @@ class PythonPmapTest(jtu.JaxTestCase):
def pmap(self):
return src_api._python_pmap
# TODO(yashkatariya): Re-enable when unsafe_buffer_pointer is implemented
@unittest.skipIf(config.jax_array, "Array does not yet implement unsafe_buffer_pointer")
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))

View File

@ -1451,7 +1451,10 @@ class LaxRandomTest(jtu.JaxTestCase):
key = self.seed_prng(1).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
self.assertEqual(count[0], 1) # 1 for the argument device_put
if config.jax_array:
self.assertEqual(count[0], 0)
else:
self.assertEqual(count[0], 1) # 1 for the argument device_put
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}