Apply suggestions from code review

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Sharad Vikram 2022-08-17 10:43:50 -07:00
parent 393bca122d
commit b0fdf10a63
5 changed files with 93 additions and 60 deletions

View File

@ -15,7 +15,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
the `concrete` option, following the previous version's deprecation; see the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
* Changes * Changes
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`. * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
## jax 0.3.16 ## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).

View File

@ -3233,37 +3233,45 @@ def block_until_ready(x):
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any): *args: Any, **kwargs: Any):
"""Calls a pure Python callback function from staged out JAX programs. """Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions. ``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on the CPU host. should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The callback is treated as "pure" meaning it can be called multiple times when The callback is treated as functionally pure, meaning it has no side-effects
transformed (for example in a ``vmap`` or ``pmap``), and it can also and its output value depends only on its argument values. As a consequence, it
potentially be removed from JAX programs via dead-code elimination. Pure is safe to be called multiple times (e.g. when transformed by ``vmap`` or
callbacks can also be reordered if data-dependence allows. ``pmap``), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When ``pmap``-ed, the pure callback will be called several times (one on each axis When ``pmap``-ed, the pure callback will be called several times (one on each
of the map). When `vmap`-ed the behavior will depend on the value of the axis of the map). When `vmap`-ed the behavior will depend on the value of the
``rank_polymorphic`` keyword argument. If the callback is indicated as rank ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
polymorphic, the callback will be called directly on batched inputs (where the is assumed to obey
batch axis is the leading dimension). Additionally, the callbacks should ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
return outputs that also have a leading batch axis. If not rank polymorphic, Therefore, the callback will be called directly on batched inputs (where the
``callback`` will be mapped sequentially across the batched axis. batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
Args: Args:
callback: A Python callable. The callable will be passed in NumPy arrays and callback: A Python callable. The callable will be passed PyTrees of NumPy
should return a PyTree of NumPy arrays that matches arrays as arguments, and should return a PyTree of NumPy arrays that
``result_shape_dtypes``. matches ``result_shape_dtypes``.
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
``dtype`` properties that correspond to the shape and dtypes of the and ``dtype`` attributes which represent to the shapes and dtypes of the
outputs of ``callback``. value of ``callback`` applied to ``args`` and ``kwargs``.
*args: The positional arguments to the callback. Must be PyTrees of JAX *args: The positional arguments to the callback. Must be PyTrees of JAX
types. types.
rank_polymorphic: A boolean that indicates whether or not ``callback`` is vectorized: A boolean that indicates whether or not ``callback`` is
rank polymorphic, meaning it can handle arrays with additional leading vectorized, meaning it can handle arrays with additional leading
dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped dimensions. If ``vectorized`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the individually. The callback should also return outputs batched across the

View File

@ -36,15 +36,15 @@ map, unsafe_map = util.safe_map, map
@pure_callback_p.def_impl @pure_callback_p.def_impl
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any], def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
rank_polymorphic: bool): vectorized: bool):
del rank_polymorphic, result_avals del vectorized, result_avals
return callback(*args) return callback(*args)
@pure_callback_p.def_abstract_eval @pure_callback_p.def_abstract_eval
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any], def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
result_avals, rank_polymorphic: bool): result_avals, vectorized: bool):
del avals, callback, rank_polymorphic del avals, callback, vectorized
return result_avals return result_avals
@ -67,26 +67,29 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool, def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
result_avals: Sequence[core.ShapedArray]): result_avals: Sequence[core.ShapedArray]):
axis_size = next(a.shape[0] for a, d in zip(args, dims) axis_size = next(a.shape[0] for a, d in zip(args, dims)
if d is not batching.not_mapped) if d is not batching.not_mapped)
new_args = [] new_args = [arg if dim is batching.not_mapped else
for arg, dim in zip(args, dims): batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
new_args.append(batching.moveaxis(arg, dim, 0)) if vectorized:
if rank_polymorphic:
result_avals = tuple( result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals) for aval in result_avals)
outvals = pure_callback_p.bind( outvals = pure_callback_p.bind(
*new_args, callback=callback, rank_polymorphic=rank_polymorphic, *new_args, callback=callback, vectorized=vectorized,
result_avals=result_avals) result_avals=result_avals)
else: else:
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(*batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return pure_callback_p.bind(
*merged_args, callback=callback, result_avals=result_avals,
vectorized=vectorized)
from jax._src.lax.control_flow import map as lax_map from jax._src.lax.control_flow import map as lax_map
outvals = lax_map( outvals = lax_map(_batch_fun, *batched_args)
functools.partial(pure_callback_p.bind, callback=callback,
rank_polymorphic=rank_polymorphic, result_avals=result_avals),
*new_args)
return tuple(outvals), (0,) * len(outvals) return tuple(outvals), (0,) * len(outvals)
@ -115,7 +118,7 @@ mlir.register_lowering(pure_callback_p, pure_callback_lowering)
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any, def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, rank_polymorphic: bool = False, **kwargs: Any): *args: Any, vectorized: bool = False, **kwargs: Any):
def _flat_callback(*flat_args): def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs)) return tree_util.tree_leaves(callback(*args, **kwargs))
@ -126,5 +129,5 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals) flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
out_flat = pure_callback_p.bind( out_flat = pure_callback_p.bind(
*flat_args, callback=_flat_callback, *flat_args, callback=_flat_callback,
result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic) result_avals=tuple(flat_result_avals), vectorized=vectorized)
return tree_util.tree_unflatten(out_tree, out_flat) return tree_util.tree_unflatten(out_tree, out_flat)

View File

@ -152,6 +152,7 @@ class RuntimeTokenSet(threading.local):
token[0].block_until_ready() token[0].block_until_ready()
for token in self.output_runtime_tokens.values(): for token in self.output_runtime_tokens.values():
token.block_until_ready() token.block_until_ready()
self.clear()
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet() runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
@ -711,7 +712,7 @@ def _check_special(name, xla_shape, buf):
raise FloatingPointError(f"invalid value (inf) encountered in {name}") raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
device: Device, input_bufs): has_host_callbacks: bool, device: Device, input_bufs):
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects] tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
tokens_flat = flatten(tokens) tokens_flat = flatten(tokens)
input_bufs = [*tokens_flat, *input_bufs] input_bufs = [*tokens_flat, *input_bufs]
@ -720,7 +721,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
has_unordered_effects) has_unordered_effects)
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens]) token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
if has_unordered_effects: if has_unordered_effects or has_host_callbacks:
if can_execute_with_token: if can_execute_with_token:
runtime_tokens.set_output_runtime_token(device, runtime_token) runtime_tokens.set_output_runtime_token(device, runtime_token)
else: else:
@ -738,14 +739,15 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
result_handler: Callable, result_handler: Callable,
has_unordered_effects: bool, has_unordered_effects: bool,
ordered_effects: List[core.Effect], ordered_effects: List[core.Effect],
kept_var_idx, host_callbacks, *args): kept_var_idx, has_host_callbacks: bool, *args):
device, = compiled.local_devices() device, = compiled.local_devices()
args, env = input_handler(args) if input_handler else (args, None) args, env = input_handler(args) if input_handler else (args, None)
in_flat = flatten(device_put(x, device) for i, x in enumerate(args) in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx) if i in kept_var_idx)
if has_unordered_effects or ordered_effects: if has_unordered_effects or ordered_effects or has_host_callbacks:
in_flat, token_handler = _add_tokens( in_flat, token_handler = _add_tokens(
has_unordered_effects, ordered_effects, device, in_flat) has_unordered_effects, ordered_effects, has_host_callbacks, device,
in_flat)
if can_execute_with_token: if can_execute_with_token:
out_flat, runtime_token = compiled.execute_with_token(in_flat) out_flat, runtime_token = compiled.execute_with_token(in_flat)
else: else:
@ -755,7 +757,7 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
out_flat = compiled.execute(in_flat) out_flat = compiled.execute(in_flat)
check_special(name, out_flat) check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts) out_bufs = unflatten(out_flat, output_buffer_counts)
if ordered_effects or has_unordered_effects: if ordered_effects or has_unordered_effects or has_host_callbacks:
out_bufs = token_handler(out_bufs, runtime_token) out_bufs = token_handler(out_bufs, runtime_token)
return result_handler(env, out_bufs) return result_handler(env, out_bufs)
@ -766,7 +768,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
result_handler: Callable, result_handler: Callable,
has_unordered_effects: bool, has_unordered_effects: bool,
ordered_effects: List[core.Effect], ordered_effects: List[core.Effect],
kept_var_idx, host_callbacks, *args): kept_var_idx, has_host_callbacks: bool, *args):
if has_unordered_effects or ordered_effects: if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects # TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError( raise NotImplementedError(
@ -965,7 +967,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
execute = _execute_compiled if nreps == 1 else _execute_replicated execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811 unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
result_handler, has_unordered_effects, result_handler, has_unordered_effects,
ordered_effects, kept_var_idx, host_callbacks) ordered_effects, kept_var_idx, bool(host_callbacks))
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call, return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
keepalive) keepalive)
@ -988,7 +990,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
result_handlers = map(partial(aval_to_result_handler, device), out_avals) result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals, unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
result_handlers, has_unordered_effects, result_handlers, has_unordered_effects,
ordered_effects, kept_var_idx, host_callbacks) ordered_effects, kept_var_idx, bool(host_callbacks))
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call, return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive) keepalive)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import functools
import io import io
import textwrap import textwrap
import unittest import unittest
@ -25,6 +26,7 @@ from jax import core
from jax import lax from jax import lax
from jax import tree_util from jax import tree_util
from jax._src import debugging from jax._src import debugging
from jax._src import dispatch
from jax._src import lib as jaxlib from jax._src import lib as jaxlib
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import util from jax._src import util
@ -459,6 +461,10 @@ class PythonCallbackTest(jtu.JaxTestCase):
class PurePythonCallbackTest(jtu.JaxTestCase): class PurePythonCallbackTest(jtu.JaxTestCase):
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@jtu.skip_on_devices(*disabled_backends) @jtu.skip_on_devices(*disabled_backends)
def test_simple_pure_callback(self): def test_simple_pure_callback(self):
@ -504,8 +510,8 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit @jax.jit
def f(x): def f(x):
# Calling a function with a return value that expects no return values # Calling a function with two return values that expects one return value
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x) return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x, x)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
f(2.) f(2.)
@ -513,8 +519,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit @jax.jit
def g(): def g():
# Calling a function with a return value that expects no return values return jax.pure_callback(lambda: (), (
return jax.pure_callback(lambda: None, (
core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32))) core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32)))
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@ -564,7 +569,14 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2))) out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T) np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
def test_vmap_rank_polymorphic_callback(self): @jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y)
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
def test_vmap_vectorized_callback(self):
def cb(x): def cb(x):
self.assertTupleEqual(x.shape, ()) self.assertTupleEqual(x.shape, ())
@ -575,7 +587,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
def f(x): def f(x):
return jax.pure_callback(cb, x, x) return jax.pure_callback(cb, x, x)
_ = f(jnp.arange(4.)) np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
def cb2(x): def cb2(x):
self.assertTupleEqual(x.shape, (4,)) self.assertTupleEqual(x.shape, (4,))
@ -585,11 +597,19 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit @jax.jit
@jax.vmap @jax.vmap
def g(x): def g(x):
return jax.pure_callback(cb2, x, x, rank_polymorphic=True) return jax.pure_callback(cb2, x, x, vectorized=True)
_ = g(jnp.arange(4.)) np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
def test_vmap_rank_polymorphic_callback_errors_if_returns_wrong_shape(self): @jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vectorized=True)
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
def cb(x): def cb(x):
# Reduces over all dimension when it shouldn't # Reduces over all dimension when it shouldn't
@ -598,7 +618,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
@jax.jit @jax.jit
@jax.vmap @jax.vmap
def f(x): def f(x):
return jax.pure_callback(cb, x, x, rank_polymorphic=True) return jax.pure_callback(cb, x, x, vectorized=True)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
f(jnp.arange(4.)) f(jnp.arange(4.))
@ -743,13 +763,13 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
np.testing.assert_allclose(out, jnp.arange(1., 41.)) np.testing.assert_allclose(out, jnp.arange(1., 41.))
@jtu.skip_on_devices(*disabled_backends) @jtu.skip_on_devices(*disabled_backends)
def test_rank_polymorphic_callback_inside_xmap(self): def test_vectorized_callback_inside_xmap(self):
def _callback(x): def _callback(x):
return (x + 1.).astype(x.dtype) return (x + 1.).astype(x.dtype)
def f(x): def f(x):
return jax.pure_callback(_callback, x, x, rank_polymorphic=True) return jax.pure_callback(_callback, x, x, vectorized=True)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'], f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'}) axis_resources={'a': 'dev'})