mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
393bca122d
commit
b0fdf10a63
@ -15,7 +15,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
the `concrete` option, following the previous version's deprecation; see
|
||||
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
|
||||
* Changes
|
||||
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from functions compiled with `jax.jit` or `jax.pmap`.
|
||||
* Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`).
|
||||
|
||||
## jax 0.3.16
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||
|
@ -3233,37 +3233,45 @@ def block_until_ready(x):
|
||||
|
||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, **kwargs: Any):
|
||||
"""Calls a pure Python callback function from staged out JAX programs.
|
||||
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
||||
should also return NumPy arrays. Execution takes place on the CPU host.
|
||||
should also return NumPy arrays. Execution takes place on CPU, like any
|
||||
Python+NumPy function.
|
||||
|
||||
The callback is treated as "pure" meaning it can be called multiple times when
|
||||
transformed (for example in a ``vmap`` or ``pmap``), and it can also
|
||||
potentially be removed from JAX programs via dead-code elimination. Pure
|
||||
callbacks can also be reordered if data-dependence allows.
|
||||
The callback is treated as functionally pure, meaning it has no side-effects
|
||||
and its output value depends only on its argument values. As a consequence, it
|
||||
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
|
||||
``pmap``), or not to be called at all when e.g. the output of a
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
When ``pmap``-ed, the pure callback will be called several times (one on each axis
|
||||
of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||
``rank_polymorphic`` keyword argument. If the callback is indicated as rank
|
||||
polymorphic, the callback will be called directly on batched inputs (where the
|
||||
batch axis is the leading dimension). Additionally, the callbacks should
|
||||
return outputs that also have a leading batch axis. If not rank polymorphic,
|
||||
``callback`` will be mapped sequentially across the batched axis.
|
||||
When ``pmap``-ed, the pure callback will be called several times (one on each
|
||||
axis of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||
is assumed to obey
|
||||
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
||||
Therefore, the callback will be called directly on batched inputs (where the
|
||||
batch axes are the leading dimensions). Additionally, the callbacks should
|
||||
return outputs that have corresponding leading batch axes. If not vectorized
|
||||
``callback`` will be mapped sequentially across the batched axis.
|
||||
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
||||
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
||||
arbitrary leading batch dimensions.
|
||||
|
||||
Args:
|
||||
callback: A Python callable. The callable will be passed in NumPy arrays and
|
||||
should return a PyTree of NumPy arrays that matches
|
||||
``result_shape_dtypes``.
|
||||
result_shape_dtypes: A PyTree of Python objects that have ``shape`` and
|
||||
``dtype`` properties that correspond to the shape and dtypes of the
|
||||
outputs of ``callback``.
|
||||
callback: A Python callable. The callable will be passed PyTrees of NumPy
|
||||
arrays as arguments, and should return a PyTree of NumPy arrays that
|
||||
matches ``result_shape_dtypes``.
|
||||
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
|
||||
and ``dtype`` attributes which represent to the shapes and dtypes of the
|
||||
value of ``callback`` applied to ``args`` and ``kwargs``.
|
||||
*args: The positional arguments to the callback. Must be PyTrees of JAX
|
||||
types.
|
||||
rank_polymorphic: A boolean that indicates whether or not ``callback`` is
|
||||
rank polymorphic, meaning it can handle arrays with additional leading
|
||||
dimensions. If ``rank_polymorphic`` is `True`, when the callback is mapped
|
||||
vectorized: A boolean that indicates whether or not ``callback`` is
|
||||
vectorized, meaning it can handle arrays with additional leading
|
||||
dimensions. If ``vectorized`` is `True`, when the callback is mapped
|
||||
via `jax.vmap`, it will be called directly on inputs with leading batch
|
||||
dimensions instead of executing ``callback`` on each mapped input
|
||||
individually. The callback should also return outputs batched across the
|
||||
|
@ -36,15 +36,15 @@ map, unsafe_map = util.safe_map, map
|
||||
|
||||
@pure_callback_p.def_impl
|
||||
def pure_callback_impl(*args, result_avals, callback: Callable[..., Any],
|
||||
rank_polymorphic: bool):
|
||||
del rank_polymorphic, result_avals
|
||||
vectorized: bool):
|
||||
del vectorized, result_avals
|
||||
return callback(*args)
|
||||
|
||||
|
||||
@pure_callback_p.def_abstract_eval
|
||||
def pure_callback_abstract_eval(*avals, callback: Callable[..., Any],
|
||||
result_avals, rank_polymorphic: bool):
|
||||
del avals, callback, rank_polymorphic
|
||||
result_avals, vectorized: bool):
|
||||
del avals, callback, vectorized
|
||||
return result_avals
|
||||
|
||||
|
||||
@ -67,26 +67,29 @@ def pure_callback_transpose_rule(*args, **kwargs):
|
||||
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
|
||||
|
||||
|
||||
def pure_callback_batching_rule(args, dims, *, callback, rank_polymorphic: bool,
|
||||
def pure_callback_batching_rule(args, dims, *, callback, vectorized: bool,
|
||||
result_avals: Sequence[core.ShapedArray]):
|
||||
axis_size = next(a.shape[0] for a, d in zip(args, dims)
|
||||
if d is not batching.not_mapped)
|
||||
new_args = []
|
||||
for arg, dim in zip(args, dims):
|
||||
new_args.append(batching.moveaxis(arg, dim, 0))
|
||||
if rank_polymorphic:
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
if vectorized:
|
||||
result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
|
||||
for aval in result_avals)
|
||||
outvals = pure_callback_p.bind(
|
||||
*new_args, callback=callback, rank_polymorphic=rank_polymorphic,
|
||||
*new_args, callback=callback, vectorized=vectorized,
|
||||
result_avals=result_avals)
|
||||
else:
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(*batched_args):
|
||||
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
||||
return pure_callback_p.bind(
|
||||
*merged_args, callback=callback, result_avals=result_avals,
|
||||
vectorized=vectorized)
|
||||
from jax._src.lax.control_flow import map as lax_map
|
||||
outvals = lax_map(
|
||||
functools.partial(pure_callback_p.bind, callback=callback,
|
||||
rank_polymorphic=rank_polymorphic, result_avals=result_avals),
|
||||
*new_args)
|
||||
outvals = lax_map(_batch_fun, *batched_args)
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
|
||||
@ -115,7 +118,7 @@ mlir.register_lowering(pure_callback_p, pure_callback_lowering)
|
||||
|
||||
|
||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, rank_polymorphic: bool = False, **kwargs: Any):
|
||||
*args: Any, vectorized: bool = False, **kwargs: Any):
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
return tree_util.tree_leaves(callback(*args, **kwargs))
|
||||
@ -126,5 +129,5 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
|
||||
out_flat = pure_callback_p.bind(
|
||||
*flat_args, callback=_flat_callback,
|
||||
result_avals=tuple(flat_result_avals), rank_polymorphic=rank_polymorphic)
|
||||
result_avals=tuple(flat_result_avals), vectorized=vectorized)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
@ -152,6 +152,7 @@ class RuntimeTokenSet(threading.local):
|
||||
token[0].block_until_ready()
|
||||
for token in self.output_runtime_tokens.values():
|
||||
token.block_until_ready()
|
||||
self.clear()
|
||||
|
||||
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
||||
|
||||
@ -711,7 +712,7 @@ def _check_special(name, xla_shape, buf):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
||||
device: Device, input_bufs):
|
||||
has_host_callbacks: bool, device: Device, input_bufs):
|
||||
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
|
||||
tokens_flat = flatten(tokens)
|
||||
input_bufs = [*tokens_flat, *input_bufs]
|
||||
@ -720,7 +721,7 @@ def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
||||
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
|
||||
has_unordered_effects)
|
||||
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
|
||||
if has_unordered_effects:
|
||||
if has_unordered_effects or has_host_callbacks:
|
||||
if can_execute_with_token:
|
||||
runtime_tokens.set_output_runtime_token(device, runtime_token)
|
||||
else:
|
||||
@ -738,14 +739,15 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
result_handler: Callable,
|
||||
has_unordered_effects: bool,
|
||||
ordered_effects: List[core.Effect],
|
||||
kept_var_idx, host_callbacks, *args):
|
||||
kept_var_idx, has_host_callbacks: bool, *args):
|
||||
device, = compiled.local_devices()
|
||||
args, env = input_handler(args) if input_handler else (args, None)
|
||||
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
||||
if i in kept_var_idx)
|
||||
if has_unordered_effects or ordered_effects:
|
||||
if has_unordered_effects or ordered_effects or has_host_callbacks:
|
||||
in_flat, token_handler = _add_tokens(
|
||||
has_unordered_effects, ordered_effects, device, in_flat)
|
||||
has_unordered_effects, ordered_effects, has_host_callbacks, device,
|
||||
in_flat)
|
||||
if can_execute_with_token:
|
||||
out_flat, runtime_token = compiled.execute_with_token(in_flat)
|
||||
else:
|
||||
@ -755,7 +757,7 @@ def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
out_flat = compiled.execute(in_flat)
|
||||
check_special(name, out_flat)
|
||||
out_bufs = unflatten(out_flat, output_buffer_counts)
|
||||
if ordered_effects or has_unordered_effects:
|
||||
if ordered_effects or has_unordered_effects or has_host_callbacks:
|
||||
out_bufs = token_handler(out_bufs, runtime_token)
|
||||
return result_handler(env, out_bufs)
|
||||
|
||||
@ -766,7 +768,7 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
result_handler: Callable,
|
||||
has_unordered_effects: bool,
|
||||
ordered_effects: List[core.Effect],
|
||||
kept_var_idx, host_callbacks, *args):
|
||||
kept_var_idx, has_host_callbacks: bool, *args):
|
||||
if has_unordered_effects or ordered_effects:
|
||||
# TODO(sharadmv): support jit-of-pmap with effects
|
||||
raise NotImplementedError(
|
||||
@ -965,7 +967,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
|
||||
result_handler, has_unordered_effects,
|
||||
ordered_effects, kept_var_idx, host_callbacks)
|
||||
ordered_effects, kept_var_idx, bool(host_callbacks))
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
||||
keepalive)
|
||||
|
||||
@ -988,7 +990,7 @@ class XlaCompiledComputation(stages.XlaExecutable):
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
||||
result_handlers, has_unordered_effects,
|
||||
ordered_effects, kept_var_idx, host_callbacks)
|
||||
ordered_effects, kept_var_idx, bool(host_callbacks))
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
||||
keepalive)
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import textwrap
|
||||
import unittest
|
||||
@ -25,6 +26,7 @@ from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
@ -459,6 +461,10 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_simple_pure_callback(self):
|
||||
|
||||
@ -504,8 +510,8 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
# Calling a function with a return value that expects no return values
|
||||
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x)
|
||||
# Calling a function with two return values that expects one return value
|
||||
return jax.pure_callback(lambda x: (x, np.ones(4, np.float32)), x, x)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
f(2.)
|
||||
@ -513,8 +519,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def g():
|
||||
# Calling a function with a return value that expects no return values
|
||||
return jax.pure_callback(lambda: None, (
|
||||
return jax.pure_callback(lambda: (), (
|
||||
core.ShapedArray((1,), np.float32), core.ShapedArray((2,), np.float32)))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
@ -564,7 +569,14 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
||||
|
||||
def test_vmap_rank_polymorphic_callback(self):
|
||||
@jax.jit
|
||||
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||
def h(x, y):
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y)
|
||||
out = h(jnp.arange(4.), 4.)
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||
|
||||
def test_vmap_vectorized_callback(self):
|
||||
|
||||
def cb(x):
|
||||
self.assertTupleEqual(x.shape, ())
|
||||
@ -575,7 +587,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return jax.pure_callback(cb, x, x)
|
||||
|
||||
_ = f(jnp.arange(4.))
|
||||
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||
|
||||
def cb2(x):
|
||||
self.assertTupleEqual(x.shape, (4,))
|
||||
@ -585,11 +597,19 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def g(x):
|
||||
return jax.pure_callback(cb2, x, x, rank_polymorphic=True)
|
||||
return jax.pure_callback(cb2, x, x, vectorized=True)
|
||||
|
||||
_ = g(jnp.arange(4.))
|
||||
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||
|
||||
def test_vmap_rank_polymorphic_callback_errors_if_returns_wrong_shape(self):
|
||||
@jax.jit
|
||||
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||
def h(x, y):
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
||||
vectorized=True)
|
||||
out = h(jnp.arange(4.), 4.)
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||
|
||||
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
|
||||
|
||||
def cb(x):
|
||||
# Reduces over all dimension when it shouldn't
|
||||
@ -598,7 +618,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
return jax.pure_callback(cb, x, x, rank_polymorphic=True)
|
||||
return jax.pure_callback(cb, x, x, vectorized=True)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
f(jnp.arange(4.))
|
||||
@ -743,13 +763,13 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_rank_polymorphic_callback_inside_xmap(self):
|
||||
def test_vectorized_callback_inside_xmap(self):
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
|
||||
def f(x):
|
||||
return jax.pure_callback(_callback, x, x, rank_polymorphic=True)
|
||||
return jax.pure_callback(_callback, x, x, vectorized=True)
|
||||
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||
axis_resources={'a': 'dev'})
|
||||
|
Loading…
x
Reference in New Issue
Block a user