mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path. Added support for `keep_unused`, `committed` and `core.Token` to pxla.py. PiperOrigin-RevId: 470896270
This commit is contained in:
parent
10b2e210ed
commit
6340952e2a
@ -101,9 +101,6 @@ def pure_callback_lowering(ctx, *args, callback, **params):
|
||||
if ctx.module_context.platform == "TPU" and jaxlib.version < (0, 3, 15):
|
||||
raise NotImplementedError("Pure callbacks on TPU not supported. "
|
||||
"Please upgrade to a jaxlib >= 0.3.15.")
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
raise NotImplementedError("Sharding for pure callback not implemented.")
|
||||
|
||||
def _callback(*flat_args):
|
||||
return tuple(pure_callback_p.impl(*flat_args, callback=callback, **params))
|
||||
|
@ -97,8 +97,6 @@ def arg_spec(x: Any) -> ArgSpec:
|
||||
aval = xla.abstractify(x)
|
||||
try:
|
||||
if config.jax_array:
|
||||
if isinstance(x.sharding, PmapSharding):
|
||||
return aval, None
|
||||
return aval, (x.sharding if x._committed else None)
|
||||
else:
|
||||
return aval, x._device
|
||||
@ -182,7 +180,7 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
_, arg_devices = util.unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
if config.jax_array:
|
||||
# This will be resolved in _xla_callable_device.
|
||||
# This will be resolved in sharded_lowering.
|
||||
device = None
|
||||
else:
|
||||
device = _device_from_arg_devices(arg_devices)
|
||||
@ -277,6 +275,10 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
xla.xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
|
||||
TracedJaxprInfo = collections.namedtuple(
|
||||
'TracedJaxprInfo', ['jaxpr', 'out_jaxpr_avals', 'consts'])
|
||||
|
||||
|
||||
def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
|
||||
*arg_specs):
|
||||
# TODO(yashkatariya): Remove the local imports from here when the functions
|
||||
@ -286,29 +288,40 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
|
||||
|
||||
in_avals, in_shardings = util.unzip2(arg_specs)
|
||||
|
||||
# TODO(yashkatariya): Remove this and make `SingleDeviceSharding` go through
|
||||
# lower_sharding_computation and resolve all the errors once that happens.
|
||||
# For pmap, keep using the fallback by checking the jaxpr and then wrapping it
|
||||
# in a lu.Wrappedfun again.
|
||||
if any(s is None or isinstance(s, sharding.SingleDeviceSharding) for s in in_shardings):
|
||||
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"in {elapsed_time} sec"):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, in_avals, debug_info=pe.debug_info_final(fun, "jit"))
|
||||
traced_jaxpr_info = TracedJaxprInfo(jaxpr, out_jaxpr_avals, consts)
|
||||
|
||||
# If jaxpr has the pmap primitive or if `backend` is provided on `jit`, then
|
||||
# take the lower_xla_callable lowering path. This is because pmap's programming
|
||||
# model is not compatible with lower_sharding_computation.
|
||||
# Specifying backend on `jit` is not supported when Array is enabled. So take
|
||||
# the `lower_xla_callable` path which can handle it.
|
||||
if (jaxpr_has_primitive(jaxpr, 'xla_pmap') or
|
||||
any(isinstance(s, sharding.PmapSharding) for s in in_shardings) or
|
||||
backend is not None):
|
||||
arg_specs = tuple(
|
||||
(a, s._device) if isinstance(s, sharding.SingleDeviceSharding) else (a, None)
|
||||
for a, s in zip(in_avals, in_shardings))
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
|
||||
keep_unused, *arg_specs).compile().unsafe_call
|
||||
return lower_xla_callable(
|
||||
fun, None, backend, name, donated_invars, False, keep_unused, *arg_specs,
|
||||
traced_jaxpr_info=traced_jaxpr_info).compile().unsafe_call
|
||||
|
||||
committed = any(i is not None for i in in_shardings)
|
||||
da = pjit._get_and_check_device_assignment(
|
||||
(i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh)
|
||||
in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i
|
||||
for i in in_shardings]
|
||||
|
||||
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'xla_callable', name, in_shardings, pjit._UNSPECIFIED,
|
||||
donated_invars, in_avals,
|
||||
in_is_global=(True,) * len(arg_specs)).compile(
|
||||
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
|
||||
committed=committed, traced_jaxpr_info=traced_jaxpr_info).compile(
|
||||
_allow_propagation_to_outputs=True).unsafe_call
|
||||
|
||||
|
||||
@ -347,9 +360,10 @@ def should_tuple_args(num_args: int, platform: str):
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, always_lower: bool, keep_unused: bool,
|
||||
*arg_specs):
|
||||
def lower_xla_callable(
|
||||
fun: lu.WrappedFun, device, backend, name, donated_invars,
|
||||
always_lower: bool, keep_unused: bool, *arg_specs,
|
||||
traced_jaxpr_info: Optional[TracedJaxprInfo] = None):
|
||||
"""Lower into XLA.
|
||||
|
||||
Args:
|
||||
@ -371,11 +385,18 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
else:
|
||||
assert abstract_args == (None,) * len(abstract_args)
|
||||
abstract_args = [aval for aval, _ in fun.in_type]
|
||||
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for jit in {elapsed_time} sec"):
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
|
||||
fun, pe.debug_info_final(fun, "jit"))
|
||||
out_avals, kept_outputs = util.unzip2(out_type)
|
||||
|
||||
if traced_jaxpr_info is None:
|
||||
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for jit in {elapsed_time} sec"):
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
|
||||
fun, pe.debug_info_final(fun, "jit"))
|
||||
out_avals, kept_outputs = util.unzip2(out_type)
|
||||
else:
|
||||
jaxpr, out_avals, consts = traced_jaxpr_info
|
||||
kept_outputs = [True] * len(out_avals)
|
||||
out_type = tuple(zip(out_avals, kept_outputs))
|
||||
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
|
||||
|
@ -548,6 +548,8 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:
|
||||
|
||||
def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
|
||||
weak_type: bool = False):
|
||||
from jax.experimental import array
|
||||
|
||||
# Don't canonicalize old_dtype because x64 context might cause
|
||||
# un-canonicalized operands to be passed in.
|
||||
old_dtype = dtypes.dtype(operand, canonicalize=False)
|
||||
@ -575,7 +577,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
|
||||
old_weak_type = False
|
||||
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
|
||||
return operand
|
||||
else:
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
@ -794,8 +796,10 @@ def broadcast_in_dim(operand: Array, shape: Shape,
|
||||
See Also:
|
||||
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
|
||||
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
|
||||
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
|
||||
and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))):
|
||||
return operand
|
||||
if config.jax_dynamic_shapes:
|
||||
# We must gate this behavior under a flag because otherwise the errors
|
||||
@ -850,6 +854,8 @@ def reshape(operand: Array, new_sizes: Shape,
|
||||
>>> reshape(y, (6,), (1, 0))
|
||||
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
|
||||
"""
|
||||
from jax.experimental import array
|
||||
|
||||
new_sizes = canonicalize_shape(new_sizes) # TODO
|
||||
new_sizes = tuple(new_sizes)
|
||||
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
|
||||
@ -860,7 +866,7 @@ def reshape(operand: Array, new_sizes: Shape,
|
||||
dims = api_util._ensure_index_tuple(dimensions)
|
||||
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
|
||||
if (np.shape(operand) and same_shape and same_dims
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
|
||||
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
|
||||
return operand
|
||||
else:
|
||||
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
|
||||
|
@ -1873,7 +1873,7 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
|
||||
# We can't use the ndarray class because we need to handle internal buffers
|
||||
# (See https://github.com/google/jax/issues/8950)
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer)
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer, Array)
|
||||
|
||||
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
|
||||
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
|
||||
@ -4746,7 +4746,7 @@ _NOT_IMPLEMENTED = ['argpartition']
|
||||
|
||||
# Experimental support for NumPy's module dispatch with NEP-37.
|
||||
# Currently requires https://github.com/seberg/numpy-dispatch
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer)
|
||||
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, Array)
|
||||
_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
|
||||
|
||||
def __array_module__(self, types):
|
||||
|
@ -348,7 +348,7 @@ class KeyTy:
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding):
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
phys_aval, = KeyTy.physical_avals(aval)
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
|
||||
@ -378,7 +378,7 @@ class KeyTy:
|
||||
# a new op sharding with a trivially extended `tile_assignment_dimensions`
|
||||
raise NotImplementedError
|
||||
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
@ -281,6 +281,15 @@ class Array:
|
||||
'named_shape': self.aval.named_shape}
|
||||
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
||||
|
||||
def unsafe_buffer_pointer(self):
|
||||
assert len(self._arrays) == 1
|
||||
return self._arrays[0].unsafe_buffer_pointer()
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
assert len(self._arrays) == 1
|
||||
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
||||
|
||||
# TODO(yashkatariya): Remove this method when everyone is using devices().
|
||||
def device(self) -> Device:
|
||||
self._check_if_deleted()
|
||||
@ -434,19 +443,25 @@ def _array_shard_arg(x, devices, indices, mode):
|
||||
if mode == pxla.InputsHandlerMode.pmap:
|
||||
return _array_pmap_shard_arg(x, devices, indices, mode)
|
||||
else:
|
||||
return x._arrays
|
||||
if isinstance(x.sharding, SingleDeviceSharding):
|
||||
return [buf if buf.device() == d else buf.copy_to_device(d)
|
||||
for buf, d in safe_zip(x._arrays, devices)]
|
||||
else:
|
||||
return x._arrays
|
||||
pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding):
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
||||
if core.aval_has_custom_eltype(global_aval):
|
||||
return global_aval.dtype.global_sharded_result_handler(
|
||||
global_aval, out_sharding)
|
||||
else:
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
global_aval, out_sharding, committed)
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs,
|
||||
committed=committed, _skip_checks=True)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
|
||||
|
||||
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
|
@ -618,7 +618,10 @@ def _gda_shard_arg(x, devices, indices, mode):
|
||||
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
|
||||
|
||||
|
||||
def _gda_array_result_handler(global_aval, out_sharding):
|
||||
def _gda_array_result_handler(global_aval, out_sharding, committed):
|
||||
if core.aval_has_custom_eltype(global_aval):
|
||||
return global_aval.dtype.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed)
|
||||
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
|
||||
global_idx_rid = get_shard_indices_replica_ids(global_aval.shape, global_mesh,
|
||||
out_axis_resources)
|
||||
|
@ -28,7 +28,7 @@ from jax.experimental.sharding import (
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src.api import _check_callable, _check_arg, devices
|
||||
from jax._src.api import _check_callable, _check_arg, local_devices
|
||||
from jax._src.config import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
@ -922,9 +922,12 @@ def _pjit_lower_cached(
|
||||
else:
|
||||
# Pass `in_is_global` here because this path is taken by both host local
|
||||
# avals and global avals.
|
||||
# TODO(yashkatariya): Don't set committed to True always. Infer that from
|
||||
# the arguments just like dispatch.py in `sharded_lowering`.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
|
||||
jaxpr.in_avals, in_is_global=in_is_global)
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
|
||||
committed=True)
|
||||
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
|
||||
@ -1519,7 +1522,7 @@ def _get_and_check_device_assignment(shardings, pjit_mesh):
|
||||
if first_device_assignment is None and not pjit_mesh.empty:
|
||||
return mesh_devices
|
||||
if first_device_assignment is None:
|
||||
return [config.jax_default_device or devices()[0]]
|
||||
return [config.jax_default_device or local_devices()[0]]
|
||||
return first_device_assignment
|
||||
|
||||
|
||||
|
@ -348,7 +348,9 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
return self._hash
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._op_sharding)
|
||||
if pxla.is_op_sharding_replicated(self._op_sharding):
|
||||
return 'OpShardingSharding(REPLICATED)'
|
||||
return f'OpShardingSharding({repr(self._op_sharding)})'
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
|
||||
|
@ -421,7 +421,14 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
|
||||
|
||||
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any, InputsHandlerMode], Sequence[Any]]] = {}
|
||||
|
||||
def _shard_token(x, devices, indices, mode):
|
||||
return device_put(np.zeros((), dtype=np.dtype(np.bool_)), devices, replicate=True)
|
||||
shard_arg_handlers[core.Token] = _shard_token
|
||||
|
||||
def _shard_array(x, devices, indices, mode):
|
||||
if x.dtype == dtypes.float0:
|
||||
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
||||
return device_put([x[i] for i in indices], devices)
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
@ -581,7 +588,7 @@ local_result_handlers[(ConcreteArray, OutputType.ShardedDeviceArray)] = sda_arra
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue, out_sharding,
|
||||
aval: core.AbstractValue, out_sharding, committed: bool
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
@ -602,7 +609,8 @@ def global_aval_to_result_handler(
|
||||
elif config.jax_parallel_functions_output_gda:
|
||||
output_type = OutputType.GlobalDeviceArray
|
||||
try:
|
||||
return global_result_handlers[(type(aval), output_type)](aval, out_sharding)
|
||||
return global_result_handlers[(type(aval), output_type)](
|
||||
aval, out_sharding, committed)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
@ -1601,7 +1609,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
|
||||
handle_outs, unordered_effects,
|
||||
ordered_effects, keepalive,
|
||||
bool(host_callbacks))
|
||||
bool(host_callbacks), set(range(len(input_indices))))
|
||||
fingerprint = getattr(compiled, "fingerprint", None)
|
||||
|
||||
return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)
|
||||
@ -1836,12 +1844,13 @@ def local_avals_to_results_handler(
|
||||
|
||||
def global_avals_to_results_handler(
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
||||
shardings: Sequence[XLACompatibleSharding],
|
||||
committed: bool) -> ResultsHandler:
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
|
||||
if config.jax_parallel_functions_output_gda or config.jax_array:
|
||||
handlers = [
|
||||
global_aval_to_result_handler(global_aval, s)
|
||||
global_aval_to_result_handler(global_aval, s, committed)
|
||||
for global_aval, s in safe_zip(global_out_avals, shardings)
|
||||
]
|
||||
return ResultsHandler(handlers, shardings, global_out_avals)
|
||||
@ -1959,13 +1968,14 @@ class ExecuteReplicated:
|
||||
"""The logic to shard inputs, execute a replicated model, returning outputs."""
|
||||
__slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler',
|
||||
'has_unordered_effects', 'ordered_effects', 'keepalive',
|
||||
'has_host_callbacks', '_local_devices', '__weakref__']
|
||||
'has_host_callbacks', '_local_devices', 'kept_var_idx',
|
||||
'__weakref__']
|
||||
|
||||
def __init__(self, xla_executable, backend, in_handler: InputsHandler,
|
||||
out_handler: ResultsHandler,
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect], keepalive: Any,
|
||||
has_host_callbacks: bool):
|
||||
has_host_callbacks: bool, kept_var_idx: Set[int]):
|
||||
self.xla_executable = xla_executable
|
||||
self.backend = backend
|
||||
self.in_handler = in_handler
|
||||
@ -1977,6 +1987,7 @@ class ExecuteReplicated:
|
||||
assert len(self._local_devices) == 1
|
||||
self.keepalive = keepalive
|
||||
self.has_host_callbacks = has_host_callbacks
|
||||
self.kept_var_idx = kept_var_idx
|
||||
|
||||
def _call_with_tokens(self, input_bufs):
|
||||
# TODO(sharadmv): simplify this logic when minimum jaxlib version is
|
||||
@ -2013,6 +2024,7 @@ class ExecuteReplicated:
|
||||
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
|
||||
input_bufs = self.in_handler(args)
|
||||
if (self.ordered_effects or self.has_unordered_effects or
|
||||
self.has_host_callbacks):
|
||||
@ -2650,7 +2662,10 @@ def lower_sharding_computation(
|
||||
out_shardings: Union[Sequence[Union[XLACompatibleSharding, _UnspecifiedValue]], _UnspecifiedValue],
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
in_is_global: Sequence[bool]):
|
||||
in_is_global: Sequence[bool],
|
||||
keep_unused: bool,
|
||||
committed: bool,
|
||||
traced_jaxpr_info: Optional[dispatch.TracedJaxprInfo] = None):
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
The caller of this code can pass in a singleton _UNSPECIFIED because the
|
||||
@ -2667,6 +2682,7 @@ def lower_sharding_computation(
|
||||
# UNSPECIFIED singleton are filtered above.
|
||||
backend, first_sharding = _get_backend_from_shardings(
|
||||
it.chain(in_shardings, out_shardings)) # type: ignore
|
||||
device_assignment = first_sharding._device_assignment
|
||||
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
@ -2678,11 +2694,13 @@ def lower_sharding_computation(
|
||||
global_in_avals, in_shardings)
|
||||
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
in_jaxpr_avals = global_in_avals
|
||||
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec"):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
||||
if traced_jaxpr_info is None:
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"for sharded computation in {elapsed_time} sec"):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_in_avals, debug_info=pe.debug_info_final(fun, "sharded computation"))
|
||||
else:
|
||||
jaxpr, out_jaxpr_avals, consts = traced_jaxpr_info
|
||||
|
||||
if _is_unspecified(out_shardings):
|
||||
out_shardings = (_UNSPECIFIED,) * len(out_jaxpr_avals)
|
||||
@ -2692,33 +2710,49 @@ def lower_sharding_computation(
|
||||
|
||||
global_out_avals = out_jaxpr_avals
|
||||
|
||||
if keep_unused:
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
else:
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
in_is_global = tuple(g for i, g in enumerate(in_is_global) if i in kept_var_idx)
|
||||
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
||||
del kept_const_idx
|
||||
|
||||
_sanitize_mesh_jaxpr(jaxpr)
|
||||
if not first_sharding.is_fully_addressable():
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
# 2. Build up the HLO
|
||||
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
|
||||
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
||||
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
||||
axis_ctx: mlir.ShardingContext
|
||||
|
||||
in_op_shardings = [i._to_xla_op_sharding(aval.ndim)
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
||||
in_op_shardings = [
|
||||
None if aval is core.abstract_token else i._to_xla_op_sharding(aval.ndim)
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings)
|
||||
]
|
||||
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
||||
# [None, OpShardingProto] has the sharding annotations.
|
||||
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
||||
for aval, o in safe_zip(global_out_avals, out_shardings)]
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
out_op_shardings = [
|
||||
None if _is_unspecified(o) or aval is core.abstract_token else o._to_xla_op_sharding(aval.ndim)
|
||||
for aval, o in safe_zip(global_out_avals, out_shardings)
|
||||
]
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = mlir.ShardingContext(first_sharding)
|
||||
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects not supported in mesh computations.")
|
||||
if len(device_assignment) > 1:
|
||||
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
||||
unordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
if eff not in core.ordered_effects]
|
||||
ordered_effects = [eff for eff in closed_jaxpr.effects
|
||||
@ -2726,7 +2760,8 @@ def lower_sharding_computation(
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
unordered_effects, ordered_effects,
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
backend.platform,
|
||||
axis_ctx,
|
||||
@ -2735,10 +2770,16 @@ def lower_sharding_computation(
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings)
|
||||
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
# to get the device_assignment and backend. So pass it to MeshExecutable
|
||||
# because we calculate the device_assignment and backend before in_shardings,
|
||||
# etc are pruned.
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
module,
|
||||
@ -2755,7 +2796,11 @@ def lower_sharding_computation(
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
host_callbacks=host_callbacks,
|
||||
keepalive=keepalive)
|
||||
keepalive=keepalive,
|
||||
kept_var_idx=kept_var_idx,
|
||||
backend=backend,
|
||||
device_assignment=device_assignment,
|
||||
committed=committed)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -2903,7 +2948,11 @@ def lower_mesh_computation(
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
host_callbacks=host_callbacks,
|
||||
keepalive=keepalive)
|
||||
keepalive=keepalive,
|
||||
kept_var_idx=set(range(len(global_in_avals))),
|
||||
backend=backend,
|
||||
device_assignment=list(mesh.devices.flat),
|
||||
committed=True)
|
||||
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
@ -2963,15 +3012,19 @@ def _get_input_metadata(
|
||||
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
||||
sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec)
|
||||
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
# represent index for each device in the global mesh. But here we want
|
||||
# indices for the local devices of the global mesh.
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if is_op_sharding_replicated(proto):
|
||||
index = tuple((slice(None),) * aval.ndim for _ in range(len(sharding.addressable_devices)))
|
||||
if aval is core.abstract_token:
|
||||
index = (slice(None),)
|
||||
else:
|
||||
index = tuple(sharding.devices_indices_map(aval.shape).values())
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
# represent index for each device in the global mesh. But here we want
|
||||
# indices for the local devices of the global mesh.
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if is_op_sharding_replicated(proto):
|
||||
index = tuple((slice(None),) * aval.ndim
|
||||
for _ in range(len(sharding.addressable_devices))) # type: ignore
|
||||
else:
|
||||
index = tuple(sharding.devices_indices_map(aval.shape).values()) # type: ignore
|
||||
|
||||
shardings.append(sharding)
|
||||
input_indices.append(index)
|
||||
@ -3044,32 +3097,28 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
keepalive: Any) -> MeshExecutable:
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
else:
|
||||
backend, first_sharding = _get_backend_from_shardings(
|
||||
it.chain(in_shardings, out_shardings)) # type: ignore
|
||||
|
||||
keepalive: Any,
|
||||
kept_var_idx: Set[int],
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
committed: bool) -> MeshExecutable:
|
||||
dev: np.ndarray
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None and spmd_lowering
|
||||
dev = mesh.devices
|
||||
num_replicas, num_partitions = 1, mesh.size
|
||||
else:
|
||||
dev = np.array(first_sharding._device_assignment)
|
||||
dev = np.array(device_assignment)
|
||||
if spmd_lowering:
|
||||
num_replicas, num_partitions = 1, dev.size
|
||||
else:
|
||||
num_replicas, num_partitions = dev.size, 1
|
||||
device_assignment = dev.reshape((num_replicas, num_partitions))
|
||||
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
|
||||
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
device_assignment=xla_device_assignment,
|
||||
use_spmd_partitioning=spmd_lowering,
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
)
|
||||
@ -3088,7 +3137,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
in_shardings, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, in_shardings, in_is_global) # type: ignore
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings) # type: ignore # arg-type
|
||||
global_out_avals, out_shardings, committed) # type: ignore # arg-type
|
||||
unsafe_call = backend.compile_replicated(computation, compile_options,
|
||||
host_callbacks, input_avals,
|
||||
input_indices, in_shardings,
|
||||
@ -3108,7 +3157,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
|
||||
assert mesh is None
|
||||
_, out_shardings_xla = _get_op_sharding_shardings_from_executable(
|
||||
xla_executable, first_sharding._device_assignment,
|
||||
xla_executable, device_assignment,
|
||||
len(global_in_avals), len(global_out_avals))
|
||||
out_shardings = [x if _is_unspecified(o) else o
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)]
|
||||
@ -3116,13 +3165,13 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
in_shardings, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, in_shardings, in_is_global) # type: ignore
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings) # type: ignore # arg-type
|
||||
global_out_avals, out_shardings, committed) # type: ignore # arg-type
|
||||
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
|
||||
input_indices, InputsHandlerMode.pjit_or_xmap)
|
||||
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
|
||||
handle_outs, unordered_effects,
|
||||
ordered_effects, keepalive,
|
||||
bool(host_callbacks))
|
||||
bool(host_callbacks), kept_var_idx)
|
||||
|
||||
return MeshExecutable(xla_executable, unsafe_call, input_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering)
|
||||
|
@ -78,6 +78,10 @@ jax_test(
|
||||
jax_test(
|
||||
name = "custom_object_test",
|
||||
srcs = ["custom_object_test.py"],
|
||||
# TODO(yashkatariya,mattjj,phawkins): Enable custom_object_test once
|
||||
# `ExecuteReplicated` supports the use case of having more
|
||||
# than 1 buffer on a single device.
|
||||
disable_configs = ["cpu_jax_array"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -239,13 +239,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
assert len(side) == 3
|
||||
|
||||
def test_jit_device(self):
|
||||
if config.jax_array:
|
||||
self.skipTest('The device parameter of jit has been deprecated. Array '
|
||||
'is not compatible with it and will not work.')
|
||||
device = jax.devices()[-1]
|
||||
x = self.jit(lambda x: x, device=device)(3.)
|
||||
_check_instance(self, x)
|
||||
if config.jax_array:
|
||||
self.assertEqual(x.device(), device)
|
||||
else:
|
||||
self.assertEqual(x.device_buffer.device(), device)
|
||||
self.assertEqual(x.device_buffer.device(), device)
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_jit_default_device(self):
|
||||
@ -267,10 +267,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
|
||||
with jax.default_device(test_device):
|
||||
# Explicit `device` or `backend` argument to jit overrides default_device
|
||||
self.assertEqual(
|
||||
jax.jit(f, device=system_default_device)(1).device(),
|
||||
system_default_device)
|
||||
# Skip this for jax.Array because using the device argument of `jit` is
|
||||
# deprecated.
|
||||
if not config.jax_array:
|
||||
# Explicit `device` or `backend` argument to jit overrides default_device
|
||||
self.assertEqual(
|
||||
jax.jit(f, device=system_default_device)(1).device(),
|
||||
system_default_device)
|
||||
out = jax.jit(f, backend="cpu")(1)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
|
||||
@ -1067,7 +1070,10 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
|
||||
with jtu.count_device_put() as count:
|
||||
_ = jitted_f(1, 2)
|
||||
self.assertEqual(count[0], 1)
|
||||
if config.jax_array:
|
||||
self.assertEqual(count[0], 2)
|
||||
else:
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
def test_jit_lower_compile_compiler_ir(self):
|
||||
|
@ -38,6 +38,7 @@ import jax.util
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import array
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src import dispatch
|
||||
@ -3042,6 +3043,14 @@ class FooTy:
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
def handler(bufs):
|
||||
buf, = bufs
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
@ -3152,6 +3161,12 @@ def device_put_foo_array(x: FooArray, device):
|
||||
return array._device_put_array(x.data, device)
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
def shard_foo_array_handler(x, devices, indices, mode):
|
||||
device, = devices
|
||||
if isinstance(x.data, array.Array):
|
||||
return array._device_put_array(x.data, device)
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
def foo_array_constant_handler(x, c):
|
||||
if config.jax_array:
|
||||
return array._array_mlir_constant_handler(x.data, c)
|
||||
@ -3186,6 +3201,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
xla.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
dispatch.device_put_handlers[FooArray] = device_put_foo_array
|
||||
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
|
||||
mlir._constant_handlers[FooArray] = foo_array_constant_handler
|
||||
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
|
||||
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
|
||||
|
@ -224,6 +224,33 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
y = jax.device_put(1, devices[2]) + jnp.ones((2, 3))
|
||||
self.assert_committed_to_device(y, devices[2])
|
||||
|
||||
def test_single_input_committed_multi_output(self):
|
||||
if jax.device_count() < 3:
|
||||
self.skipTest("Test requires 3 devices")
|
||||
devices = self.get_devices()
|
||||
|
||||
@jax.jit
|
||||
def f(a, b, c, d, e):
|
||||
return a, b, c, d, e
|
||||
|
||||
outs = f(jax.device_put(1, devices[2]), jnp.array(2), jnp.array(3),
|
||||
jnp.array(4), jnp.array(5))
|
||||
for o in outs:
|
||||
self.assert_committed_to_device(o, devices[2])
|
||||
|
||||
def test_different_devices_input_error(self):
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest("Test requires 2 devices")
|
||||
devices = self.get_devices()
|
||||
|
||||
a = jax.device_put(1, devices[0])
|
||||
b = jax.device_put(2, devices[1])
|
||||
|
||||
# Don't look for the message because the Array and non-Array path raise
|
||||
# slightly different error messages.
|
||||
with self.assertRaises(ValueError):
|
||||
_ = a + b
|
||||
|
||||
def test_transpose(self):
|
||||
if jax.device_count() < 3:
|
||||
self.skipTest("test requires 3 devices")
|
||||
|
@ -173,9 +173,12 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
result2 = jax.jit(my_sin)(data_on_cpu)
|
||||
self.assertEqual(result2.device(), cpus[0])
|
||||
|
||||
# jit with `device` spec places the data on the specified device
|
||||
result3 = jax.jit(my_sin, device=cpus[0])(2)
|
||||
self.assertEqual(result3.device(), cpus[0])
|
||||
# Skip this for jax.Array because using the device argument of `jit` is
|
||||
# deprecated.
|
||||
if not config.jax_array:
|
||||
# jit with `device` spec places the data on the specified device\
|
||||
result3 = jax.jit(my_sin, device=cpus[0])(2)
|
||||
self.assertEqual(result3.device(), cpus[0])
|
||||
|
||||
# jit with `backend` spec places the data on the specified backend
|
||||
result4 = jax.jit(my_sin, backend="cpu")(2)
|
||||
|
@ -137,8 +137,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def pmap(self):
|
||||
return src_api._python_pmap
|
||||
|
||||
# TODO(yashkatariya): Re-enable when unsafe_buffer_pointer is implemented
|
||||
@unittest.skipIf(config.jax_array, "Array does not yet implement unsafe_buffer_pointer")
|
||||
def testDeviceBufferToArray(self):
|
||||
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
|
||||
|
||||
|
@ -1451,7 +1451,10 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
key = self.seed_prng(1).block_until_ready()
|
||||
with jtu.count_device_put() as count:
|
||||
jax.jit(random.split)(key)
|
||||
self.assertEqual(count[0], 1) # 1 for the argument device_put
|
||||
if config.jax_array:
|
||||
self.assertEqual(count[0], 0)
|
||||
else:
|
||||
self.assertEqual(count[0], 1) # 1 for the argument device_put
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
|
||||
|
Loading…
x
Reference in New Issue
Block a user