Merge pull request #20514 from gnecula:callback_cache

PiperOrigin-RevId: 621160168
This commit is contained in:
jax authors 2024-04-02 06:55:45 -07:00
commit 4c41c12e21
2 changed files with 41 additions and 16 deletions

View File

@ -14,6 +14,7 @@
"""Module for JAX callbacks."""
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
import logging
import functools
@ -21,6 +22,7 @@ from typing import Any, Callable
import numpy as np
import jax
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@ -46,10 +48,27 @@ dispatch.prim_requires_devices_during_lowering.add(pure_callback_p)
map, unsafe_map = util.safe_map, map
@dataclasses.dataclass(frozen=True)
class _FlatCallback:
"""A Python function callable with flat arguments and results.
An instance of this class is used as a parameter for the callback primitives.
We prefer it to an anonymous flattened function because it produces
equal objects when we call the same Python function with the same argument
structure.
"""
callback_func: Callable[..., Any]
in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`.
def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]:
args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args)
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
def pure_callback_impl(
*args,
result_avals,
callback: Callable[..., Any],
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
):
@ -68,7 +87,7 @@ pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
@pure_callback_p.def_abstract_eval
def pure_callback_abstract_eval(
*avals,
callback: Callable[..., Any],
callback: _FlatCallback,
result_avals,
sharding: SingleDeviceSharding | None,
vectorized: bool,
@ -100,7 +119,7 @@ def pure_callback_batching_rule(
args,
dims,
*,
callback,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
result_avals: Sequence[core.ShapedArray],
@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
def pure_callback_lowering(
ctx, *args, callback, sharding: SingleDeviceSharding | None, **params
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
):
def _callback(*flat_args):
return tuple(
@ -265,10 +284,6 @@ def pure_callback(
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
result_avals = tree_util.tree_map(
@ -276,7 +291,7 @@ def pure_callback(
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
out_flat = pure_callback_p.bind(
*flat_args,
callback=_flat_callback,
callback=_FlatCallback(callback, in_tree),
result_avals=tuple(flat_result_avals),
sharding=sharding,
vectorized=vectorized,
@ -378,7 +393,7 @@ effects.shardable_ordered_effects.add_type(OrderedIOEffect)
def io_callback_impl(
*args,
result_avals,
callback: Callable[..., Any],
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
ordered: bool,
):
@ -397,7 +412,7 @@ io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
@io_callback_p.def_effectful_abstract_eval
def io_callback_abstract_eval(
*avals,
callback: Callable[..., Any],
callback: _FlatCallback,
result_avals,
sharding: SingleDeviceSharding | None,
ordered: bool,
@ -516,10 +531,6 @@ def io_callback(
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
def _flat_callback(*flat_args):
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
return tree_util.tree_leaves(callback(*args, **kwargs))
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
@ -528,7 +539,7 @@ def io_callback(
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind(
*flat_args,
callback=_flat_callback,
callback=_FlatCallback(callback, in_tree),
result_avals=tuple(flat_result_avals),
sharding=sharding,
ordered=ordered,

View File

@ -586,6 +586,20 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertIn(f"jax.{api_name} failed", output)
self.assertIn("Traceback (most recent call last)", output)
@with_pure_and_io_callbacks
def test_compilation_caching(self, *, callback):
def f_outside(x):
return 2 * x
def fun(x):
return callback(f_outside, x, x)
x = np.arange(6, dtype=np.int32).reshape((2, 3))
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * x, fun(x))
self.assertEqual(count[0], 1)
class PureCallbackTest(jtu.JaxTestCase):