mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #20514 from gnecula:callback_cache
PiperOrigin-RevId: 621160168
This commit is contained in:
commit
4c41c12e21
@ -14,6 +14,7 @@
|
||||
"""Module for JAX callbacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import functools
|
||||
@ -21,6 +22,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
@ -46,10 +48,27 @@ dispatch.prim_requires_devices_during_lowering.add(pure_callback_p)
|
||||
map, unsafe_map = util.safe_map, map
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _FlatCallback:
|
||||
"""A Python function callable with flat arguments and results.
|
||||
|
||||
An instance of this class is used as a parameter for the callback primitives.
|
||||
We prefer it to an anonymous flattened function because it produces
|
||||
equal objects when we call the same Python function with the same argument
|
||||
structure.
|
||||
"""
|
||||
callback_func: Callable[..., Any]
|
||||
in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`.
|
||||
|
||||
def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]:
|
||||
args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args)
|
||||
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
|
||||
|
||||
|
||||
def pure_callback_impl(
|
||||
*args,
|
||||
result_avals,
|
||||
callback: Callable[..., Any],
|
||||
callback: _FlatCallback,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
vectorized: bool,
|
||||
):
|
||||
@ -68,7 +87,7 @@ pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
@pure_callback_p.def_abstract_eval
|
||||
def pure_callback_abstract_eval(
|
||||
*avals,
|
||||
callback: Callable[..., Any],
|
||||
callback: _FlatCallback,
|
||||
result_avals,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
vectorized: bool,
|
||||
@ -100,7 +119,7 @@ def pure_callback_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
*,
|
||||
callback,
|
||||
callback: _FlatCallback,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
vectorized: bool,
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
||||
|
||||
|
||||
def pure_callback_lowering(
|
||||
ctx, *args, callback, sharding: SingleDeviceSharding | None, **params
|
||||
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
|
||||
):
|
||||
def _callback(*flat_args):
|
||||
return tuple(
|
||||
@ -265,10 +284,6 @@ def pure_callback(
|
||||
|
||||
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
"""
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
return tree_util.tree_leaves(callback(*args, **kwargs))
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
||||
result_avals = tree_util.tree_map(
|
||||
@ -276,7 +291,7 @@ def pure_callback(
|
||||
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
|
||||
out_flat = pure_callback_p.bind(
|
||||
*flat_args,
|
||||
callback=_flat_callback,
|
||||
callback=_FlatCallback(callback, in_tree),
|
||||
result_avals=tuple(flat_result_avals),
|
||||
sharding=sharding,
|
||||
vectorized=vectorized,
|
||||
@ -378,7 +393,7 @@ effects.shardable_ordered_effects.add_type(OrderedIOEffect)
|
||||
def io_callback_impl(
|
||||
*args,
|
||||
result_avals,
|
||||
callback: Callable[..., Any],
|
||||
callback: _FlatCallback,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
ordered: bool,
|
||||
):
|
||||
@ -397,7 +412,7 @@ io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
||||
@io_callback_p.def_effectful_abstract_eval
|
||||
def io_callback_abstract_eval(
|
||||
*avals,
|
||||
callback: Callable[..., Any],
|
||||
callback: _FlatCallback,
|
||||
result_avals,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
ordered: bool,
|
||||
@ -516,10 +531,6 @@ def io_callback(
|
||||
|
||||
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
"""
|
||||
def _flat_callback(*flat_args):
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
|
||||
return tree_util.tree_leaves(callback(*args, **kwargs))
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
||||
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
|
||||
@ -528,7 +539,7 @@ def io_callback(
|
||||
flat_args = map(core.raise_as_much_as_possible, flat_args)
|
||||
out_flat = io_callback_p.bind(
|
||||
*flat_args,
|
||||
callback=_flat_callback,
|
||||
callback=_FlatCallback(callback, in_tree),
|
||||
result_avals=tuple(flat_result_avals),
|
||||
sharding=sharding,
|
||||
ordered=ordered,
|
||||
|
@ -586,6 +586,20 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertIn(f"jax.{api_name} failed", output)
|
||||
self.assertIn("Traceback (most recent call last)", output)
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
def test_compilation_caching(self, *, callback):
|
||||
def f_outside(x):
|
||||
return 2 * x
|
||||
|
||||
def fun(x):
|
||||
return callback(f_outside, x, x)
|
||||
|
||||
x = np.arange(6, dtype=np.int32).reshape((2, 3))
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
for _ in range(3):
|
||||
self.assertAllClose(2 * x, fun(x))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
|
||||
class PureCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user