mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make core.Token
a non-trivial class which wraps a jax.Array
. Currently, we use a singleton and empty core.token
object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage). PiperOrigin-RevId: 626091210
This commit is contained in:
parent
9c9e805e82
commit
c2d4373535
@ -24,6 +24,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
to non-parallel computations, as we already do async dispatch for parallel
|
||||
computations. You can recover the old behavior by setting
|
||||
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
|
||||
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
|
||||
be created and threaded in and out of computations to build up dependency.
|
||||
The singleton object `core.token` has been removed, users now should create
|
||||
and use fresh `core.Token` objects instead.
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
|
@ -2454,6 +2454,8 @@ def _infer_src_sharding(src, x) -> Sharding | None:
|
||||
def _check_sharding(x, s):
|
||||
if isinstance(s, Sharding):
|
||||
aval = shaped_abstractify(x)
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
aval = core.token_shaped_array
|
||||
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
|
||||
pjit.pjit_check_aval_sharding(
|
||||
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
|
||||
|
@ -952,6 +952,13 @@ def _array_shard_arg(x, sharding):
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
def _token_shard_arg(x, sharding):
|
||||
return _array_shard_arg(x._buf, sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
||||
@ -963,7 +970,21 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
)
|
||||
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token
|
||||
|
||||
|
||||
def _token_global_result_handler(global_aval, out_sharding, committed):
|
||||
array_handler = _array_global_result_handler(
|
||||
core.token_shaped_array, out_sharding, committed
|
||||
)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
out_buf = array_handler(*args, **kwargs)
|
||||
return core.Token(out_buf)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
|
||||
|
||||
|
||||
# Only used for Arrays that come out of pmap.
|
||||
|
142
jax/_src/core.py
142
jax/_src/core.py
@ -1635,6 +1635,70 @@ class UnshapedArray(AbstractValue):
|
||||
"UnshapedArray instances to ever be produced.")
|
||||
raise TypeError(msg)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
# Dimensions are most commonly integral (by far), so we check that first.
|
||||
try:
|
||||
return operator.index(dim)
|
||||
except TypeError as e:
|
||||
type_error = e
|
||||
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
|
||||
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
|
||||
or isinstance(dim.dtype, bint))):
|
||||
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
|
||||
return dim
|
||||
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
|
||||
type(dim._aval.dtype) is bint and not dim._aval.shape):
|
||||
return dim
|
||||
elif is_dim(dim):
|
||||
return dim
|
||||
else:
|
||||
raise type_error
|
||||
|
||||
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape value.
|
||||
|
||||
Args:
|
||||
shape: a Python value that represents a shape.
|
||||
|
||||
Returns:
|
||||
A tuple of canonical dimension values.
|
||||
"""
|
||||
try:
|
||||
return tuple(unsafe_map(_canonicalize_dimension, shape))
|
||||
except TypeError:
|
||||
pass
|
||||
raise _invalid_shape_error(shape, context)
|
||||
|
||||
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
|
||||
|
||||
Args:
|
||||
f: a Python value that represents a dimension.
|
||||
|
||||
Returns:
|
||||
A canonical dimension value.
|
||||
"""
|
||||
return canonicalize_shape((d,), context)[0]
|
||||
|
||||
def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
if config.dynamic_shapes.value:
|
||||
msg = ("Shapes must be 1D sequences of integer scalars, "
|
||||
f"got {shape}")
|
||||
else:
|
||||
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
||||
f"got {shape}.")
|
||||
if context:
|
||||
msg += f" {context}."
|
||||
if not config.dynamic_shapes.value and any(
|
||||
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
||||
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
||||
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
||||
"smaller subfunctions.")
|
||||
for x in shape:
|
||||
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
|
||||
msg += x._origin_msg()
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'named_shape']
|
||||
@ -1960,9 +2024,18 @@ class AbstractToken(AbstractValue):
|
||||
def at_least_vspace(self): return self
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
|
||||
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
|
||||
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))
|
||||
|
||||
# Concrete token object
|
||||
class Token: pass
|
||||
token: Token = Token()
|
||||
class Token:
|
||||
# The underlying data wrapped by the token, could be used to threaded in and
|
||||
# out of computations to build up data dependency.
|
||||
_buf: Array
|
||||
def __init__(self, buf):
|
||||
self._buf = buf
|
||||
def block_until_ready(self):
|
||||
self._buf.block_until_ready()
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
|
||||
|
||||
@ -2121,71 +2194,6 @@ def dimension_as_value(d: DimSize):
|
||||
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
|
||||
return operator.index(d)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
# Dimensions are most commonly integral (by far), so we check that first.
|
||||
try:
|
||||
return operator.index(dim)
|
||||
except TypeError as e:
|
||||
type_error = e
|
||||
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
|
||||
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
|
||||
or isinstance(dim.dtype, bint))):
|
||||
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
|
||||
return dim
|
||||
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
|
||||
type(dim._aval.dtype) is bint and not dim._aval.shape):
|
||||
return dim
|
||||
elif is_dim(dim):
|
||||
return dim
|
||||
else:
|
||||
raise type_error
|
||||
|
||||
def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape value.
|
||||
|
||||
Args:
|
||||
shape: a Python value that represents a shape.
|
||||
|
||||
Returns:
|
||||
A tuple of canonical dimension values.
|
||||
"""
|
||||
try:
|
||||
return tuple(unsafe_map(_canonicalize_dimension, shape))
|
||||
except TypeError:
|
||||
pass
|
||||
raise _invalid_shape_error(shape, context)
|
||||
|
||||
def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
|
||||
|
||||
Args:
|
||||
f: a Python value that represents a dimension.
|
||||
|
||||
Returns:
|
||||
A canonical dimension value.
|
||||
"""
|
||||
return canonicalize_shape((d,), context)[0]
|
||||
|
||||
def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
if config.dynamic_shapes.value:
|
||||
msg = ("Shapes must be 1D sequences of integer scalars, "
|
||||
f"got {shape}")
|
||||
else:
|
||||
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
|
||||
f"got {shape}.")
|
||||
if context:
|
||||
msg += f" {context}."
|
||||
if not config.dynamic_shapes.value and any(
|
||||
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
|
||||
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
|
||||
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
|
||||
"smaller subfunctions.")
|
||||
for x in shape:
|
||||
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
|
||||
msg += x._origin_msg()
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
class SomeTracer:
|
||||
__slots__ = ()
|
||||
def __repr__(self): return "[dynamic]"
|
||||
|
@ -107,7 +107,7 @@ class RuntimeTokenSet(threading.local):
|
||||
|
||||
# For each ordered effect, the token returned by the last dispatched
|
||||
# computation, sharded over the devices in that computation.
|
||||
current_tokens: dict[core.Effect, jax.Array]
|
||||
current_tokens: dict[core.Effect, core.Token]
|
||||
|
||||
# For each device, the runtime token returned by the last dispatched
|
||||
# computation on that device.
|
||||
@ -117,11 +117,12 @@ class RuntimeTokenSet(threading.local):
|
||||
self.current_tokens = {}
|
||||
self.output_runtime_tokens = {}
|
||||
|
||||
def get_token_input(self, eff: core.Effect,
|
||||
devices: list[Device]) -> jax.Array:
|
||||
def get_token_input(
|
||||
self, eff: core.Effect, devices: list[Device]
|
||||
) -> core.Token:
|
||||
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
||||
|
||||
if isinstance(tok, jax.Array):
|
||||
if isinstance(tok, core.Token):
|
||||
# The order of devices may change, so we need to reshard if necessary.
|
||||
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
|
||||
# scenario. Revise the logic later. A distributed shutdown barrier inside
|
||||
@ -131,11 +132,11 @@ class RuntimeTokenSet(threading.local):
|
||||
# We only use replicated sharding for the first time when the token for the
|
||||
# order effect hasn't been created.
|
||||
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
||||
sharded_tok = pxla.shard_args([s], [tok])[0]
|
||||
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
|
||||
self.current_tokens[eff] = sharded_tok
|
||||
return sharded_tok
|
||||
|
||||
def set_token_result(self, eff: core.Effect, token: jax.Array):
|
||||
def set_token_result(self, eff: core.Effect, token: core.Token):
|
||||
self.current_tokens[eff] = token
|
||||
|
||||
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
|
||||
|
@ -131,13 +131,6 @@ def get_addressable_devices_for_shard_arg(
|
||||
def _get_replicated_slices(num_addressable_devices: int):
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
|
||||
def _shard_token(x, sharding):
|
||||
devices = get_addressable_devices_for_shard_arg(sharding)
|
||||
indices = _get_replicated_slices(len(devices))
|
||||
zeros = np.zeros((), dtype=np.dtype(np.bool_))
|
||||
aval = api_util.shaped_abstractify(zeros)
|
||||
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
|
||||
shard_arg_handlers[core.Token] = _shard_token
|
||||
|
||||
def _masked_array_error(x, sharding):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
@ -1148,8 +1141,9 @@ class ExecuteReplicated:
|
||||
def _add_tokens_to_inputs(self, input_bufs):
|
||||
if self.ordered_effects:
|
||||
tokens = [
|
||||
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
|
||||
for eff in self.ordered_effects]
|
||||
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf
|
||||
for eff in self.ordered_effects
|
||||
]
|
||||
input_bufs = [*tokens, *input_bufs]
|
||||
return input_bufs
|
||||
|
||||
@ -1163,7 +1157,7 @@ class ExecuteReplicated:
|
||||
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
||||
assert len(token_buf) > 0
|
||||
if len(token_buf) == 1:
|
||||
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
||||
dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0]))
|
||||
else:
|
||||
token_devices = []
|
||||
for token in token_buf:
|
||||
@ -1173,7 +1167,9 @@ class ExecuteReplicated:
|
||||
global_token_array = jax.make_array_from_single_device_arrays(
|
||||
(0,), s, token_buf
|
||||
)
|
||||
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
|
||||
dispatch.runtime_tokens.set_token_result(
|
||||
eff, core.Token(global_token_array)
|
||||
)
|
||||
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
|
@ -148,7 +148,6 @@ from jax._src.core import (
|
||||
subst_axis_names_var as subst_axis_names_var,
|
||||
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
|
||||
thread_local_state as thread_local_state,
|
||||
token as token,
|
||||
trace_state_clean as trace_state_clean,
|
||||
traverse_jaxpr_params as traverse_jaxpr_params,
|
||||
typecheck as typecheck,
|
||||
|
@ -673,8 +673,12 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
arr = jnp.ones(10)
|
||||
token = jax.lax.create_token()
|
||||
_, out_token = noop(arr, token)
|
||||
|
||||
self.assertEqual(token, noop(arr, token)[1])
|
||||
self.assertIsInstance(token, core.Token)
|
||||
self.assertIsInstance(out_token, core.Token)
|
||||
# Different token objects.
|
||||
self.assertIsNot(token, out_token)
|
||||
|
||||
def test_jit_bad_input(self):
|
||||
def f(x):
|
||||
@ -1226,7 +1230,6 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
|
||||
self.assertNotIn(s, hlo_str)
|
||||
|
||||
|
||||
@parameterized.parameters([0, 2, [(0, 2)]])
|
||||
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
|
||||
def f(x, y, *args, **kwargs):
|
||||
@ -3732,7 +3735,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(x, core.Token)
|
||||
|
||||
def test_jit_capturing_token(self):
|
||||
tok = core.token
|
||||
tok = jax.lax.create_token()
|
||||
_, y = jax.jit(lambda x: (x + 2, tok))(7)
|
||||
self.assertIsInstance(y, core.Token)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user