2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2022 The JAX Authors.
|
2022-08-12 12:39:22 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Module for JAX callbacks."""
|
2022-08-15 17:05:27 -07:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Sequence
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
import dataclasses
|
2022-09-01 15:27:27 -07:00
|
|
|
import functools
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
import logging
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import Any
|
2022-08-12 12:39:22 -07:00
|
|
|
|
2024-04-01 11:28:28 +02:00
|
|
|
import jax
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src import dispatch
|
2022-08-31 12:18:40 -07:00
|
|
|
from jax._src import dtypes
|
2023-02-01 17:50:00 -08:00
|
|
|
from jax._src import effects
|
2023-04-10 10:15:08 -07:00
|
|
|
from jax._src import sharding_impls
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src import tree_util
|
2022-08-12 12:39:22 -07:00
|
|
|
from jax._src import util
|
2023-02-06 22:51:50 -08:00
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
from jax._src.interpreters import batching
|
2023-03-31 08:50:59 -07:00
|
|
|
from jax._src.interpreters import mlir
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src.lax.control_flow.loops import map as lax_map
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-08-25 14:57:00 -07:00
|
|
|
from jax._src.sharding_impls import SingleDeviceSharding
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
import numpy as np
|
2022-08-12 12:39:22 -07:00
|
|
|
|
2024-03-13 07:22:33 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2022-08-12 12:39:22 -07:00
|
|
|
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
|
|
|
pure_callback_p = core.Primitive("pure_callback")
|
|
|
|
pure_callback_p.multiple_results = True
|
2023-12-08 16:31:11 -08:00
|
|
|
dispatch.prim_requires_devices_during_lowering.add(pure_callback_p)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
|
|
|
|
|
2024-04-01 11:28:28 +02:00
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class _FlatCallback:
|
|
|
|
"""A Python function callable with flat arguments and results.
|
|
|
|
|
|
|
|
An instance of this class is used as a parameter for the callback primitives.
|
|
|
|
We prefer it to an anonymous flattened function because it produces
|
|
|
|
equal objects when we call the same Python function with the same argument
|
|
|
|
structure.
|
|
|
|
"""
|
|
|
|
callback_func: Callable[..., Any]
|
|
|
|
in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`.
|
|
|
|
|
|
|
|
def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]:
|
|
|
|
args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args)
|
|
|
|
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
|
|
|
|
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def pure_callback_impl(
|
|
|
|
*args,
|
|
|
|
result_avals,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback: _FlatCallback,
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding: SingleDeviceSharding | None,
|
|
|
|
vectorized: bool,
|
|
|
|
):
|
|
|
|
del sharding, vectorized, result_avals
|
2024-05-16 11:27:56 -07:00
|
|
|
try:
|
|
|
|
cpu_device, *_ = jax.local_devices(backend="cpu")
|
|
|
|
except RuntimeError as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
"jax.pure_callback failed to find a local CPU device to place the"
|
|
|
|
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
|
|
|
|
" JAX_PLATFORMS environment variable."
|
|
|
|
) from e
|
2024-04-19 17:33:17 -07:00
|
|
|
args = jax.device_put(args, cpu_device)
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
with jax.default_device(cpu_device):
|
|
|
|
try:
|
|
|
|
return tree_util.tree_map(np.asarray, callback(*args))
|
|
|
|
except BaseException:
|
|
|
|
logger.exception("jax.pure_callback failed")
|
|
|
|
raise
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
|
2022-09-01 15:27:27 -07:00
|
|
|
pure_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
|
|
|
pure_callback_p))
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pure_callback_p.def_abstract_eval
|
2023-08-25 14:57:00 -07:00
|
|
|
def pure_callback_abstract_eval(
|
|
|
|
*avals,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback: _FlatCallback,
|
2023-08-25 14:57:00 -07:00
|
|
|
result_avals,
|
|
|
|
sharding: SingleDeviceSharding | None,
|
|
|
|
vectorized: bool,
|
|
|
|
):
|
|
|
|
del avals, callback, sharding, vectorized
|
2022-08-12 12:39:22 -07:00
|
|
|
return result_avals
|
|
|
|
|
|
|
|
|
|
|
|
def pure_callback_jvp_rule(*args, **kwargs):
|
|
|
|
del args, kwargs
|
|
|
|
raise ValueError(
|
|
|
|
"Pure callbacks do not support JVP. "
|
|
|
|
"Please use `jax.custom_jvp` to use callbacks while taking gradients.")
|
|
|
|
|
|
|
|
|
|
|
|
ad.primitive_jvps[pure_callback_p] = pure_callback_jvp_rule
|
|
|
|
|
|
|
|
|
|
|
|
def pure_callback_transpose_rule(*args, **kwargs):
|
|
|
|
del args, kwargs
|
|
|
|
raise ValueError(
|
|
|
|
"Pure callbacks do not support transpose. "
|
|
|
|
"Please use `jax.custom_vjp` to use callbacks while taking gradients.")
|
|
|
|
|
|
|
|
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
|
|
|
|
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
def callback_batching_rule(
|
|
|
|
prim,
|
2023-08-25 14:57:00 -07:00
|
|
|
args,
|
|
|
|
dims,
|
|
|
|
*,
|
|
|
|
vectorized: bool,
|
|
|
|
result_avals: Sequence[core.ShapedArray],
|
2024-06-07 11:47:04 -07:00
|
|
|
**kwargs: Any,
|
2023-08-25 14:57:00 -07:00
|
|
|
):
|
2024-02-26 15:09:20 -08:00
|
|
|
axis_size = next(a.shape[d] for a, d in zip(args, dims)
|
2022-08-15 17:05:27 -07:00
|
|
|
if d is not batching.not_mapped)
|
2022-08-17 10:43:50 -07:00
|
|
|
new_args = [arg if dim is batching.not_mapped else
|
|
|
|
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
|
|
|
if vectorized:
|
2022-08-15 17:05:27 -07:00
|
|
|
result_avals = tuple(
|
|
|
|
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
|
|
|
|
for aval in result_avals)
|
2024-06-07 11:47:04 -07:00
|
|
|
outvals = prim.bind(
|
2023-08-25 14:57:00 -07:00
|
|
|
*new_args,
|
|
|
|
vectorized=vectorized,
|
|
|
|
result_avals=result_avals,
|
2024-06-07 11:47:04 -07:00
|
|
|
**kwargs,
|
2023-08-25 14:57:00 -07:00
|
|
|
)
|
2022-08-15 17:05:27 -07:00
|
|
|
else:
|
2022-08-17 10:43:50 -07:00
|
|
|
is_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
2022-09-30 15:35:42 -07:00
|
|
|
def _batch_fun(batched_args):
|
2022-08-17 10:43:50 -07:00
|
|
|
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
2024-06-07 11:47:04 -07:00
|
|
|
return prim.bind(
|
2023-08-25 14:57:00 -07:00
|
|
|
*merged_args,
|
|
|
|
result_avals=result_avals,
|
|
|
|
vectorized=vectorized,
|
2024-06-07 11:47:04 -07:00
|
|
|
**kwargs,
|
2023-08-25 14:57:00 -07:00
|
|
|
)
|
2022-09-30 15:35:42 -07:00
|
|
|
outvals = lax_map(_batch_fun, batched_args)
|
2022-08-12 12:39:22 -07:00
|
|
|
return tuple(outvals), (0,) * len(outvals)
|
|
|
|
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
batching.primitive_batchers[pure_callback_p] = functools.partial(
|
|
|
|
callback_batching_rule, pure_callback_p
|
|
|
|
)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
2023-05-10 16:01:01 -07:00
|
|
|
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
|
|
|
# If we have fully manual sharding during lowering, that means the JAX
|
|
|
|
# program has per-device semantics, so we run the callback on each device.
|
|
|
|
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
|
2023-01-12 15:56:10 -08:00
|
|
|
raise NotImplementedError(
|
2023-08-25 14:57:00 -07:00
|
|
|
"callbacks are only supported in spmd computations when all mesh"
|
2023-01-12 15:56:10 -08:00
|
|
|
" axes are partitioned manually (no partial automatic sharding)."
|
|
|
|
)
|
2023-08-25 14:57:00 -07:00
|
|
|
if sharding is not None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"callbacks do not support specifying sharding inside spmd"
|
|
|
|
" computations"
|
|
|
|
)
|
|
|
|
op_sharding = xc.OpSharding()
|
|
|
|
op_sharding.type = xc.OpSharding.Type.MANUAL
|
|
|
|
return op_sharding
|
|
|
|
|
|
|
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
|
|
|
if sharding is not None:
|
|
|
|
if not isinstance(sharding, SingleDeviceSharding):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"pure_callback only supports SingleDeviceSharding, but got"
|
|
|
|
f" {type(sharding)}"
|
|
|
|
)
|
|
|
|
device = next(iter(sharding.device_set))
|
2023-12-08 14:35:27 -08:00
|
|
|
device_assignment = axis_context.device_assignment
|
|
|
|
if device_assignment is None:
|
|
|
|
raise AssertionError(
|
2024-09-20 07:51:48 -07:00
|
|
|
"Please file a bug at https://github.com/jax-ml/jax/issues")
|
2023-08-25 14:57:00 -07:00
|
|
|
try:
|
2023-12-08 14:35:27 -08:00
|
|
|
device_index = device_assignment.index(device)
|
2023-08-25 14:57:00 -07:00
|
|
|
except IndexError as e:
|
|
|
|
raise ValueError(
|
|
|
|
"Sharding provided to pure_callback specifies a device"
|
|
|
|
f" {device} that is not in the device assignment"
|
2023-12-08 14:35:27 -08:00
|
|
|
f" ({device_assignment})") from e
|
2023-08-25 14:57:00 -07:00
|
|
|
else:
|
|
|
|
device_index = 0
|
|
|
|
|
2023-05-10 16:01:01 -07:00
|
|
|
# If we have fully automatic sharding during lowering, that means the JAX
|
|
|
|
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
|
|
|
# sharding and hence execute it only once on the full logical value).
|
2023-08-25 14:57:00 -07:00
|
|
|
op_sharding = xc.OpSharding()
|
|
|
|
op_sharding.type = xc.OpSharding.Type.MAXIMAL
|
|
|
|
op_sharding.tile_assignment_dimensions = [1]
|
|
|
|
op_sharding.tile_assignment_devices = [device_index]
|
|
|
|
return op_sharding
|
|
|
|
|
|
|
|
# When there's no SPMD partitioning going on, don't annotate a sharding.
|
|
|
|
return None
|
|
|
|
|
2023-01-12 15:56:10 -08:00
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def pure_callback_lowering(
|
2024-04-01 11:28:28 +02:00
|
|
|
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
|
2023-08-25 14:57:00 -07:00
|
|
|
):
|
|
|
|
def _callback(*flat_args):
|
|
|
|
return tuple(
|
|
|
|
pure_callback_impl(
|
|
|
|
*flat_args,
|
|
|
|
callback=callback,
|
|
|
|
sharding=None, # unused.
|
|
|
|
**params,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
2023-10-05 15:06:29 -07:00
|
|
|
result, _, _ = mlir.emit_python_callback(
|
2023-08-25 14:57:00 -07:00
|
|
|
ctx,
|
|
|
|
_callback,
|
|
|
|
None,
|
|
|
|
list(args),
|
|
|
|
ctx.avals_in,
|
|
|
|
ctx.avals_out,
|
2024-05-02 13:44:06 -07:00
|
|
|
has_side_effect=False,
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding=op_sharding,
|
|
|
|
)
|
2022-08-12 12:39:22 -07:00
|
|
|
return result
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2022-08-12 12:39:22 -07:00
|
|
|
mlir.register_lowering(pure_callback_p, pure_callback_lowering)
|
|
|
|
|
2022-08-31 12:18:40 -07:00
|
|
|
def _check_shape_dtype(shape_dtype):
|
|
|
|
dt = np.dtype(shape_dtype.dtype)
|
|
|
|
if dtypes.canonicalize_dtype(dt) != dt:
|
|
|
|
raise ValueError(
|
[callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
2024-04-02 16:03:44 +02:00
|
|
|
"result_shape_dtypes cannot specify 64-bit types when `jax_enable_x64` is disabled")
|
2022-08-12 12:39:22 -07:00
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
def pure_callback(
|
|
|
|
callback: Callable[..., Any],
|
|
|
|
result_shape_dtypes: Any,
|
|
|
|
*args: Any,
|
|
|
|
sharding: SingleDeviceSharding | None = None,
|
|
|
|
vectorized: bool = False,
|
|
|
|
**kwargs: Any,
|
|
|
|
):
|
2024-04-25 10:21:49 -07:00
|
|
|
"""Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc.
|
2023-04-12 07:33:09 -07:00
|
|
|
|
|
|
|
For more explanation, see `External Callbacks`_.
|
|
|
|
|
2024-04-25 10:21:49 -07:00
|
|
|
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
2024-05-16 11:27:56 -07:00
|
|
|
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
|
|
|
|
it should also return JAX arrays on CPU.
|
2024-04-25 10:21:49 -07:00
|
|
|
|
|
|
|
The callback is treated as functionally pure, meaning it has no side-effects
|
|
|
|
and its output value depends only on its argument values. As a consequence, it
|
|
|
|
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
|
|
|
|
:func:`~pmap`), or not to be called at all when e.g. the output of a
|
|
|
|
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
|
|
|
may also be reordered if data-dependence allows.
|
|
|
|
|
|
|
|
When `vmap`-ed the behavior will depend on the value of the
|
|
|
|
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
|
|
|
is assumed to obey
|
|
|
|
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
|
|
|
Therefore, the callback will be called directly on batched inputs (where the
|
|
|
|
batch axes are the leading dimensions). Additionally, the callbacks should
|
|
|
|
return outputs that have corresponding leading batch axes. If not vectorized
|
|
|
|
``callback`` will be mapped sequentially across the batched axis.
|
|
|
|
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
|
|
|
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
|
|
|
arbitrary leading batch dimensions.
|
|
|
|
|
2023-04-12 07:33:09 -07:00
|
|
|
Args:
|
|
|
|
callback: function to execute on the host. The callback is assumed to be a pure
|
|
|
|
function (i.e. one without side-effects): if an impure function is passed, it
|
2024-04-25 10:21:49 -07:00
|
|
|
may behave in unexpected ways, particularly under transformation. The callable
|
|
|
|
will be passed PyTrees of arrays as arguments, and should return a PyTree of
|
|
|
|
arrays that matches ``result_shape_dtypes``.
|
2023-04-12 07:33:09 -07:00
|
|
|
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
|
|
|
whose structure matches the expected output of the callback function at runtime.
|
2023-10-11 19:55:39 -07:00
|
|
|
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
2023-04-12 07:33:09 -07:00
|
|
|
*args: arguments to be passed to the callback function
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding: optional sharding that specifies the device from which the callback should
|
|
|
|
be invoked.
|
2023-04-12 07:33:09 -07:00
|
|
|
vectorized: boolean specifying whether the callback function can operate in a
|
|
|
|
vectorized manner.
|
|
|
|
**kwargs: keyword arguments to be passed to the callback function
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
result: a pytree of :class:`jax.Array` objects whose structure matches that of
|
|
|
|
``result_shape_dtypes``.
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
- :func:`jax.experimental.io_callback`: callback designed for impure functions.
|
|
|
|
- :func:`jax.debug.callback`: callback designed for general-purpose debugging.
|
|
|
|
- :func:`jax.debug.print`: callback designed for printing.
|
|
|
|
|
|
|
|
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
|
|
|
"""
|
2022-08-12 12:39:22 -07:00
|
|
|
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
2022-08-31 12:18:40 -07:00
|
|
|
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
2022-08-12 12:39:22 -07:00
|
|
|
result_avals = tree_util.tree_map(
|
|
|
|
lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
|
|
|
|
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
|
|
|
|
out_flat = pure_callback_p.bind(
|
2023-08-25 14:57:00 -07:00
|
|
|
*flat_args,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback=_FlatCallback(callback, in_tree),
|
2023-08-25 14:57:00 -07:00
|
|
|
result_avals=tuple(flat_result_avals),
|
|
|
|
sharding=sharding,
|
|
|
|
vectorized=vectorized,
|
|
|
|
)
|
2022-08-12 12:39:22 -07:00
|
|
|
return tree_util.tree_unflatten(out_tree, out_flat)
|
2022-11-10 12:00:21 -08:00
|
|
|
|
|
|
|
|
|
|
|
# IO Callback
|
|
|
|
|
|
|
|
io_callback_p = core.Primitive("io_callback")
|
|
|
|
io_callback_p.multiple_results = True
|
2023-12-08 16:31:11 -08:00
|
|
|
dispatch.prim_requires_devices_during_lowering.add(io_callback_p)
|
2022-11-10 12:00:21 -08:00
|
|
|
|
2023-02-01 17:50:00 -08:00
|
|
|
class IOEffect(effects.Effect):
|
2022-11-10 12:00:21 -08:00
|
|
|
__str__ = lambda _: "IO"
|
2023-02-01 17:50:00 -08:00
|
|
|
|
|
|
|
class OrderedIOEffect(effects.Effect):
|
2022-11-10 12:00:21 -08:00
|
|
|
__str__ = lambda _: "OrderedIO"
|
2023-02-01 17:50:00 -08:00
|
|
|
|
2022-11-10 12:00:21 -08:00
|
|
|
_IOEffect = IOEffect()
|
|
|
|
_OrderedIOEffect = OrderedIOEffect()
|
2023-02-01 17:50:00 -08:00
|
|
|
effects.lowerable_effects.add_type(IOEffect)
|
|
|
|
effects.lowerable_effects.add_type(OrderedIOEffect)
|
|
|
|
effects.control_flow_allowed_effects.add_type(IOEffect)
|
|
|
|
effects.control_flow_allowed_effects.add_type(OrderedIOEffect)
|
|
|
|
effects.ordered_effects.add_type(OrderedIOEffect)
|
2023-09-18 02:49:53 -07:00
|
|
|
effects.shardable_ordered_effects.add_type(OrderedIOEffect)
|
2022-11-10 12:00:21 -08:00
|
|
|
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def io_callback_impl(
|
|
|
|
*args,
|
|
|
|
result_avals,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback: _FlatCallback,
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding: SingleDeviceSharding | None,
|
|
|
|
ordered: bool,
|
|
|
|
):
|
|
|
|
del result_avals, sharding, ordered
|
2024-05-16 11:27:56 -07:00
|
|
|
try:
|
|
|
|
cpu_device, *_ = jax.local_devices(backend="cpu")
|
|
|
|
except RuntimeError as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
"jax.io_callback failed to find a local CPU device to place the"
|
|
|
|
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
|
|
|
|
" JAX_PLATFORMS environment variable."
|
|
|
|
) from e
|
2024-04-19 17:33:17 -07:00
|
|
|
args = jax.device_put(args, cpu_device)
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
with jax.default_device(cpu_device):
|
|
|
|
try:
|
|
|
|
return tree_util.tree_map(np.asarray, callback(*args))
|
|
|
|
except BaseException:
|
|
|
|
logger.exception("jax.io_callback failed")
|
|
|
|
raise
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
|
2022-11-10 12:00:21 -08:00
|
|
|
io_callback_p.def_impl(functools.partial(dispatch.apply_primitive,
|
|
|
|
io_callback_p))
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2022-11-10 12:00:21 -08:00
|
|
|
@io_callback_p.def_effectful_abstract_eval
|
2023-08-25 14:57:00 -07:00
|
|
|
def io_callback_abstract_eval(
|
|
|
|
*avals,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback: _FlatCallback,
|
2023-08-25 14:57:00 -07:00
|
|
|
result_avals,
|
|
|
|
sharding: SingleDeviceSharding | None,
|
|
|
|
ordered: bool,
|
|
|
|
):
|
|
|
|
del avals, sharding, callback
|
2022-11-10 12:00:21 -08:00
|
|
|
effect = _OrderedIOEffect if ordered else _IOEffect
|
|
|
|
return result_avals, {effect}
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2022-11-10 12:00:21 -08:00
|
|
|
def io_callback_jvp_rule(*args, **kwargs):
|
|
|
|
del args, kwargs
|
|
|
|
raise ValueError("IO callbacks do not support JVP.")
|
|
|
|
ad.primitive_jvps[io_callback_p] = io_callback_jvp_rule
|
|
|
|
|
|
|
|
|
|
|
|
def io_callback_transpose_rule(*args, **kwargs):
|
|
|
|
del args, kwargs
|
|
|
|
raise ValueError("IO callbacks do not support transpose.")
|
|
|
|
ad.primitive_transposes[io_callback_p] = io_callback_transpose_rule
|
|
|
|
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def io_callback_batching_rule(
|
|
|
|
args, dims, callback, result_avals, sharding, ordered
|
|
|
|
):
|
2022-11-10 12:00:21 -08:00
|
|
|
if ordered:
|
|
|
|
raise ValueError("Cannot `vmap` ordered IO callback.")
|
2024-04-11 20:45:46 -07:00
|
|
|
is_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
new_args = [arg if dim is batching.not_mapped else
|
|
|
|
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
|
|
|
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
|
|
|
def _batch_fun(batched_args):
|
|
|
|
merged = util.merge_lists(is_batched, unbatched_args, batched_args)
|
|
|
|
return io_callback_p.bind(*merged, callback=callback, sharding=sharding,
|
|
|
|
result_avals=result_avals, ordered=False)
|
|
|
|
out_vals = lax_map(_batch_fun, batched_args)
|
|
|
|
return out_vals, (0,) * len(out_vals)
|
2022-11-10 12:00:21 -08:00
|
|
|
batching.primitive_batchers[io_callback_p] = io_callback_batching_rule
|
|
|
|
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
2022-11-10 12:00:21 -08:00
|
|
|
def _callback(*flat_args):
|
2023-08-25 14:57:00 -07:00
|
|
|
return tuple(
|
|
|
|
io_callback_impl(
|
|
|
|
*flat_args,
|
|
|
|
callback=callback,
|
|
|
|
sharding=None, # unused.
|
|
|
|
ordered=ordered,
|
|
|
|
**params,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
2022-11-10 12:00:21 -08:00
|
|
|
if ordered:
|
2024-07-01 08:42:48 -04:00
|
|
|
token = ctx.tokens_in.get(_OrderedIOEffect)
|
2023-10-05 15:06:29 -07:00
|
|
|
result, token, _ = mlir.emit_python_callback(
|
2023-08-25 14:57:00 -07:00
|
|
|
ctx,
|
|
|
|
_callback,
|
|
|
|
token,
|
|
|
|
list(args),
|
|
|
|
ctx.avals_in,
|
|
|
|
ctx.avals_out,
|
2024-05-02 13:44:06 -07:00
|
|
|
has_side_effect=True,
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding=op_sharding,
|
|
|
|
)
|
2024-07-01 08:42:48 -04:00
|
|
|
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
|
2022-11-10 12:00:21 -08:00
|
|
|
else:
|
2023-10-05 15:06:29 -07:00
|
|
|
result, token, _ = mlir.emit_python_callback(
|
2023-08-25 14:57:00 -07:00
|
|
|
ctx,
|
|
|
|
_callback,
|
|
|
|
None,
|
|
|
|
list(args),
|
|
|
|
ctx.avals_in,
|
|
|
|
ctx.avals_out,
|
2024-05-02 13:44:06 -07:00
|
|
|
has_side_effect=True,
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding=op_sharding,
|
|
|
|
)
|
2022-11-10 12:00:21 -08:00
|
|
|
return result
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
|
2022-11-10 12:00:21 -08:00
|
|
|
mlir.register_lowering(io_callback_p, io_callback_lowering)
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
def io_callback(
|
|
|
|
callback: Callable[..., Any],
|
|
|
|
result_shape_dtypes: Any,
|
|
|
|
*args: Any,
|
|
|
|
sharding: SingleDeviceSharding | None = None,
|
|
|
|
ordered: bool = False,
|
|
|
|
**kwargs: Any,
|
|
|
|
):
|
2023-04-12 07:33:09 -07:00
|
|
|
"""Calls an impure Python callback.
|
|
|
|
|
|
|
|
For more explanation, see `External Callbacks`_.
|
|
|
|
|
|
|
|
Args:
|
2024-05-03 12:08:22 +05:30
|
|
|
callback: function to execute on the host. It is assumed to be an impure function.
|
2023-04-12 07:33:09 -07:00
|
|
|
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
|
|
|
|
more efficient execution.
|
|
|
|
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
|
|
|
whose structure matches the expected output of the callback function at runtime.
|
2023-10-11 19:55:39 -07:00
|
|
|
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
2023-04-12 07:33:09 -07:00
|
|
|
*args: arguments to be passed to the callback function
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding: optional sharding that specifies the device from which the callback should
|
|
|
|
be invoked.
|
2023-04-12 07:33:09 -07:00
|
|
|
ordered: boolean specifying whether sequential calls to callback must be ordered.
|
|
|
|
**kwargs: keyword arguments to be passed to the callback function
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
result: a pytree of :class:`jax.Array` objects whose structure matches that of
|
|
|
|
``result_shape_dtypes``.
|
|
|
|
|
|
|
|
See Also:
|
|
|
|
- :func:`jax.pure_callback`: callback designed for pure functions.
|
|
|
|
- :func:`jax.debug.callback`: callback designed for general-purpose debugging.
|
|
|
|
- :func:`jax.debug.print`: callback designed for printing.
|
|
|
|
|
|
|
|
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
|
|
|
"""
|
2022-11-10 12:00:21 -08:00
|
|
|
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
|
|
|
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
|
|
|
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
|
|
|
|
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
|
|
|
|
flat_shape_dtypes)
|
|
|
|
flat_args = map(core.raise_as_much_as_possible, flat_args)
|
|
|
|
out_flat = io_callback_p.bind(
|
2023-08-25 14:57:00 -07:00
|
|
|
*flat_args,
|
2024-04-01 11:28:28 +02:00
|
|
|
callback=_FlatCallback(callback, in_tree),
|
2022-11-10 12:00:21 -08:00
|
|
|
result_avals=tuple(flat_result_avals),
|
2023-08-25 14:57:00 -07:00
|
|
|
sharding=sharding,
|
|
|
|
ordered=ordered,
|
|
|
|
)
|
2022-11-10 12:00:21 -08:00
|
|
|
return tree_util.tree_unflatten(out_tree, out_flat)
|