Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.

Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
This commit is contained in:
Yue Sheng 2024-04-18 11:09:02 -07:00 committed by jax authors
parent 9c9e805e82
commit c2d4373535
8 changed files with 123 additions and 89 deletions

View File

@ -24,6 +24,10 @@ Remember to align the itemized text with the first line of an item within a list
to non-parallel computations, as we already do async dispatch for parallel
computations. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old

View File

@ -2454,6 +2454,8 @@ def _infer_src_sharding(src, x) -> Sharding | None:
def _check_sharding(x, s):
if isinstance(s, Sharding):
aval = shaped_abstractify(x)
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)

View File

@ -952,6 +952,13 @@ def _array_shard_arg(x, sharding):
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _token_shard_arg(x, sharding):
return _array_shard_arg(x._buf, sharding)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
@ -963,7 +970,21 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
)
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token
def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed
)
def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)
return core.Token(out_buf)
return wrapper
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
# Only used for Arrays that come out of pmap.

View File

@ -1635,6 +1635,70 @@ class UnshapedArray(AbstractValue):
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
return dim
else:
raise type_error
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()
return TypeError(msg)
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'named_shape']
@ -1960,9 +2024,18 @@ class AbstractToken(AbstractValue):
def at_least_vspace(self): return self
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
# Concrete token object
class Token: pass
token: Token = Token()
class Token:
# The underlying data wrapped by the token, could be used to threaded in and
# out of computations to build up data dependency.
_buf: Array
def __init__(self, buf):
self._buf = buf
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
@ -2121,71 +2194,6 @@ def dimension_as_value(d: DimSize):
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
return dim
else:
raise type_error
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()
return TypeError(msg)
class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"

View File

@ -107,7 +107,7 @@ class RuntimeTokenSet(threading.local):
# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
current_tokens: dict[core.Effect, jax.Array]
current_tokens: dict[core.Effect, core.Token]
# For each device, the runtime token returned by the last dispatched
# computation on that device.
@ -117,11 +117,12 @@ class RuntimeTokenSet(threading.local):
self.current_tokens = {}
self.output_runtime_tokens = {}
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
def get_token_input(
self, eff: core.Effect, devices: list[Device]
) -> core.Token:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
if isinstance(tok, jax.Array):
if isinstance(tok, core.Token):
# The order of devices may change, so we need to reshard if necessary.
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
# scenario. Revise the logic later. A distributed shutdown barrier inside
@ -131,11 +132,11 @@ class RuntimeTokenSet(threading.local):
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = pxla.shard_args([s], [tok])[0]
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok
def set_token_result(self, eff: core.Effect, token: jax.Array):
def set_token_result(self, eff: core.Effect, token: core.Token):
self.current_tokens[eff] = token
def set_output_runtime_token(self, device: Device, token: RuntimeToken):

View File

@ -131,13 +131,6 @@ def get_addressable_devices_for_shard_arg(
def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices
def _shard_token(x, sharding):
devices = get_addressable_devices_for_shard_arg(sharding)
indices = _get_replicated_slices(len(devices))
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
shard_arg_handlers[core.Token] = _shard_token
def _masked_array_error(x, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
@ -1148,8 +1141,9 @@ class ExecuteReplicated:
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
tokens = [
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
for eff in self.ordered_effects]
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf
for eff in self.ordered_effects
]
input_bufs = [*tokens, *input_bufs]
return input_bufs
@ -1163,7 +1157,7 @@ class ExecuteReplicated:
for eff, token_buf in zip(self.ordered_effects, token_bufs):
assert len(token_buf) > 0
if len(token_buf) == 1:
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0]))
else:
token_devices = []
for token in token_buf:
@ -1173,7 +1167,9 @@ class ExecuteReplicated:
global_token_array = jax.make_array_from_single_device_arrays(
(0,), s, token_buf
)
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
dispatch.runtime_tokens.set_token_result(
eff, core.Token(global_token_array)
)
@profiler.annotate_function
def __call__(self, *args):

View File

@ -148,7 +148,6 @@ from jax._src.core import (
subst_axis_names_var as subst_axis_names_var,
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
thread_local_state as thread_local_state,
token as token,
trace_state_clean as trace_state_clean,
traverse_jaxpr_params as traverse_jaxpr_params,
typecheck as typecheck,

View File

@ -673,8 +673,12 @@ class JitTest(jtu.BufferDonationTestCase):
arr = jnp.ones(10)
token = jax.lax.create_token()
_, out_token = noop(arr, token)
self.assertEqual(token, noop(arr, token)[1])
self.assertIsInstance(token, core.Token)
self.assertIsInstance(out_token, core.Token)
# Different token objects.
self.assertIsNot(token, out_token)
def test_jit_bad_input(self):
def f(x):
@ -1226,7 +1230,6 @@ class JitTest(jtu.BufferDonationTestCase):
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
def f(x, y, *args, **kwargs):
@ -3732,7 +3735,7 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(x, core.Token)
def test_jit_capturing_token(self):
tok = core.token
tok = jax.lax.create_token()
_, y = jax.jit(lambda x: (x + 2, tok))(7)
self.assertIsInstance(y, core.Token)