diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 7731eaa3e..de7bb6d70 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -101,9 +101,6 @@ def pure_callback_lowering(ctx, *args, callback, **params): if ctx.module_context.platform == "TPU" and jaxlib.version < (0, 3, 15): raise NotImplementedError("Pure callbacks on TPU not supported. " "Please upgrade to a jaxlib >= 0.3.15.") - if isinstance(ctx.module_context.axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)): - raise NotImplementedError("Sharding for pure callback not implemented.") def _callback(*flat_args): return tuple(pure_callback_p.impl(*flat_args, callback=callback, **params)) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index a32aa6bf1..ebd6f31d7 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -97,8 +97,6 @@ def arg_spec(x: Any) -> ArgSpec: aval = xla.abstractify(x) try: if config.jax_array: - if isinstance(x.sharding, PmapSharding): - return aval, None return aval, (x.sharding if x._committed else None) else: return aval, x._device @@ -182,7 +180,7 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params): _, arg_devices = util.unzip2(arg_specs) donated_invars = (False,) * len(arg_specs) if config.jax_array: - # This will be resolved in _xla_callable_device. + # This will be resolved in sharded_lowering. device = None else: device = _device_from_arg_devices(arg_devices) @@ -277,6 +275,10 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, xla.xla_call_p.def_impl(_xla_call_impl) +TracedJaxprInfo = collections.namedtuple( + 'TracedJaxprInfo', ['jaxpr', 'out_jaxpr_avals', 'consts']) + + def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, *arg_specs): # TODO(yashkatariya): Remove the local imports from here when the functions @@ -286,29 +288,40 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused, in_avals, in_shardings = util.unzip2(arg_specs) - # TODO(yashkatariya): Remove this and make `SingleDeviceSharding` go through - # lower_sharding_computation and resolve all the errors once that happens. - # For pmap, keep using the fallback by checking the jaxpr and then wrapping it - # in a lu.Wrappedfun again. - if any(s is None or isinstance(s, sharding.SingleDeviceSharding) for s in in_shardings): + with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " + "in {elapsed_time} sec"): + jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final( + fun, in_avals, debug_info=pe.debug_info_final(fun, "jit")) + traced_jaxpr_info = TracedJaxprInfo(jaxpr, out_jaxpr_avals, consts) + + # If jaxpr has the pmap primitive or if `backend` is provided on `jit`, then + # take the lower_xla_callable lowering path. This is because pmap's programming + # model is not compatible with lower_sharding_computation. + # Specifying backend on `jit` is not supported when Array is enabled. So take + # the `lower_xla_callable` path which can handle it. + if (jaxpr_has_primitive(jaxpr, 'xla_pmap') or + any(isinstance(s, sharding.PmapSharding) for s in in_shardings) or + backend is not None): arg_specs = tuple( (a, s._device) if isinstance(s, sharding.SingleDeviceSharding) else (a, None) for a, s in zip(in_avals, in_shardings)) - return lower_xla_callable(fun, device, backend, name, donated_invars, False, - keep_unused, *arg_specs).compile().unsafe_call + return lower_xla_callable( + fun, None, backend, name, donated_invars, False, keep_unused, *arg_specs, + traced_jaxpr_info=traced_jaxpr_info).compile().unsafe_call + committed = any(i is not None for i in in_shardings) da = pjit._get_and_check_device_assignment( (i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh) in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i for i in in_shardings] - # Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. return pxla.lower_sharding_computation( fun, 'xla_callable', name, in_shardings, pjit._UNSPECIFIED, donated_invars, in_avals, - in_is_global=(True,) * len(arg_specs)).compile( + in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused, + committed=committed, traced_jaxpr_info=traced_jaxpr_info).compile( _allow_propagation_to_outputs=True).unsafe_call @@ -347,9 +360,10 @@ def should_tuple_args(num_args: int, platform: str): @profiler.annotate_function -def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, - donated_invars, always_lower: bool, keep_unused: bool, - *arg_specs): +def lower_xla_callable( + fun: lu.WrappedFun, device, backend, name, donated_invars, + always_lower: bool, keep_unused: bool, *arg_specs, + traced_jaxpr_info: Optional[TracedJaxprInfo] = None): """Lower into XLA. Args: @@ -371,11 +385,18 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, else: assert abstract_args == (None,) * len(abstract_args) abstract_args = [aval for aval, _ in fun.in_type] - with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for jit in {elapsed_time} sec"): - jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( - fun, pe.debug_info_final(fun, "jit")) - out_avals, kept_outputs = util.unzip2(out_type) + + if traced_jaxpr_info is None: + with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " + "for jit in {elapsed_time} sec"): + jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( + fun, pe.debug_info_final(fun, "jit")) + out_avals, kept_outputs = util.unzip2(out_type) + else: + jaxpr, out_avals, consts = traced_jaxpr_info + kept_outputs = [True] * len(out_avals) + out_type = tuple(zip(out_avals, kept_outputs)) + if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a3a1fd316..056775b84 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -548,6 +548,8 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array: def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None, weak_type: bool = False): + from jax.experimental import array + # Don't canonicalize old_dtype because x64 context might cause # un-canonicalized operands to be passed in. old_dtype = dtypes.dtype(operand, canonicalize=False) @@ -575,7 +577,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None, old_weak_type = False if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type) - and isinstance(operand, (core.Tracer, device_array.DeviceArray))): + and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))): return operand else: return convert_element_type_p.bind(operand, new_dtype=new_dtype, @@ -794,8 +796,10 @@ def broadcast_in_dim(operand: Array, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ + from jax.experimental import array + if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) - and isinstance(operand, (device_array.DeviceArray, core.Tracer))): + and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))): return operand if config.jax_dynamic_shapes: # We must gate this behavior under a flag because otherwise the errors @@ -850,6 +854,8 @@ def reshape(operand: Array, new_sizes: Shape, >>> reshape(y, (6,), (1, 0)) DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32) """ + from jax.experimental import array + new_sizes = canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes) @@ -860,7 +866,7 @@ def reshape(operand: Array, new_sizes: Shape, dims = api_util._ensure_index_tuple(dimensions) same_dims = tuple(dims) == tuple(range(np.ndim(operand))) if (np.shape(operand) and same_shape and same_dims - and isinstance(operand, (core.Tracer, device_array.DeviceArray))): + and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))): return operand else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f065945d0..e2eefce11 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1873,7 +1873,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0): # We can't use the ndarray class because we need to handle internal buffers # (See https://github.com/google/jax/issues/8950) - ndarray_types = (device_array.DeviceArray, core.Tracer) + ndarray_types = (device_array.DeviceArray, core.Tracer, Array) if not _any(isinstance(leaf, ndarray_types) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists @@ -4746,7 +4746,7 @@ _NOT_IMPLEMENTED = ['argpartition'] # Experimental support for NumPy's module dispatch with NEP-37. # Currently requires https://github.com/seberg/numpy-dispatch -_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer) +_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, Array) _HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) def __array_module__(self, types): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 83706b3e9..452dca819 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -348,7 +348,7 @@ class KeyTy: return handler @staticmethod - def global_sharded_result_handler(aval, out_sharding): + def global_sharded_result_handler(aval, out_sharding, committed): phys_aval, = KeyTy.physical_avals(aval) key_shape = aval.dtype.impl.key_shape @@ -378,7 +378,7 @@ class KeyTy: # a new op sharding with a trivially extended `tile_assignment_dimensions` raise NotImplementedError - phys_handler = phys_handler_maker(phys_aval, phys_sharding) + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) def handler(bufs): return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs)) return handler diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 1008a43e5..40c99fd16 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -281,6 +281,15 @@ class Array: 'named_shape': self.aval.named_shape} return (_reconstruct_array, (fun, args, arr_state, aval_state)) + def unsafe_buffer_pointer(self): + assert len(self._arrays) == 1 + return self._arrays[0].unsafe_buffer_pointer() + + @property + def __cuda_array_interface__(self): + assert len(self._arrays) == 1 + return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties + # TODO(yashkatariya): Remove this method when everyone is using devices(). def device(self) -> Device: self._check_if_deleted() @@ -434,19 +443,25 @@ def _array_shard_arg(x, devices, indices, mode): if mode == pxla.InputsHandlerMode.pmap: return _array_pmap_shard_arg(x, devices, indices, mode) else: - return x._arrays + if isinstance(x.sharding, SingleDeviceSharding): + return [buf if buf.device() == d else buf.copy_to_device(d) + for buf, d in safe_zip(x._arrays, devices)] + else: + return x._arrays pxla.shard_arg_handlers[Array] = _array_shard_arg -def _array_global_result_handler(global_aval, out_sharding): +def _array_global_result_handler(global_aval, out_sharding, committed): + if global_aval.dtype == dtypes.float0: + return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore if core.aval_has_custom_eltype(global_aval): return global_aval.dtype.global_sharded_result_handler( - global_aval, out_sharding) - else: - return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True, - _skip_checks=True) + global_aval, out_sharding, committed) + return lambda bufs: Array(global_aval, out_sharding, bufs, + committed=committed, _skip_checks=True) pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler +pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token def _array_local_result_handler(aval, sharding, indices): diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 5749dc3b3..26ddcd0c3 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -618,7 +618,10 @@ def _gda_shard_arg(x, devices, indices, mode): pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg -def _gda_array_result_handler(global_aval, out_sharding): +def _gda_array_result_handler(global_aval, out_sharding, committed): + if core.aval_has_custom_eltype(global_aval): + return global_aval.dtype.global_sharded_result_handler( + global_aval, out_sharding, committed) global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh, out_axis_resources) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index f7bf16d6f..7596e1baa 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -28,7 +28,7 @@ from jax.experimental.sharding import ( from jax import core from jax import linear_util as lu from jax import stages -from jax._src.api import _check_callable, _check_arg, devices +from jax._src.api import _check_callable, _check_arg, local_devices from jax._src.config import config from jax._src import dispatch from jax._src import source_info_util @@ -922,9 +922,12 @@ def _pjit_lower_cached( else: # Pass `in_is_global` here because this path is taken by both host local # avals and global avals. + # TODO(yashkatariya): Don't set committed to True always. Infer that from + # the arguments just like dispatch.py in `sharded_lowering`. return pxla.lower_sharding_computation( fun, 'pjit', name, in_shardings, out_shardings, donated_invars, - jaxpr.in_avals, in_is_global=in_is_global) + jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True, + committed=True) def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, @@ -1519,7 +1522,7 @@ def _get_and_check_device_assignment(shardings, pjit_mesh): if first_device_assignment is None and not pjit_mesh.empty: return mesh_devices if first_device_assignment is None: - return [config.jax_default_device or devices()[0]] + return [config.jax_default_device or local_devices()[0]] return first_device_assignment diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index 1ac2343df..3a119698d 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -348,7 +348,9 @@ class OpShardingSharding(XLACompatibleSharding): return self._hash def __repr__(self): - return repr(self._op_sharding) + if pxla.is_op_sharding_replicated(self._op_sharding): + return 'OpShardingSharding(REPLICATED)' + return f'OpShardingSharding({repr(self._op_sharding)})' def is_compatible_aval(self, aval_shape: Shape): num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 0121ebe76..9edc20559 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -421,7 +421,14 @@ def shard_args(devices: Sequence[xb.xla_client.Device], shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {} + +def _shard_token(x, devices, indices, mode): + return device_put(np.zeros((), dtype=np.dtype(np.bool_)), devices, replicate=True) +shard_arg_handlers[core.Token] = _shard_token + def _shard_array(x, devices, indices, mode): + if x.dtype == dtypes.float0: + x = np.zeros(x.shape, dtype=np.dtype(bool)) return device_put([x[i] for i in indices], devices) for _t in array_types: shard_arg_handlers[_t] = _shard_array @@ -581,7 +588,7 @@ local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_arra def global_aval_to_result_handler( - aval: core.AbstractValue, out_sharding, + aval: core.AbstractValue, out_sharding, committed: bool ) -> Callable[[List[xb.xla_client.Buffer]], Any]: """Returns a function for handling the raw buffers of a single output aval. @@ -602,7 +609,8 @@ def global_aval_to_result_handler( elif config.jax_parallel_functions_output_gda: output_type = OutputType.GlobalDeviceArray try: - return global_result_handlers[(type(aval), output_type)](aval, out_sharding) + return global_result_handlers[(type(aval), output_type)]( + aval, out_sharding, committed) except KeyError as err: raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err @@ -1601,7 +1609,7 @@ class PmapExecutable(stages.XlaExecutable): execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args, handle_outs, unordered_effects, ordered_effects, keepalive, - bool(host_callbacks)) + bool(host_callbacks), set(range(len(input_indices)))) fingerprint = getattr(compiled, "fingerprint", None) return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals) @@ -1836,12 +1844,13 @@ def local_avals_to_results_handler( def global_avals_to_results_handler( global_out_avals: Sequence[ShapedArray], - shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler: + shardings: Sequence[XLACompatibleSharding], + committed: bool) -> ResultsHandler: from jax.experimental.sharding import MeshPspecSharding if config.jax_parallel_functions_output_gda or config.jax_array: handlers = [ - global_aval_to_result_handler(global_aval, s) + global_aval_to_result_handler(global_aval, s, committed) for global_aval, s in safe_zip(global_out_avals, shardings) ] return ResultsHandler(handlers, shardings, global_out_avals) @@ -1959,13 +1968,14 @@ class ExecuteReplicated: """The logic to shard inputs, execute a replicated model, returning outputs.""" __slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', - 'has_host_callbacks', '_local_devices', '__weakref__'] + 'has_host_callbacks', '_local_devices', 'kept_var_idx', + '__weakref__'] def __init__(self, xla_executable, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: List[core.Effect], ordered_effects: List[core.Effect], keepalive: Any, - has_host_callbacks: bool): + has_host_callbacks: bool, kept_var_idx: Set[int]): self.xla_executable = xla_executable self.backend = backend self.in_handler = in_handler @@ -1977,6 +1987,7 @@ class ExecuteReplicated: assert len(self._local_devices) == 1 self.keepalive = keepalive self.has_host_callbacks = has_host_callbacks + self.kept_var_idx = kept_var_idx def _call_with_tokens(self, input_bufs): # TODO(sharadmv): simplify this logic when minimum jaxlib version is @@ -2013,6 +2024,7 @@ class ExecuteReplicated: @profiler.annotate_function def __call__(self, *args): + args = [x for i, x in enumerate(args) if i in self.kept_var_idx] input_bufs = self.in_handler(args) if (self.ordered_effects or self.has_unordered_effects or self.has_host_callbacks): @@ -2650,7 +2662,10 @@ def lower_sharding_computation( out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue], donated_invars: Sequence[bool], global_in_avals: Sequence[core.ShapedArray], - in_is_global: Sequence[bool]): + in_is_global: Sequence[bool], + keep_unused: bool, + committed: bool, + traced_jaxpr_info: Optional[dispatch.TracedJaxprInfo] = None): """Lowers a computation to XLA. It can take arbitrary shardings as input. The caller of this code can pass in a singleton _UNSPECIFIED because the @@ -2667,6 +2682,7 @@ def lower_sharding_computation( # UNSPECIFIED singleton are filtered above. backend, first_sharding = _get_backend_from_shardings( it.chain(in_shardings, out_shardings)) # type: ignore + device_assignment = first_sharding._device_assignment name_stack = new_name_stack(wrap_name(fun_name, api_name)) @@ -2678,11 +2694,13 @@ def lower_sharding_computation( global_in_avals, in_shardings) # 1. Trace to jaxpr and preprocess/verify it - in_jaxpr_avals = global_in_avals - - with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec"): - jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals) + if traced_jaxpr_info is None: + with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " + "for sharded computation in {elapsed_time} sec"): + jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final( + fun, global_in_avals, debug_info=pe.debug_info_final(fun, "sharded computation")) + else: + jaxpr, out_jaxpr_avals, consts = traced_jaxpr_info if _is_unspecified(out_shardings): out_shardings = (_UNSPECIFIED,) * len(out_jaxpr_avals) @@ -2692,33 +2710,49 @@ def lower_sharding_computation( global_out_avals = out_jaxpr_avals + if keep_unused: + kept_var_idx = set(range(len(global_in_avals))) + else: + jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr) + consts = [c for i, c in enumerate(consts) if i in kept_const_idx] + global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx) + in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) + in_is_global = tuple(g for i, g in enumerate(in_is_global) if i in kept_var_idx) + donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) + del kept_const_idx + _sanitize_mesh_jaxpr(jaxpr) if not first_sharding.is_fully_addressable(): check_multihost_collective_allowlist(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) # 2. Build up the HLO - tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform) + tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) in_op_shardings: Optional[List[Optional[xc.OpSharding]]] out_op_shardings: Optional[List[Optional[xc.OpSharding]]] axis_ctx: mlir.ShardingContext - in_op_shardings = [i._to_xla_op_sharding(aval.ndim) - for aval, i in safe_zip(global_in_avals, in_shardings)] + in_op_shardings = [ + None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim) + for aval, i in safe_zip(global_in_avals, in_shardings) + ] # TODO(yashkatariya): Fix the HLO produced if out_partitions is # [None, OpShardingProto] has the sharding annotations. - out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim) - for aval, o in safe_zip(global_out_avals, out_shardings)] - replicated_args = [False] * len(in_jaxpr_avals) + out_op_shardings = [ + None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim) + for aval, o in safe_zip(global_out_avals, out_shardings) + ] + replicated_args = [False] * len(global_in_avals) axis_ctx = mlir.ShardingContext(first_sharding) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module: Union[str, xc.XlaComputation] module_name = f"{api_name}_{fun_name}" - if any(eff in core.ordered_effects for eff in closed_jaxpr.effects): - raise ValueError("Ordered effects not supported in mesh computations.") + if len(device_assignment) > 1: + if any(eff in core.ordered_effects for eff in closed_jaxpr.effects): + raise ValueError("Ordered effects are not supported for more than 1 device.") unordered_effects = [eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects] ordered_effects = [eff for eff in closed_jaxpr.effects @@ -2726,7 +2760,8 @@ def lower_sharding_computation( lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - unordered_effects, ordered_effects, + unordered_effects, + ordered_effects, backend, backend.platform, axis_ctx, @@ -2735,10 +2770,16 @@ def lower_sharding_computation( replicated_args=replicated_args, arg_shardings=in_op_shardings, result_shardings=out_op_shardings) + module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) + # backend and device_assignment is passed through to MeshExecutable because + # if keep_unused=False and all in_shardings are pruned, then there is no way + # to get the device_assignment and backend. So pass it to MeshExecutable + # because we calculate the device_assignment and backend before in_shardings, + # etc are pruned. return MeshComputation( str(name_stack), module, @@ -2755,7 +2796,11 @@ def lower_sharding_computation( unordered_effects=unordered_effects, ordered_effects=ordered_effects, host_callbacks=host_callbacks, - keepalive=keepalive) + keepalive=keepalive, + kept_var_idx=kept_var_idx, + backend=backend, + device_assignment=device_assignment, + committed=committed) @profiler.annotate_function @@ -2903,7 +2948,11 @@ def lower_mesh_computation( unordered_effects=unordered_effects, ordered_effects=ordered_effects, host_callbacks=host_callbacks, - keepalive=keepalive) + keepalive=keepalive, + kept_var_idx=set(range(len(global_in_avals))), + backend=backend, + device_assignment=list(mesh.devices.flat), + committed=True) class MeshComputation(stages.XlaLowering): @@ -2963,15 +3012,19 @@ def _get_input_metadata( aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval) sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec) - # We special case this logic to support fully replicated values because - # the mesh is global mesh and the indices returned by `spec_to_indices` will - # represent index for each device in the global mesh. But here we want - # indices for the local devices of the global mesh. - proto = sharding._to_xla_op_sharding(aval.ndim) - if is_op_sharding_replicated(proto): - index = tuple((slice(None),) * aval.ndim for _ in range(len(sharding.addressable_devices))) + if aval is core.abstract_token: + index = (slice(None),) else: - index = tuple(sharding.devices_indices_map(aval.shape).values()) + # We special case this logic to support fully replicated values because + # the mesh is global mesh and the indices returned by `spec_to_indices` will + # represent index for each device in the global mesh. But here we want + # indices for the local devices of the global mesh. + proto = sharding._to_xla_op_sharding(aval.ndim) + if is_op_sharding_replicated(proto): + index = tuple((slice(None),) * aval.ndim + for _ in range(len(sharding.addressable_devices))) # type: ignore + else: + index = tuple(sharding.devices_indices_map(aval.shape).values()) # type: ignore shardings.append(sharding) input_indices.append(index) @@ -3044,32 +3097,28 @@ class MeshExecutable(stages.XlaExecutable): unordered_effects: List[core.Effect], ordered_effects: List[core.Effect], host_callbacks: List[Any], - keepalive: Any) -> MeshExecutable: - if auto_spmd_lowering: - assert mesh is not None - assert not mesh.empty - backend = xb.get_device_backend(mesh.devices.flat[0]) - else: - backend, first_sharding = _get_backend_from_shardings( - it.chain(in_shardings, out_shardings)) # type: ignore - + keepalive: Any, + kept_var_idx: Set[int], + backend: xb.XlaBackend, + device_assignment: Sequence[xc.Device], + committed: bool) -> MeshExecutable: dev: np.ndarray if auto_spmd_lowering: assert mesh is not None and spmd_lowering dev = mesh.devices num_replicas, num_partitions = 1, mesh.size else: - dev = np.array(first_sharding._device_assignment) + dev = np.array(device_assignment) if spmd_lowering: num_replicas, num_partitions = 1, dev.size else: num_replicas, num_partitions = dev.size, 1 - device_assignment = dev.reshape((num_replicas, num_partitions)) + xla_device_assignment = dev.reshape((num_replicas, num_partitions)) compile_options = xb.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, - device_assignment=device_assignment, + device_assignment=xla_device_assignment, use_spmd_partitioning=spmd_lowering, use_auto_spmd_partitioning=auto_spmd_lowering, ) @@ -3088,7 +3137,7 @@ class MeshExecutable(stages.XlaExecutable): in_shardings, input_indices, input_avals = _get_input_metadata( global_in_avals, in_shardings, in_is_global) # type: ignore handle_outs = global_avals_to_results_handler( - global_out_avals, out_shardings) # type: ignore # arg-type + global_out_avals, out_shardings, committed) # type: ignore # arg-type unsafe_call = backend.compile_replicated(computation, compile_options, host_callbacks, input_avals, input_indices, in_shardings, @@ -3108,7 +3157,7 @@ class MeshExecutable(stages.XlaExecutable): elif out_shardings and any(_is_unspecified(o) for o in out_shardings): assert mesh is None _, out_shardings_xla = _get_op_sharding_shardings_from_executable( - xla_executable, first_sharding._device_assignment, + xla_executable, device_assignment, len(global_in_avals), len(global_out_avals)) out_shardings = [x if _is_unspecified(o) else o for x, o in safe_zip(out_shardings_xla, out_shardings)] @@ -3116,13 +3165,13 @@ class MeshExecutable(stages.XlaExecutable): in_shardings, input_indices, input_avals = _get_input_metadata( global_in_avals, in_shardings, in_is_global) # type: ignore handle_outs = global_avals_to_results_handler( - global_out_avals, out_shardings) # type: ignore # arg-type + global_out_avals, out_shardings, committed) # type: ignore # arg-type handle_args = InputsHandler(xla_executable.local_devices(), in_shardings, input_indices, InputsHandlerMode.pjit_or_xmap) unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs, unordered_effects, ordered_effects, keepalive, - bool(host_callbacks)) + bool(host_callbacks), kept_var_idx) return MeshExecutable(xla_executable, unsafe_call, input_avals, in_shardings, out_shardings, auto_spmd_lowering) diff --git a/tests/BUILD b/tests/BUILD index 5ae6e5a01..7c9a8d3d8 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -78,6 +78,10 @@ jax_test( jax_test( name = "custom_object_test", srcs = ["custom_object_test.py"], + # TODO(yashkatariya,mattjj,phawkins): Enable custom_object_test once + # `ExecuteReplicated` supports the use case of having more + # than 1 buffer on a single device. + disable_configs = ["cpu_jax_array"], ) py_test( diff --git a/tests/api_test.py b/tests/api_test.py index dca393fc7..2a8680f62 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -239,13 +239,13 @@ class CPPJitTest(jtu.BufferDonationTestCase): assert len(side) == 3 def test_jit_device(self): + if config.jax_array: + self.skipTest('The device parameter of jit has been deprecated. Array ' + 'is not compatible with it and will not work.') device = jax.devices()[-1] x = self.jit(lambda x: x, device=device)(3.) _check_instance(self, x) - if config.jax_array: - self.assertEqual(x.device(), device) - else: - self.assertEqual(x.device_buffer.device(), device) + self.assertEqual(x.device_buffer.device(), device) @jtu.skip_on_devices("cpu") def test_jit_default_device(self): @@ -267,10 +267,13 @@ class CPPJitTest(jtu.BufferDonationTestCase): self.assertEqual(f(1).device(), system_default_device) with jax.default_device(test_device): - # Explicit `device` or `backend` argument to jit overrides default_device - self.assertEqual( - jax.jit(f, device=system_default_device)(1).device(), - system_default_device) + # Skip this for jax.Array because using the device argument of `jit` is + # deprecated. + if not config.jax_array: + # Explicit `device` or `backend` argument to jit overrides default_device + self.assertEqual( + jax.jit(f, device=system_default_device)(1).device(), + system_default_device) out = jax.jit(f, backend="cpu")(1) if config.jax_array: self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding) @@ -1067,7 +1070,10 @@ class CPPJitTest(jtu.BufferDonationTestCase): jitted_f = self.jit(lambda x, y: x, keep_unused=True) with jtu.count_device_put() as count: _ = jitted_f(1, 2) - self.assertEqual(count[0], 1) + if config.jax_array: + self.assertEqual(count[0], 2) + else: + self.assertEqual(count[0], 1) @jtu.ignore_warning(category=DeprecationWarning) def test_jit_lower_compile_compiler_ir(self): diff --git a/tests/lax_test.py b/tests/lax_test.py index d998f99ce..cd67a8483 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -38,6 +38,7 @@ import jax.util from jax.interpreters import xla from jax.interpreters import mlir from jax.interpreters import batching +from jax.interpreters import pxla from jax.experimental import array from jax._src.lib.mlir.dialects import mhlo from jax._src import dispatch @@ -3042,6 +3043,14 @@ class FooTy: return FooArray(aval.shape, buf) return handler + @staticmethod + def global_sharded_result_handler(aval, out_sharding, committed): + def handler(bufs): + buf, = bufs + buf.aval = core.ShapedArray(buf.shape, buf.dtype) + return FooArray(aval.shape, buf) + return handler + # eltype-polymorphic primitive lowering rules @staticmethod @@ -3152,6 +3161,12 @@ def device_put_foo_array(x: FooArray, device): return array._device_put_array(x.data, device) return dispatch._device_put_array(x.data, device) +def shard_foo_array_handler(x, devices, indices, mode): + device, = devices + if isinstance(x.data, array.Array): + return array._device_put_array(x.data, device) + return dispatch._device_put_array(x.data, device) + def foo_array_constant_handler(x, c): if config.jax_array: return array._array_mlir_constant_handler(x.data, c) @@ -3186,6 +3201,7 @@ class CustomElementTypesTest(jtu.JaxTestCase): xla.pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy()) dispatch.device_put_handlers[FooArray] = device_put_foo_array + pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler mlir._constant_handlers[FooArray] = foo_array_constant_handler mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False)) mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False)) diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 24e75f60c..c254e7b34 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -224,6 +224,33 @@ class MultiDeviceTest(jtu.JaxTestCase): y = jax.device_put(1, devices[2]) + jnp.ones((2, 3)) self.assert_committed_to_device(y, devices[2]) + def test_single_input_committed_multi_output(self): + if jax.device_count() < 3: + self.skipTest("Test requires 3 devices") + devices = self.get_devices() + + @jax.jit + def f(a, b, c, d, e): + return a, b, c, d, e + + outs = f(jax.device_put(1, devices[2]), jnp.array(2), jnp.array(3), + jnp.array(4), jnp.array(5)) + for o in outs: + self.assert_committed_to_device(o, devices[2]) + + def test_different_devices_input_error(self): + if jax.device_count() < 2: + self.skipTest("Test requires 2 devices") + devices = self.get_devices() + + a = jax.device_put(1, devices[0]) + b = jax.device_put(2, devices[1]) + + # Don't look for the message because the Array and non-Array path raise + # slightly different error messages. + with self.assertRaises(ValueError): + _ = a + b + def test_transpose(self): if jax.device_count() < 3: self.skipTest("test requires 3 devices") diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 54409a257..9925cc796 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -173,9 +173,12 @@ class MultiBackendTest(jtu.JaxTestCase): result2 = jax.jit(my_sin)(data_on_cpu) self.assertEqual(result2.device(), cpus[0]) - # jit with `device` spec places the data on the specified device - result3 = jax.jit(my_sin, device=cpus[0])(2) - self.assertEqual(result3.device(), cpus[0]) + # Skip this for jax.Array because using the device argument of `jit` is + # deprecated. + if not config.jax_array: + # jit with `device` spec places the data on the specified device\ + result3 = jax.jit(my_sin, device=cpus[0])(2) + self.assertEqual(result3.device(), cpus[0]) # jit with `backend` spec places the data on the specified backend result4 = jax.jit(my_sin, backend="cpu")(2) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 33101e291..4cd15f788 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -137,8 +137,6 @@ class PythonPmapTest(jtu.JaxTestCase): def pmap(self): return src_api._python_pmap - # TODO(yashkatariya): Re-enable when unsafe_buffer_pointer is implemented - @unittest.skipIf(config.jax_array, "Array does not yet implement unsafe_buffer_pointer") def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) diff --git a/tests/random_test.py b/tests/random_test.py index 2773bb4c7..ab86418e4 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1451,7 +1451,10 @@ class LaxRandomTest(jtu.JaxTestCase): key = self.seed_prng(1).block_until_ready() with jtu.count_device_put() as count: jax.jit(random.split)(key) - self.assertEqual(count[0], 1) # 1 for the argument device_put + if config.jax_array: + self.assertEqual(count[0], 0) + else: + self.assertEqual(count[0], 1) # 1 for the argument device_put @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={dtype}", "dtype": dtype}