Apply suggestions from code review

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Sharad Vikram 2022-08-17 10:43:50 -07:00
parent 393bca122d
commit b0fdf10a63
5 changed files with 93 additions and 60 deletions

View File

@ -15,7 +15,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
* Changes
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`.
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).

View File

@ -3233,37 +3233,45 @@ def block_until_ready(x):
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
"""Calls a pure Python callback function from staged out JAX programs.
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on the CPU host.
should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The callback is treated as "pure" meaning it can be called multiple times when
transformed (for example in a ``vmap`` or ``pmap``), and it can also
potentially be removed from JAX programs via dead-code elimination. Pure
callbacks can also be reordered if data-dependence allows.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
``pmap``), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When ``pmap``-ed, the pure callback will be called several times (one on each axis
of the map). When `vmap`-ed the behavior will depend on the value of the
``rank_polymorphic`` keyword argument. If the callback is indicated as rank
polymorphic, the callback will be called directly on batched inputs (where the
batch axis is the leading dimension). Additionally, the callbacks should
return outputs that also have a leading batch axis. If not rank polymorphic,
``callback`` will be mapped sequentially across the batched axis.
When ``pmap``-ed, the pure callback will be called several times (one on each
axis of the map). When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
Therefore, the callback will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
Args:
callback: A Python callable. The callable will be passed in NumPy arrays and
should return a PyTree of NumPy arrays that matches
``result_shape_dtypes``.
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and
``dtype`` properties that correspond to the shape and dtypes of the
outputs of ``callback``.
callback: A Python callable. The callable will be passed PyTrees of NumPy
arrays as arguments, and should return a PyTree of NumPy arrays that
matches ``result_shape_dtypes``.
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
and ``dtype`` attributes which represent to the shapes and dtypes of the
value of ``callback`` applied to ``args`` and ``kwargs``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
rank_polymorphic: A boolean that indicates whether or not ``callback`` is
rank polymorphic, meaning it can handle arrays with additional leading
dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped
vectorized: A boolean that indicates whether or not ``callback`` is
vectorized, meaning it can handle arrays with additional leading
dimensions. If ``vectorized`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the

View File

@ -36,15 +36,15 @@ map, unsafe_map = util.safe_map, map
@pure_callback_p.def_impl
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
rank_polymorphic: bool):
del rank_polymorphic, result_avals
vectorized: bool):
del vectorized, result_avals
return callback(*args)
@pure_callback_p.def_abstract_eval
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
result_avals, rank_polymorphic: bool):
del avals, callback, rank_polymorphic
result_avals, vectorized: bool):
del avals, callback, vectorized
return result_avals
@ -67,26 +67,29 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool,
def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
result_avals: Sequence[core.ShapedArray]):
axis_size = next(a.shape[0] for a, d in zip(args, dims)
if d is not batching.not_mapped)
new_args = []
for arg, dim in zip(args, dims):
new_args.append(batching.moveaxis(arg, dim, 0))
if rank_polymorphic:
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
if vectorized:
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
outvals = pure_callback_p.bind(
*new_args, callback=callback, rank_polymorphic=rank_polymorphic,
*new_args, callback=callback, vectorized=vectorized,
result_avals=result_avals)
else:
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(*batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return pure_callback_p.bind(
*merged_args, callback=callback, result_avals=result_avals,
vectorized=vectorized)
from jax._src.lax.control_flow import map as lax_map
outvals = lax_map(
functools.partial(pure_callback_p.bind, callback=callback,
rank_polymorphic=rank_polymorphic, result_avals=result_avals),
*new_args)
outvals = lax_map(_batch_fun, *batched_args)
return tuple(outvals), (0,) * len(outvals)
@ -115,7 +118,7 @@ mlir.register_lowering(pure_callback_p, pure_callback_lowering)
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, rank_polymorphic: bool = False, **kwargs: Any):
*args: Any, vectorized: bool = False, **kwargs: Any):
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))
@ -126,5 +129,5 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
out_flat = pure_callback_p.bind(
*flat_args, callback=_flat_callback,
result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic)
result_avals=tuple(flat_result_avals), vectorized=vectorized)
return tree_util.tree_unflatten(out_tree, out_flat)

View File

@ -152,6 +152,7 @@ class RuntimeTokenSet(threading.local):
token[0].block_until_ready()
for token in self.output_runtime_tokens.values():
token.block_until_ready()
self.clear()
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
@ -711,7 +712,7 @@ def _check_special(name, xla_shape, buf):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
device: Device, input_bufs):
has_host_callbacks: bool, device: Device, input_bufs):
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
tokens_flat = flatten(tokens)
input_bufs = [*tokens_flat, *input_bufs]
@ -720,7 +721,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
has_unordered_effects)
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
if has_unordered_effects:
if has_unordered_effects or has_host_callbacks:
if can_execute_with_token:
runtime_tokens.set_output_runtime_token(device, runtime_token)
else:
@ -738,14 +739,15 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
result_handler: Callable,
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx, host_callbacks, *args):
kept_var_idx, has_host_callbacks: bool, *args):
device, = compiled.local_devices()
args, env = input_handler(args) if input_handler else (args, None)
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
if has_unordered_effects or ordered_effects:
if has_unordered_effects or ordered_effects or has_host_callbacks:
in_flat, token_handler = _add_tokens(
has_unordered_effects, ordered_effects, device, in_flat)
has_unordered_effects, ordered_effects, has_host_callbacks, device,
in_flat)
if can_execute_with_token:
out_flat, runtime_token = compiled.execute_with_token(in_flat)
else:
@ -755,7 +757,7 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
out_flat = compiled.execute(in_flat)
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if ordered_effects or has_unordered_effects:
if ordered_effects or has_unordered_effects or has_host_callbacks:
out_bufs = token_handler(out_bufs, runtime_token)
return result_handler(env, out_bufs)
@ -766,7 +768,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
result_handler: Callable,
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx, host_callbacks, *args):
kept_var_idx, has_host_callbacks: bool, *args):
if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError(
@ -965,7 +967,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
result_handler, has_unordered_effects,
ordered_effects, kept_var_idx, host_callbacks)
ordered_effects, kept_var_idx, bool(host_callbacks))
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
keepalive)
@ -988,7 +990,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
result_handlers, has_unordered_effects,
ordered_effects, kept_var_idx, host_callbacks)
ordered_effects, kept_var_idx, bool(host_callbacks))
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import io
import textwrap
import unittest
@ -25,6 +26,7 @@ from jax import core
from jax import lax
from jax import tree_util
from jax._src import debugging
from jax._src import dispatch
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
from jax._src import util
@ -459,6 +461,10 @@ class PythonCallbackTest(jtu.JaxTestCase):
class PurePythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@jtu.skip_on_devices(*disabled_backends)
def test_simple_pure_callback(self):
@ -504,8 +510,8 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit
def f(x):
# Calling a function with a return value that expects no return values
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x)
# Calling a function with two return values that expects one return value
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x, x)
with self.assertRaises(RuntimeError):
f(2.)
@ -513,8 +519,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit
def g():
# Calling a function with a return value that expects no return values
return jax.pure_callback(lambda: None, (
return jax.pure_callback(lambda: (), (
core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32)))
with self.assertRaises(RuntimeError):
@ -564,7 +569,14 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
def test_vmap_rank_polymorphic_callback(self):
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y)
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
def test_vmap_vectorized_callback(self):
def cb(x):
self.assertTupleEqual(x.shape, ())
@ -575,7 +587,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
def f(x):
return jax.pure_callback(cb, x, x)
_ = f(jnp.arange(4.))
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
def cb2(x):
self.assertTupleEqual(x.shape, (4,))
@ -585,11 +597,19 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def g(x):
return jax.pure_callback(cb2, x, x, rank_polymorphic=True)
return jax.pure_callback(cb2, x, x, vectorized=True)
_ = g(jnp.arange(4.))
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
def test_vmap_rank_polymorphic_callback_errors_if_returns_wrong_shape(self):
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vectorized=True)
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
def cb(x):
# Reduces over all dimension when it shouldn't
@ -598,7 +618,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, rank_polymorphic=True)
return jax.pure_callback(cb, x, x, vectorized=True)
with self.assertRaises(RuntimeError):
f(jnp.arange(4.))
@ -743,13 +763,13 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(1., 41.))
@jtu.skip_on_devices(*disabled_backends)
def test_rank_polymorphic_callback_inside_xmap(self):
def test_vectorized_callback_inside_xmap(self):
def _callback(x):
return (x + 1.).astype(x.dtype)
def f(x):
return jax.pure_callback(_callback, x, x, rank_polymorphic=True)
return jax.pure_callback(_callback, x, x, vectorized=True)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})