diff --git a/CHANGELOG.md b/CHANGELOG.md index a30f333d9..0121e67e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ Remember to align the itemized text with the first line of an item within a list to non-parallel computations, as we already do async dispatch for parallel computations. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. + * `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could + be created and threaded in and out of computations to build up dependency. + The singleton object `core.token` has been removed, users now should create + and use fresh `core.Token` objects instead. * Deprecations & Removals * Pallas now exclusively uses XLA for compiling kernels on GPU. The old diff --git a/jax/_src/api.py b/jax/_src/api.py index 7580a9117..6f5316e29 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2454,6 +2454,8 @@ def _infer_src_sharding(src, x) -> Sharding | None: def _check_sharding(x, s): if isinstance(s, Sharding): aval = shaped_abstractify(x) + if isinstance(aval, core.AbstractToken): + aval = core.token_shaped_array if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) diff --git a/jax/_src/array.py b/jax/_src/array.py index c1a587f34..ad8608862 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -952,6 +952,13 @@ def _array_shard_arg(x, sharding): pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg +def _token_shard_arg(x, sharding): + return _array_shard_arg(x._buf, sharding) + + +pxla.shard_arg_handlers[core.Token] = _token_shard_arg + + def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore @@ -963,7 +970,21 @@ def _array_global_result_handler(global_aval, out_sharding, committed): ) pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler -pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token + + +def _token_global_result_handler(global_aval, out_sharding, committed): + array_handler = _array_global_result_handler( + core.token_shaped_array, out_sharding, committed + ) + + def wrapper(*args, **kwargs): + out_buf = array_handler(*args, **kwargs) + return core.Token(out_buf) + + return wrapper + + +pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler # Only used for Arrays that come out of pmap. diff --git a/jax/_src/core.py b/jax/_src/core.py index 92c48e082..3aaa3db0b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1635,6 +1635,70 @@ class UnshapedArray(AbstractValue): "UnshapedArray instances to ever be produced.") raise TypeError(msg) +def _canonicalize_dimension(dim: DimSize) -> DimSize: + # Dimensions are most commonly integral (by far), so we check that first. + try: + return operator.index(dim) + except TypeError as e: + type_error = e + if isinstance(dim, Tracer) and config.dynamic_shapes.value: + if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) + or isinstance(dim.dtype, bint))): + raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") + return dim + elif (config.dynamic_shapes.value and isinstance(dim, DArray) and + type(dim._aval.dtype) is bint and not dim._aval.shape): + return dim + elif is_dim(dim): + return dim + else: + raise type_error + +def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]: + """Canonicalizes and checks for errors in a user-provided shape value. + + Args: + shape: a Python value that represents a shape. + + Returns: + A tuple of canonical dimension values. + """ + try: + return tuple(unsafe_map(_canonicalize_dimension, shape)) + except TypeError: + pass + raise _invalid_shape_error(shape, context) + +def canonicalize_dim(d: DimSize, context: str="") -> DimSize: + """Canonicalizes and checks for errors in a user-provided shape dimension value. + + Args: + f: a Python value that represents a dimension. + + Returns: + A canonical dimension value. + """ + return canonicalize_shape((d,), context)[0] + +def _invalid_shape_error(shape: Shape, context: str=""): + if config.dynamic_shapes.value: + msg = ("Shapes must be 1D sequences of integer scalars, " + f"got {shape}") + else: + msg = ("Shapes must be 1D sequences of concrete values of integer type, " + f"got {shape}.") + if context: + msg += f" {context}." + if not config.dynamic_shapes.value and any( + isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) + and not isinstance(get_aval(x), ConcreteArray) for x in shape): + msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " + "smaller subfunctions.") + for x in shape: + if isinstance(x, Tracer) and hasattr(x, "_origin_msg"): + msg += x._origin_msg() + + return TypeError(msg) class ShapedArray(UnshapedArray): __slots__ = ['shape', 'named_shape'] @@ -1960,9 +2024,18 @@ class AbstractToken(AbstractValue): def at_least_vspace(self): return self abstract_token: AbstractToken = AbstractToken() +# Singleton shaped array used by all abstract tokens when shape/dtype is needed. +token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_)) + # Concrete token object -class Token: pass -token: Token = Token() +class Token: + # The underlying data wrapped by the token, could be used to threaded in and + # out of computations to build up data dependency. + _buf: Array + def __init__(self, buf): + self._buf = buf + def block_until_ready(self): + self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token @@ -2121,71 +2194,6 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) -def _canonicalize_dimension(dim: DimSize) -> DimSize: - # Dimensions are most commonly integral (by far), so we check that first. - try: - return operator.index(dim) - except TypeError as e: - type_error = e - if isinstance(dim, Tracer) and config.dynamic_shapes.value: - if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) - or isinstance(dim.dtype, bint))): - raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") - return dim - elif (config.dynamic_shapes.value and isinstance(dim, DArray) and - type(dim._aval.dtype) is bint and not dim._aval.shape): - return dim - elif is_dim(dim): - return dim - else: - raise type_error - -def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]: - """Canonicalizes and checks for errors in a user-provided shape value. - - Args: - shape: a Python value that represents a shape. - - Returns: - A tuple of canonical dimension values. - """ - try: - return tuple(unsafe_map(_canonicalize_dimension, shape)) - except TypeError: - pass - raise _invalid_shape_error(shape, context) - -def canonicalize_dim(d: DimSize, context: str="") -> DimSize: - """Canonicalizes and checks for errors in a user-provided shape dimension value. - - Args: - f: a Python value that represents a dimension. - - Returns: - A canonical dimension value. - """ - return canonicalize_shape((d,), context)[0] - -def _invalid_shape_error(shape: Shape, context: str=""): - if config.dynamic_shapes.value: - msg = ("Shapes must be 1D sequences of integer scalars, " - f"got {shape}") - else: - msg = ("Shapes must be 1D sequences of concrete values of integer type, " - f"got {shape}.") - if context: - msg += f" {context}." - if not config.dynamic_shapes.value and any( - isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) - and not isinstance(get_aval(x), ConcreteArray) for x in shape): - msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " - "smaller subfunctions.") - for x in shape: - if isinstance(x, Tracer) and hasattr(x, "_origin_msg"): - msg += x._origin_msg() - - return TypeError(msg) - class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8f186f0fd..54eb03803 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -107,7 +107,7 @@ class RuntimeTokenSet(threading.local): # For each ordered effect, the token returned by the last dispatched # computation, sharded over the devices in that computation. - current_tokens: dict[core.Effect, jax.Array] + current_tokens: dict[core.Effect, core.Token] # For each device, the runtime token returned by the last dispatched # computation on that device. @@ -117,11 +117,12 @@ class RuntimeTokenSet(threading.local): self.current_tokens = {} self.output_runtime_tokens = {} - def get_token_input(self, eff: core.Effect, - devices: list[Device]) -> jax.Array: + def get_token_input( + self, eff: core.Effect, devices: list[Device] + ) -> core.Token: tok = self.current_tokens.get(eff, np.zeros(0, np.bool_)) - if isinstance(tok, jax.Array): + if isinstance(tok, core.Token): # The order of devices may change, so we need to reshard if necessary. # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside @@ -131,11 +132,11 @@ class RuntimeTokenSet(threading.local): # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = pxla.shard_args([s], [tok])[0] + sharded_tok = core.Token(pxla.shard_args([s], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok - def set_token_result(self, eff: core.Effect, token: jax.Array): + def set_token_result(self, eff: core.Effect, token: core.Token): self.current_tokens[eff] = token def set_output_runtime_token(self, device: Device, token: RuntimeToken): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 85dd03ee3..5910e3ddf 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -131,13 +131,6 @@ def get_addressable_devices_for_shard_arg( def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _shard_token(x, sharding): - devices = get_addressable_devices_for_shard_arg(sharding) - indices = _get_replicated_slices(len(devices)) - zeros = np.zeros((), dtype=np.dtype(np.bool_)) - aval = api_util.shaped_abstractify(zeros) - return batched_device_put(aval, sharding, [zeros for _ in indices], devices) -shard_arg_handlers[core.Token] = _shard_token def _masked_array_error(x, sharding): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " @@ -1148,8 +1141,9 @@ class ExecuteReplicated: def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: tokens = [ - dispatch.runtime_tokens.get_token_input(eff, self._local_devices) - for eff in self.ordered_effects] + dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf + for eff in self.ordered_effects + ] input_bufs = [*tokens, *input_bufs] return input_bufs @@ -1163,7 +1157,7 @@ class ExecuteReplicated: for eff, token_buf in zip(self.ordered_effects, token_bufs): assert len(token_buf) > 0 if len(token_buf) == 1: - dispatch.runtime_tokens.set_token_result(eff, token_buf[0]) + dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0])) else: token_devices = [] for token in token_buf: @@ -1173,7 +1167,9 @@ class ExecuteReplicated: global_token_array = jax.make_array_from_single_device_arrays( (0,), s, token_buf ) - dispatch.runtime_tokens.set_token_result(eff, global_token_array) + dispatch.runtime_tokens.set_token_result( + eff, core.Token(global_token_array) + ) @profiler.annotate_function def __call__(self, *args): diff --git a/jax/core.py b/jax/core.py index c9cbd3310..cd0b3c8f2 100644 --- a/jax/core.py +++ b/jax/core.py @@ -148,7 +148,6 @@ from jax._src.core import ( subst_axis_names_var as subst_axis_names_var, substitute_vars_in_output_ty as substitute_vars_in_output_ty, thread_local_state as thread_local_state, - token as token, trace_state_clean as trace_state_clean, traverse_jaxpr_params as traverse_jaxpr_params, typecheck as typecheck, diff --git a/tests/api_test.py b/tests/api_test.py index c20325539..39316d521 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -673,8 +673,12 @@ class JitTest(jtu.BufferDonationTestCase): arr = jnp.ones(10) token = jax.lax.create_token() + _, out_token = noop(arr, token) - self.assertEqual(token, noop(arr, token)[1]) + self.assertIsInstance(token, core.Token) + self.assertIsInstance(out_token, core.Token) + # Different token objects. + self.assertIsNot(token, out_token) def test_jit_bad_input(self): def f(x): @@ -1226,7 +1230,6 @@ class JitTest(jtu.BufferDonationTestCase): for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): self.assertNotIn(s, hlo_str) - @parameterized.parameters([0, 2, [(0, 2)]]) def test_jit_lower_arg_info_static_argnums(self, static_argnums): def f(x, y, *args, **kwargs): @@ -3732,7 +3735,7 @@ class APITest(jtu.JaxTestCase): self.assertIsInstance(x, core.Token) def test_jit_capturing_token(self): - tok = core.token + tok = jax.lax.create_token() _, y = jax.jit(lambda x: (x + 2, tok))(7) self.assertIsInstance(y, core.Token)