mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add IO callback
This commit is contained in:
parent
a58e59d98f
commit
3de5c2b716
@ -3347,7 +3347,7 @@ def block_until_ready(x):
|
||||
return jax.tree_util.tree_map(try_to_block, x)
|
||||
|
||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, **kwargs: Any):
|
||||
*args: Any, vectorized: bool = False, **kwargs: Any):
|
||||
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
@ -3390,14 +3390,15 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
via `jax.vmap`, it will be called directly on inputs with leading batch
|
||||
dimensions instead of executing ``callback`` on each mapped input
|
||||
individually. The callback should also return outputs batched across the
|
||||
leading axis.
|
||||
leading axis. By default, ``vectorized`` is ``False``.
|
||||
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
|
||||
Returns:
|
||||
The value of ``callback(*args, **kwargs)``.
|
||||
"""
|
||||
return jcb.pure_callback(callback, result_shape_dtypes, *args, **kwargs)
|
||||
return jcb.pure_callback(callback, result_shape_dtypes, *args,
|
||||
vectorized=vectorized, **kwargs)
|
||||
|
||||
def clear_backends():
|
||||
"""
|
||||
|
@ -23,11 +23,11 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
import numpy as np
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
||||
pure_callback_p = core.Primitive("pure_callback")
|
||||
@ -150,3 +150,106 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*flat_args, callback=_flat_callback,
|
||||
result_avals=tuple(flat_result_avals), vectorized=vectorized)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
# IO Callback
|
||||
|
||||
io_callback_p = core.Primitive("io_callback")
|
||||
io_callback_p.multiple_results = True
|
||||
|
||||
class IOEffect:
|
||||
__str__ = lambda _: "IO"
|
||||
class OrderedIOEffect:
|
||||
__str__ = lambda _: "OrderedIO"
|
||||
_IOEffect = IOEffect()
|
||||
_OrderedIOEffect = OrderedIOEffect()
|
||||
mlir.lowerable_effects.add(_IOEffect)
|
||||
mlir.lowerable_effects.add(_OrderedIOEffect)
|
||||
core.control_flow_allowed_effects.add(_IOEffect)
|
||||
core.control_flow_allowed_effects.add(_OrderedIOEffect)
|
||||
core.ordered_effects.add(_OrderedIOEffect)
|
||||
|
||||
|
||||
def io_callback_impl(*args, result_avals, callback: Callable[..., Any],
|
||||
ordered: bool):
|
||||
del result_avals, ordered
|
||||
return callback(*args)
|
||||
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
io_callback_p))
|
||||
|
||||
@io_callback_p.def_effectful_abstract_eval
|
||||
def io_callback_abstract_eval(*avals, callback: Callable[..., Any],
|
||||
result_avals, ordered: bool):
|
||||
del avals, callback
|
||||
effect = _OrderedIOEffect if ordered else _IOEffect
|
||||
return result_avals, {effect}
|
||||
|
||||
def io_callback_jvp_rule(*args, **kwargs):
|
||||
del args, kwargs
|
||||
raise ValueError("IO callbacks do not support JVP.")
|
||||
ad.primitive_jvps[io_callback_p] = io_callback_jvp_rule
|
||||
|
||||
|
||||
def io_callback_transpose_rule(*args, **kwargs):
|
||||
del args, kwargs
|
||||
raise ValueError("IO callbacks do not support transpose.")
|
||||
ad.primitive_transposes[io_callback_p] = io_callback_transpose_rule
|
||||
|
||||
|
||||
def io_callback_batching_rule(args, dims, callback, result_avals, ordered):
|
||||
if ordered:
|
||||
raise ValueError("Cannot `vmap` ordered IO callback.")
|
||||
return pure_callback_batching_rule(args, dims, callback=callback,
|
||||
vectorized=False, result_avals=result_avals)
|
||||
batching.primitive_batchers[io_callback_p] = io_callback_batching_rule
|
||||
|
||||
def io_callback_lowering(ctx, *args, callback, ordered, **params):
|
||||
|
||||
def _callback(*flat_args):
|
||||
return tuple(io_callback_impl(*flat_args, callback=callback,
|
||||
ordered=ordered, **params))
|
||||
|
||||
# TODO(sharadmv): figure out the best API for sharding callbacks. For now, we
|
||||
# can only safely maximally shard. Should we allow device_index to be passed
|
||||
# in like host_callback?
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
|
||||
if ordered:
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True,
|
||||
sharding=sharding)
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
|
||||
else:
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
|
||||
sharding=sharding)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
mlir.register_lowering(io_callback_p, io_callback_lowering)
|
||||
|
||||
def io_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, ordered: bool = False, **kwargs: Any):
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
return tree_util.tree_leaves(callback(*args, **kwargs))
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
||||
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
|
||||
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
|
||||
flat_shape_dtypes)
|
||||
flat_args = map(core.raise_as_much_as_possible, flat_args)
|
||||
out_flat = io_callback_p.bind(
|
||||
*flat_args, callback=_flat_callback,
|
||||
result_avals=tuple(flat_result_avals),
|
||||
ordered=ordered)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
@ -63,6 +63,7 @@ Effect = Hashable
|
||||
Effects = Set[Effect]
|
||||
no_effects: Effects = set()
|
||||
ordered_effects: Set[Effect] = set()
|
||||
control_flow_allowed_effects: Set[Effect] = set()
|
||||
|
||||
|
||||
class Jaxpr:
|
||||
|
@ -15,21 +15,21 @@
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from typing import Callable, Optional, Sequence, Set
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.core import control_flow_allowed_effects as allowed_effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src import ad_util
|
||||
from jax._src import util
|
||||
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_map, tree_unflatten
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
allowed_effects: Set[core.Effect] = set()
|
||||
allowed_effects.add(lax.InOutFeedEffect.Infeed)
|
||||
allowed_effects.add(lax.InOutFeedEffect.Outfeed)
|
||||
|
||||
|
@ -337,10 +337,17 @@ def _bcast_select_n(pred, *cases):
|
||||
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
|
||||
index, *ops = args
|
||||
index_dim, *op_dims = dims
|
||||
# TODO(sharadmv): clean this up by adding a specific blocklist
|
||||
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
|
||||
branch.jaxpr.effects):
|
||||
raise NotImplementedError(
|
||||
"State effect not supported in cond vmap.")
|
||||
"State effect not supported in vmap-of-cond.")
|
||||
from jax._src.callback import _IOEffect, _OrderedIOEffect
|
||||
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
|
||||
for branch in branches):
|
||||
raise NotImplementedError(
|
||||
"IO effect not supported in vmap-of-cond.")
|
||||
|
||||
|
||||
if index_dim is not batching.not_mapped:
|
||||
# Convert to a lax.select. While we could get away with not broadcasting
|
||||
|
@ -1143,6 +1143,12 @@ def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs):
|
||||
def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims,
|
||||
cond_nconsts, cond_jaxpr,
|
||||
body_nconsts, body_jaxpr):
|
||||
from jax._src.callback import _IOEffect, _OrderedIOEffect
|
||||
if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect]
|
||||
for branch in [body_jaxpr, cond_jaxpr]):
|
||||
raise NotImplementedError(
|
||||
"IO effect not supported in vmap-of-while.")
|
||||
|
||||
orig_batched = [d is not batching.not_mapped for d in dims]
|
||||
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
|
||||
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])
|
||||
|
@ -22,3 +22,6 @@ from jax.experimental.x64_context import (
|
||||
enable_x64 as enable_x64,
|
||||
disable_x64 as disable_x64,
|
||||
)
|
||||
from jax._src.callback import (
|
||||
io_callback as io_callback
|
||||
)
|
||||
|
@ -972,6 +972,9 @@ jax_test(
|
||||
jax_test(
|
||||
name = "python_callback_test",
|
||||
srcs = ["python_callback_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -24,17 +24,18 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import sharding
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import io_callback
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -682,7 +683,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
try:
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
spec = P('x')
|
||||
spec = pjit.PartitionSpec('x')
|
||||
|
||||
def f(x):
|
||||
axis_resources = {v: v for v in mesh.axis_names}
|
||||
@ -699,7 +700,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
with mesh:
|
||||
inp = jnp.arange(float(jax.local_device_count()))
|
||||
out = pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp)
|
||||
out = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)(inp)
|
||||
np.testing.assert_allclose(
|
||||
out, np.sin(np.arange(jax.local_device_count()))
|
||||
)
|
||||
@ -708,7 +709,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, 'when all mesh axes are partitioned manually'
|
||||
):
|
||||
pjit(
|
||||
pjit.pjit(
|
||||
without_xmap_f, in_axis_resources=spec, out_axis_resources=spec
|
||||
)(inp)
|
||||
|
||||
@ -879,5 +880,157 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
x = np.arange(6, dtype=np.int32).reshape((3, 2))
|
||||
np.testing.assert_allclose(g(x), x)
|
||||
|
||||
class IOPythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
def test_io_callback_can_mutate_state(self):
|
||||
x = 0
|
||||
def cb():
|
||||
nonlocal x
|
||||
x += 1
|
||||
return np.array(x, np.int32)
|
||||
|
||||
def f():
|
||||
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
|
||||
f()
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(x, 1)
|
||||
f()
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(x, 2)
|
||||
|
||||
def test_io_callback_can_be_batched_if_unordered(self):
|
||||
_mut = 0
|
||||
def cb(x):
|
||||
nonlocal _mut
|
||||
_mut += 1
|
||||
return x
|
||||
|
||||
x = jnp.arange(4)
|
||||
def f(x):
|
||||
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
|
||||
jax.vmap(f)(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(_mut, 4)
|
||||
jax.vmap(f)(x)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(_mut, 8)
|
||||
|
||||
def test_cannot_call_ordered_io_in_pmap(self):
|
||||
def f(x):
|
||||
return io_callback(
|
||||
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Ordered effects not supported in `pmap`"):
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_cannot_call_ordered_io_in_xmap(self):
|
||||
def f(x):
|
||||
return io_callback(
|
||||
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot `vmap` ordered IO callback"):
|
||||
maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16))
|
||||
|
||||
def test_cannot_call_ordered_io_in_vmap(self):
|
||||
def f(x):
|
||||
return io_callback(
|
||||
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot `vmap` ordered IO callback"):
|
||||
jax.vmap(f)(jnp.arange(4))
|
||||
|
||||
def test_cannot_use_io_callback_in_jvp(self):
|
||||
def f(x):
|
||||
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "IO callbacks do not support JVP."):
|
||||
jax.jvp(f, (0.,), (1.,))
|
||||
|
||||
def test_cannot_use_io_callback_in_linearize(self):
|
||||
def f(x):
|
||||
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "IO callbacks do not support JVP."):
|
||||
jax.linearize(f, 0.)
|
||||
|
||||
def test_cannot_use_io_callback_in_transpose(self):
|
||||
x = jnp.array(1.)
|
||||
|
||||
def f(x):
|
||||
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "IO callbacks do not support transpose."):
|
||||
jax.linear_transpose(f, x)(x)
|
||||
|
||||
def test_cannot_vmap_of_cond_io_callback(self):
|
||||
def f(pred):
|
||||
def true_fun():
|
||||
io_callback(lambda: print("true"), None)
|
||||
def false_fun():
|
||||
io_callback(lambda: print("false"), None)
|
||||
return lax.cond(pred, false_fun, true_fun)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"IO effect not supported in vmap-of-cond."):
|
||||
jax.vmap(f)(jnp.array([True, True]))
|
||||
|
||||
def test_cannot_vmap_of_while_io_callback(self):
|
||||
def check(x):
|
||||
assert np.all(x < 5)
|
||||
|
||||
def f(i):
|
||||
def cond(i):
|
||||
return i < 5
|
||||
def body(i):
|
||||
io_callback(check, None, i)
|
||||
return i + 1
|
||||
return lax.while_loop(cond, body, i)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"IO effect not supported in vmap-of-while."):
|
||||
jax.vmap(f)(jnp.array([0, 4]))
|
||||
|
||||
def test_cannot_use_io_callback_in_checkpoint(self):
|
||||
@jax.grad
|
||||
@jax.checkpoint
|
||||
def f(x, y):
|
||||
io_callback(lambda x: x, y, y)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"Effects not supported in partial-eval of `checkpoint`"):
|
||||
f(2., 3.)
|
||||
|
||||
def test_can_use_io_callback_in_pjit(self):
|
||||
|
||||
_mut = 0
|
||||
def _cb(x):
|
||||
nonlocal _mut
|
||||
_mut = x.sum()
|
||||
|
||||
def f(x):
|
||||
io_callback(_cb, None, x)
|
||||
return x
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
|
||||
else:
|
||||
spec = pjit.PartitionSpec('dev')
|
||||
out_spec = pjit.PartitionSpec()
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
|
||||
with mesh:
|
||||
f(jnp.arange(mesh.size))
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(_mut, jnp.arange(mesh.size).sum())
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user