2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2022 The JAX Authors.
|
2022-08-08 11:41:46 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-07-21 14:20:39 -04:00
|
|
|
|
2023-09-25 13:46:50 -07:00
|
|
|
import collections
|
2024-03-26 05:32:00 -07:00
|
|
|
import contextlib
|
2022-08-17 10:43:50 -07:00
|
|
|
import functools
|
2023-09-18 02:49:53 -07:00
|
|
|
import logging
|
|
|
|
import time
|
2022-08-12 12:39:22 -07:00
|
|
|
import unittest
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
from absl.testing import absltest
|
2023-09-02 21:50:32 -07:00
|
|
|
from absl.testing import parameterized
|
2022-08-08 11:41:46 -07:00
|
|
|
import jax
|
|
|
|
from jax import lax
|
2023-10-11 08:45:30 -07:00
|
|
|
from jax._src import config
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2022-08-17 10:43:50 -07:00
|
|
|
from jax._src import dispatch
|
2022-08-08 11:41:46 -07:00
|
|
|
from jax._src import test_util as jtu
|
|
|
|
from jax._src import util
|
2025-03-26 16:35:30 -07:00
|
|
|
from jax._src.lib import jaxlib_extension_version
|
2023-09-18 02:49:53 -07:00
|
|
|
from jax.experimental import io_callback
|
2022-11-10 12:00:21 -08:00
|
|
|
from jax.experimental import pjit
|
2023-05-10 16:01:01 -07:00
|
|
|
from jax.experimental.shard_map import shard_map
|
2022-08-08 11:41:46 -07:00
|
|
|
import jax.numpy as jnp
|
2023-02-03 14:28:07 -08:00
|
|
|
from jax.sharding import Mesh
|
2022-08-08 11:41:46 -07:00
|
|
|
import numpy as np
|
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
config.parse_flags_with_absl()
|
2025-01-08 06:37:02 -08:00
|
|
|
jtu.request_cpu_devices(2)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
# Some test methods take a kwarg
|
|
|
|
# callback=[io_callback(ordered=True) | io_callback(ordered=False) | pure_callback]
|
|
|
|
io_callback_ordered = functools.partial(io_callback, ordered=True)
|
|
|
|
io_calback_unordered = functools.partial(io_callback, ordered=False)
|
|
|
|
with_pure_and_io_callbacks = parameterized.named_parameters(
|
|
|
|
dict(testcase_name=flavor,
|
|
|
|
callback=dict(io_unordered=io_calback_unordered,
|
|
|
|
io_ordered=io_callback_ordered,
|
|
|
|
pure=jax.pure_callback)[flavor])
|
|
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
|
|
)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2024-03-26 05:32:00 -07:00
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
class PythonCallbackTest(jtu.JaxTestCase):
|
|
|
|
|
2023-08-22 16:14:25 -07:00
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
2023-09-27 12:10:06 -07:00
|
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
2023-09-13 16:35:02 +01:00
|
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
2023-08-22 16:14:25 -07:00
|
|
|
|
2022-08-18 20:33:34 -07:00
|
|
|
def tearDown(self):
|
|
|
|
super().tearDown()
|
|
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_scalar_values(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
out = f(0.)
|
|
|
|
self.assertEqual(out, 1.)
|
|
|
|
|
2024-03-26 05:32:00 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
|
|
|
|
callback=dict(
|
|
|
|
io_unordered=io_calback_unordered,
|
|
|
|
io_ordered=io_callback_ordered,
|
|
|
|
pure=jax.pure_callback,
|
|
|
|
)[flavor],
|
|
|
|
expect_dtype=expect_dtype,
|
|
|
|
)
|
|
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
|
|
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
|
|
|
|
)
|
|
|
|
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
|
|
|
|
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(
|
|
|
|
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
|
|
|
|
)
|
|
|
|
|
[callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
2024-04-02 16:03:44 +02:00
|
|
|
if not config.enable_x64.value and expect_dtype in (np.int64, np.float64):
|
|
|
|
ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types")
|
|
|
|
elif config.enable_x64.value and expect_dtype in (np.int32, np.float32):
|
|
|
|
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
|
2024-03-26 05:32:00 -07:00
|
|
|
else:
|
|
|
|
ctx = contextlib.nullcontext()
|
|
|
|
|
|
|
|
with ctx:
|
|
|
|
out = f(0.0)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, returned_literal)
|
|
|
|
|
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_returning_custom_array(self, *, callback):
|
|
|
|
# Some users write the callback in TF, returning a tf.Tensor. We don't
|
|
|
|
# want to add TF as a dependency, but simulate that use case with a
|
|
|
|
# custom array class.
|
|
|
|
class CustomArray:
|
|
|
|
|
|
|
|
def __init__(self, a: np.ndarray):
|
|
|
|
self.a = a
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
|
|
|
return self.a.shape
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dtype(self):
|
|
|
|
return self.a.dtype
|
|
|
|
|
|
|
|
def __array__(self):
|
|
|
|
return self.a
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(
|
|
|
|
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
|
|
|
|
core.ShapedArray((), np.float32),
|
|
|
|
x,
|
|
|
|
)
|
|
|
|
|
|
|
|
out = f(0.0)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, 42.0)
|
|
|
|
|
2023-09-05 14:22:44 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(testcase_name=f"{flavor}_{dtype}",
|
|
|
|
dtype=dtype,
|
|
|
|
callback=dict(io_unordered=io_calback_unordered,
|
|
|
|
io_ordered=io_callback_ordered,
|
|
|
|
pure=jax.pure_callback)[flavor])
|
|
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
|
|
for dtype in jtu.dtypes.all
|
|
|
|
)
|
|
|
|
def test_callback_works_with_all_types(self, *, callback, dtype):
|
|
|
|
def host_func(x):
|
|
|
|
if dtype == np.bool_:
|
|
|
|
return ~ x
|
|
|
|
else:
|
|
|
|
return x + x
|
|
|
|
_received = None
|
|
|
|
def _cb(x):
|
|
|
|
nonlocal _received
|
|
|
|
_received = x
|
|
|
|
return host_func(x)
|
|
|
|
|
|
|
|
if dtype == np.bool_:
|
|
|
|
x = np.array([True, False, True, True], dtype=np.bool_)
|
|
|
|
else:
|
|
|
|
x = np.arange(4, dtype=dtype)
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(_cb,
|
|
|
|
core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
out = f(x)
|
|
|
|
self.assertAllClose(out, host_func(x))
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(_received, x)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_wrong_number_of_args(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
# Calling a function that expects `x` with no arguments
|
|
|
|
return callback(lambda x: np.ones(4, np.float32),
|
|
|
|
core.ShapedArray((4,), np.float32))
|
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_wrong_number_of_returned_values(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
2023-09-02 21:50:32 -07:00
|
|
|
def f(x):
|
|
|
|
# Calling a function with two return values that expects one return value
|
|
|
|
return callback(lambda x: (x, np.ones(4, np.float32)), x, x)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
2023-09-02 21:50:32 -07:00
|
|
|
f(2.)
|
2022-08-08 11:41:46 -07:00
|
|
|
jax.effects_barrier()
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def g():
|
2023-09-02 21:50:32 -07:00
|
|
|
# Specifically for io_callback, calling a function with a return value
|
|
|
|
# that expects no return values
|
|
|
|
return io_callback(lambda: None, (core.ShapedArray(
|
2022-08-08 11:41:46 -07:00
|
|
|
(1,), np.float32), core.ShapedArray((2,), np.float32)))
|
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
g()
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_wrong_shape_outputs(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
# Calling a function expected a (1,) shaped return value but getting ()
|
|
|
|
return callback(lambda: np.float32(1.), core.ShapedArray((1,),
|
|
|
|
np.float32))
|
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
2024-03-29 13:36:20 +02:00
|
|
|
def test_callback_with_wrong_dtype_outputs(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _cb():
|
|
|
|
return np.array([1], np.float64)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
# Calling a function expected a f32 return value but getting f64
|
|
|
|
return callback(_cb, core.ShapedArray((1,), np.float32))
|
|
|
|
|
[callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
2024-04-02 16:03:44 +02:00
|
|
|
if config.enable_x64.value:
|
|
|
|
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
|
|
|
|
else:
|
|
|
|
ctx = contextlib.nullcontext()
|
|
|
|
with ctx:
|
|
|
|
res = f()
|
2022-08-08 11:41:46 -07:00
|
|
|
jax.effects_barrier()
|
[callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.
2024-04-02 16:03:44 +02:00
|
|
|
self.assertAllClose(res, np.array([1], np.float32))
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
|
2023-10-11 08:45:30 -07:00
|
|
|
if config.enable_x64.value:
|
2023-09-02 21:50:32 -07:00
|
|
|
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
return callback(lambda: np.float64(1.),
|
|
|
|
core.ShapedArray((), np.float64))
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_single_return_value(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
return callback(lambda: np.ones(4, np.float32),
|
|
|
|
core.ShapedArray((4,), np.float32))
|
|
|
|
|
|
|
|
out = f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(out, np.ones(4, np.float32))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_multiple_return_values(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f():
|
|
|
|
return callback(lambda: (np.ones(4, np.float32), np.ones(5, np.int32)),
|
|
|
|
(core.ShapedArray(
|
|
|
|
(4,), np.float32), core.ShapedArray((5,), np.int32)))
|
|
|
|
|
|
|
|
x, y = f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(x, np.ones(4, np.float32))
|
|
|
|
np.testing.assert_allclose(y, np.ones(5, np.int32))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_multiple_arguments_and_return_values(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x, y, z):
|
|
|
|
return (x, y + z)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y, z):
|
|
|
|
return callback(_callback, (core.ShapedArray(
|
|
|
|
(3,), x.dtype), core.ShapedArray((3,), x.dtype)), x, y, z)
|
|
|
|
|
|
|
|
x, y = f(jnp.ones(3), jnp.arange(3.), jnp.arange(3.) + 1.)
|
2023-08-22 03:46:38 -07:00
|
|
|
jax.effects_barrier()
|
2022-08-08 11:41:46 -07:00
|
|
|
np.testing.assert_allclose(x, np.ones(3))
|
|
|
|
np.testing.assert_allclose(y, np.array([1., 3., 5]))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_send_zero_dim_arrays(self, *, callback):
|
2023-08-22 03:46:38 -07:00
|
|
|
result = np.full((2,), 42.0, dtype=np.float32)
|
|
|
|
x = np.zeros((2, 0), np.float32)
|
|
|
|
def _callback(x): # x: f32[2, 0]
|
|
|
|
return result
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2023-08-22 03:46:38 -07:00
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(
|
|
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(f(x), result)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_send_zero_dim_and_non_zero_dim_arrays(self, *, callback):
|
2023-08-22 03:46:38 -07:00
|
|
|
x = np.zeros((2, 0), np.float32)
|
|
|
|
y = np.full((2,), 42.0, dtype=np.float32)
|
|
|
|
result = y
|
|
|
|
def _callback(x, y): # x: f32[2, 0] y: f32[2]
|
|
|
|
return y
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x, y):
|
|
|
|
return callback(
|
|
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x, y)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(f(x, y), result)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_recv_zero_dim_arrays(self, *, callback):
|
2023-08-22 03:46:38 -07:00
|
|
|
result = np.full((2, 0), 42.0, dtype=np.float32)
|
|
|
|
x = np.zeros((2,), np.float32)
|
|
|
|
def _callback(_): # f32[2] -> f32[2, 0]
|
|
|
|
return result
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(
|
|
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(f(x), result)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_recv_zero_dim_and_non_zero_dim_arrays(self, *, callback):
|
2023-08-22 03:46:38 -07:00
|
|
|
x = np.full((2,), 42., dtype=np.float32)
|
|
|
|
result0 = np.ones((2, 0), dtype=np.float32)
|
|
|
|
result1 = x
|
|
|
|
result2 = np.ones((3, 0), dtype=np.int32)
|
|
|
|
result3 = np.concatenate([x, x]) + 1.
|
|
|
|
def _callback(x): # x: f32[2] -> (f32[2, 0], f32[2], f32[3, 0], f32[4])
|
|
|
|
return (result0, x, result2, np.concatenate([x, x]) + 1.)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
2023-08-22 03:46:38 -07:00
|
|
|
return callback(
|
|
|
|
_callback, (core.ShapedArray(result0.shape, result0.dtype),
|
|
|
|
core.ShapedArray(result1.shape, result1.dtype),
|
|
|
|
core.ShapedArray(result2.shape, result2.dtype),
|
|
|
|
core.ShapedArray(result3.shape, result3.dtype)), x)
|
|
|
|
res = f(x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(res, (result0, result1, result2, result3))
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_with_pytree_arguments_and_return_values(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return dict(y=[x])
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(_callback, dict(y=[core.ShapedArray((), np.float32)]),
|
|
|
|
[x])
|
|
|
|
|
|
|
|
out = f(jnp.float32(2.))
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, dict(y=[2.]))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_while_loop_of_scalars(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
def cond(x):
|
|
|
|
return x < 10
|
|
|
|
def body(x):
|
|
|
|
return callback(_callback, core.ShapedArray((), x.dtype), x)
|
|
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
|
|
|
|
out = f(0.)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, 10.)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_while_loop(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
def cond(x):
|
|
|
|
return jnp.any(x < 10)
|
|
|
|
|
|
|
|
def body(x):
|
|
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
|
|
|
|
out = f(jnp.arange(5.))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_cond_of_scalars(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback1(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
def _callback2(x):
|
|
|
|
return (x - 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(pred, x):
|
|
|
|
|
|
|
|
def true_fun(x):
|
|
|
|
return callback(_callback1, core.ShapedArray((), x.dtype), x)
|
|
|
|
|
|
|
|
def false_fun(x):
|
|
|
|
return callback(_callback2, core.ShapedArray((), x.dtype), x)
|
|
|
|
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
|
|
|
|
out = f(True, 1.)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, 2.)
|
|
|
|
out = f(False, 1.)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, 0.)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_cond(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback1(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
def _callback2(x):
|
|
|
|
return x - 1.
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(pred, x):
|
|
|
|
|
|
|
|
def true_fun(x):
|
|
|
|
return callback(_callback1, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
def false_fun(x):
|
|
|
|
return callback(_callback2, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
|
|
|
|
out = f(True, jnp.ones(2))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
|
|
|
|
out = f(False, jnp.ones(2))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(out, jnp.zeros(2))
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_scan_of_scalars(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
def body(x, _):
|
|
|
|
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
return x, ()
|
|
|
|
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
|
|
|
|
out = f(0.)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(out, 10.)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_scan(self, *, callback):
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
def body(x, _):
|
|
|
|
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
return x, ()
|
|
|
|
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
|
|
|
|
out = f(jnp.arange(2.))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_pmap_of_scalars(self, *, callback):
|
|
|
|
if callback is io_callback_ordered:
|
|
|
|
self.skipTest("N/A")
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.pmap
|
|
|
|
def f(x):
|
|
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
out = f(jnp.arange(jax.local_device_count(), dtype=jnp.float32))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
|
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
@with_pure_and_io_callbacks
|
|
|
|
def test_callback_inside_of_pmap(self, *, callback):
|
|
|
|
if callback is io_callback_ordered:
|
|
|
|
self.skipTest("N/A")
|
2022-08-08 11:41:46 -07:00
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
@jax.pmap
|
|
|
|
def f(x):
|
|
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
out = f(
|
|
|
|
jnp.arange(2 * jax.local_device_count(),
|
|
|
|
dtype=jnp.float32).reshape([-1, 2]))
|
|
|
|
jax.effects_barrier()
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out,
|
|
|
|
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
|
|
|
|
2024-03-13 07:22:33 -07:00
|
|
|
@with_pure_and_io_callbacks
|
2025-01-10 06:58:01 -08:00
|
|
|
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
2024-03-13 07:22:33 -07:00
|
|
|
def test_exception_in_callback(self, *, callback):
|
|
|
|
def fail(x):
|
|
|
|
raise RuntimeError("Ooops")
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return callback(fail, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
|
|
|
|
with self.assertLogs(level="ERROR") as l:
|
|
|
|
try:
|
|
|
|
f(0.0).block_until_ready()
|
|
|
|
except RuntimeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
api_name = (
|
|
|
|
"pure_callback" if callback is jax.pure_callback else "io_callback"
|
|
|
|
)
|
|
|
|
output = "\n".join(l.output)
|
|
|
|
self.assertIn(f"jax.{api_name} failed", output)
|
|
|
|
self.assertIn("Traceback (most recent call last)", output)
|
|
|
|
|
2024-04-01 11:28:28 +02:00
|
|
|
@with_pure_and_io_callbacks
|
2025-01-10 06:58:01 -08:00
|
|
|
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
|
2024-04-01 11:28:28 +02:00
|
|
|
def test_compilation_caching(self, *, callback):
|
|
|
|
def f_outside(x):
|
|
|
|
return 2 * x
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
return callback(f_outside, x, x)
|
|
|
|
|
|
|
|
x = np.arange(6, dtype=np.int32).reshape((2, 3))
|
|
|
|
with jtu.count_primitive_compiles() as count:
|
|
|
|
for _ in range(3):
|
|
|
|
self.assertAllClose(2 * x, fun(x))
|
2024-12-11 16:54:52 -05:00
|
|
|
self.assertEqual(count(), 1)
|
2024-04-01 11:28:28 +02:00
|
|
|
|
2025-04-02 03:39:25 -07:00
|
|
|
@parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn")
|
2025-03-26 04:49:46 -07:00
|
|
|
def test_subbyte_operands(self, dtype: str):
|
2025-04-02 03:39:25 -07:00
|
|
|
if jaxlib_extension_version < 323:
|
|
|
|
self.skipTest("Requires jaxlib_extension_version >= 323.")
|
|
|
|
if "2" in dtype and jtu.test_device_matches(["tpu"]):
|
|
|
|
self.skipTest(
|
|
|
|
"TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and"
|
|
|
|
" float4_e2m1fn."
|
|
|
|
)
|
2025-03-26 04:49:46 -07:00
|
|
|
def get(x):
|
|
|
|
return x
|
|
|
|
def f(x):
|
|
|
|
y = jax.pure_callback(
|
|
|
|
get,
|
|
|
|
jax.ShapeDtypeStruct((8,), dtype=dtype),
|
|
|
|
x,
|
|
|
|
)
|
|
|
|
return y
|
|
|
|
x = np.arange(8, dtype=dtype)
|
2025-04-02 03:39:25 -07:00
|
|
|
np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype))
|
2025-03-26 04:49:46 -07:00
|
|
|
|
2025-04-02 03:39:25 -07:00
|
|
|
@parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn")
|
2025-03-26 04:49:46 -07:00
|
|
|
def test_subbyte_results(self, dtype: str):
|
2025-04-02 03:39:25 -07:00
|
|
|
if jaxlib_extension_version < 323:
|
|
|
|
self.skipTest("Requires jaxlib_extension_version >= 323.")
|
|
|
|
if "2" in dtype and jtu.test_device_matches(["tpu"]):
|
|
|
|
self.skipTest(
|
|
|
|
"TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and"
|
|
|
|
" float4_e2m1fn."
|
|
|
|
)
|
2025-03-26 04:49:46 -07:00
|
|
|
def get():
|
|
|
|
return np.arange(8, dtype=dtype)
|
|
|
|
|
|
|
|
def f():
|
|
|
|
y = jax.pure_callback(
|
|
|
|
get,
|
|
|
|
jax.ShapeDtypeStruct((8,), dtype)
|
|
|
|
)
|
|
|
|
return y
|
|
|
|
|
2025-04-02 03:39:25 -07:00
|
|
|
np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype))
|
|
|
|
|
|
|
|
@parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn")
|
|
|
|
def test_non_default_stride_subbyte_results(self, dtype: str):
|
|
|
|
if jaxlib_extension_version < 323:
|
|
|
|
self.skipTest("Requires jaxlib_extension_version >= 323.")
|
|
|
|
if "2" in dtype and jtu.test_device_matches(["tpu"]):
|
|
|
|
self.skipTest(
|
|
|
|
"TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and"
|
|
|
|
" float4_e2m1fn."
|
|
|
|
)
|
|
|
|
x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4)
|
|
|
|
def callback(x):
|
|
|
|
return np.asfortranarray(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(
|
|
|
|
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
|
|
|
|
)
|
|
|
|
|
|
|
|
result = f(x)
|
|
|
|
np.testing.assert_array_equal(x, result)
|
|
|
|
|
|
|
|
def test_non_default_stride(self):
|
|
|
|
x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4)
|
|
|
|
def callback(x):
|
|
|
|
return np.asfortranarray(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(
|
|
|
|
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
|
|
|
|
)
|
|
|
|
|
|
|
|
result = f(x)
|
|
|
|
np.testing.assert_array_equal(x, result)
|
2025-03-26 04:49:46 -07:00
|
|
|
|
2024-03-13 07:22:33 -07:00
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
class PureCallbackTest(jtu.JaxTestCase):
|
2022-08-12 12:39:22 -07:00
|
|
|
|
2023-08-22 16:14:25 -07:00
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
2023-09-27 12:10:06 -07:00
|
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
2023-09-13 16:35:02 +01:00
|
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
2023-08-22 16:14:25 -07:00
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
def tearDown(self):
|
|
|
|
super().tearDown()
|
|
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
def test_pure_callback_passes_jax_arrays_without_jit(self):
|
2022-09-01 15:27:27 -07:00
|
|
|
|
|
|
|
def cb(x):
|
jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
2024-04-06 09:29:16 -07:00
|
|
|
self.assertIsInstance(x, jax.Array)
|
2022-09-01 15:27:27 -07:00
|
|
|
return x
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(cb, x, x)
|
|
|
|
f(jnp.array(2.))
|
|
|
|
|
2022-08-12 12:39:22 -07:00
|
|
|
def test_can_dce_pure_callback(self):
|
|
|
|
|
|
|
|
if jax.default_backend() == "tpu":
|
|
|
|
raise unittest.SkipTest("DCE doesn't currently happen on TPU")
|
|
|
|
|
|
|
|
log = []
|
|
|
|
def _callback(x):
|
|
|
|
# Should never happen!
|
|
|
|
log.append("hello world")
|
|
|
|
return (x * 2.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
_ = jax.pure_callback(_callback, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
return x * 2.
|
|
|
|
_ = f(2.)
|
|
|
|
self.assertEmpty(log)
|
|
|
|
|
|
|
|
def test_can_vmap_pure_callback(self):
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.vmap
|
|
|
|
def f(x):
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
2022-08-12 12:39:22 -07:00
|
|
|
out = f(jnp.arange(4.))
|
|
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def g(x):
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
2022-08-12 12:39:22 -07:00
|
|
|
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
|
|
|
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
@jax.jit
|
|
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
|
|
def h(x, y):
|
2022-11-30 15:17:51 -08:00
|
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
|
|
vmap_method="sequential")
|
2022-08-17 10:43:50 -07:00
|
|
|
out = h(jnp.arange(4.), 4.)
|
2022-11-30 15:17:51 -08:00
|
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
|
|
|
|
rtol=1E-7, check_dtypes=False)
|
2022-08-17 10:43:50 -07:00
|
|
|
|
2022-09-30 15:35:42 -07:00
|
|
|
@jax.jit
|
|
|
|
@functools.partial(jax.vmap)
|
|
|
|
def h(x, y):
|
2022-11-30 15:17:51 -08:00
|
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
|
|
vmap_method="sequential")
|
2022-09-30 15:35:42 -07:00
|
|
|
out = h(jnp.arange(4.), jnp.arange(10., 14.))
|
2022-11-30 15:17:51 -08:00
|
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
|
|
|
|
rtol=1E-7, check_dtypes=False)
|
2022-09-30 15:35:42 -07:00
|
|
|
|
2024-02-26 15:09:20 -08:00
|
|
|
@jax.jit
|
|
|
|
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
|
|
|
|
def h(x, y):
|
|
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
|
|
vmap_method="sequential")
|
2024-02-26 15:09:20 -08:00
|
|
|
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
|
|
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
|
|
|
|
14.)[None],
|
|
|
|
rtol=1E-7, check_dtypes=False)
|
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
def test_vmap_vectorized_callback(self):
|
2022-08-15 17:05:27 -07:00
|
|
|
|
|
|
|
def cb(x):
|
|
|
|
self.assertTupleEqual(x.shape, ())
|
|
|
|
return np.sin(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.vmap
|
|
|
|
def f(x):
|
2024-07-19 17:24:33 -04:00
|
|
|
return jax.pure_callback(cb, x, x, vmap_method="sequential")
|
2022-08-15 17:05:27 -07:00
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
|
2022-08-15 17:05:27 -07:00
|
|
|
|
|
|
|
def cb2(x):
|
|
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
|
|
return np.sin(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.vmap
|
|
|
|
def g(x):
|
2024-10-21 11:53:31 -04:00
|
|
|
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
|
2022-08-17 10:43:50 -07:00
|
|
|
|
|
|
|
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
2022-08-15 17:05:27 -07:00
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
@jax.jit
|
|
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
|
|
def h(x, y):
|
|
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
2024-10-21 11:53:31 -04:00
|
|
|
vmap_method="expand_dims")
|
2022-08-17 10:43:50 -07:00
|
|
|
out = h(jnp.arange(4.), 4.)
|
|
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
2022-08-15 17:05:27 -07:00
|
|
|
|
2024-02-26 15:09:20 -08:00
|
|
|
@jax.jit
|
|
|
|
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
|
|
|
|
def h(x, y):
|
|
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
2024-07-19 17:24:33 -04:00
|
|
|
vmap_method="legacy_vectorized")
|
2024-02-26 15:09:20 -08:00
|
|
|
out = h(jnp.arange(4.)[None], 4.)
|
|
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
|
|
|
|
|
2022-08-17 10:43:50 -07:00
|
|
|
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
|
2022-08-15 17:05:27 -07:00
|
|
|
|
|
|
|
def cb(x):
|
|
|
|
# Reduces over all dimension when it shouldn't
|
|
|
|
return np.sin(x).sum()
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.vmap
|
|
|
|
def f(x):
|
2024-10-21 11:53:31 -04:00
|
|
|
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
|
2022-08-15 17:05:27 -07:00
|
|
|
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
f(jnp.arange(4.))
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
2022-08-12 12:39:22 -07:00
|
|
|
def test_can_pmap_pure_callback(self):
|
|
|
|
|
|
|
|
@jax.pmap
|
|
|
|
def f(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(np.sin, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
out = f(jnp.arange(float(jax.local_device_count())))
|
|
|
|
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
|
|
|
|
|
|
|
|
def test_cant_take_grad_of_pure_callback(self):
|
|
|
|
|
|
|
|
def sin(x):
|
|
|
|
return np.sin(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.grad
|
|
|
|
def f(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(sin, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Pure callbacks do not support JVP."):
|
|
|
|
f(2.)
|
|
|
|
|
2024-03-14 13:22:13 -07:00
|
|
|
def test_error_propagation(self):
|
|
|
|
def throws_error_fn(x):
|
|
|
|
raise RuntimeError("Errors should propagate.")
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(throws_error_fn, x, x)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(Exception, "Errors should propagate."):
|
|
|
|
print(np.array(f(2.0)), flush=True)
|
|
|
|
|
2024-03-26 18:35:49 -07:00
|
|
|
def test_reentrant_error_propagation(self):
|
|
|
|
reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile()
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(reentrant_fn, x, x)
|
|
|
|
|
|
|
|
try:
|
|
|
|
np.array(f(2.0))
|
|
|
|
except:
|
|
|
|
# Only should not deadlock.
|
|
|
|
pass
|
|
|
|
|
2022-08-12 12:39:22 -07:00
|
|
|
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
|
|
|
|
|
|
|
|
@jax.custom_jvp
|
|
|
|
def sin(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(np.sin, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
@sin.defjvp
|
|
|
|
def sin_jvp(xs, ts):
|
|
|
|
(x,), (t,), = xs, ts
|
2022-08-15 17:05:27 -07:00
|
|
|
return sin(x), jax.pure_callback(np.cos, x, x) * t
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
@jax.grad
|
|
|
|
def f(x):
|
|
|
|
return sin(x)
|
|
|
|
out = f(2.)
|
|
|
|
np.testing.assert_allclose(out, jnp.cos(2.))
|
|
|
|
|
|
|
|
def test_callback_inside_of_cond(self):
|
|
|
|
|
|
|
|
def _callback1(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
def _callback2(x):
|
|
|
|
return x - 1.
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(pred, x):
|
|
|
|
|
|
|
|
def true_fun(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(_callback1, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
def false_fun(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(_callback2, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
|
|
|
|
out = f(True, jnp.ones(2))
|
|
|
|
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
|
|
|
|
out = f(False, jnp.ones(2))
|
|
|
|
np.testing.assert_allclose(out, jnp.zeros(2))
|
|
|
|
|
|
|
|
def test_callback_inside_of_scan(self):
|
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
def body(x, _):
|
2022-08-15 17:05:27 -07:00
|
|
|
x = jax.pure_callback(_callback, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
return x, ()
|
|
|
|
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
|
|
|
|
out = f(jnp.arange(2.))
|
|
|
|
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
|
|
|
|
|
|
|
def test_callback_inside_of_while_loop(self):
|
|
|
|
|
|
|
|
def _cond_callback(x):
|
|
|
|
return np.any(x < 10)
|
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
|
|
|
|
def cond(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(
|
2022-08-12 12:39:22 -07:00
|
|
|
_cond_callback, jax.ShapeDtypeStruct((), np.bool_), x)
|
|
|
|
|
|
|
|
def body(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(_callback, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
|
|
|
|
out = f(jnp.arange(5.))
|
|
|
|
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
|
|
|
|
|
|
|
def test_callback_inside_of_pmap(self):
|
|
|
|
|
|
|
|
def _callback(x):
|
|
|
|
return x + 1.
|
|
|
|
|
|
|
|
@jax.pmap
|
|
|
|
def f(x):
|
2022-08-15 17:05:27 -07:00
|
|
|
return jax.pure_callback(_callback, x, x)
|
2022-08-12 12:39:22 -07:00
|
|
|
|
|
|
|
out = f(
|
|
|
|
jnp.arange(2 * jax.local_device_count(),
|
|
|
|
dtype=jnp.float32).reshape([-1, 2]))
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out,
|
|
|
|
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
2022-08-08 11:41:46 -07:00
|
|
|
|
2023-10-05 15:06:29 -07:00
|
|
|
def test_callback_with_immediate_executable_destruction(self):
|
|
|
|
|
|
|
|
def loop_body(i, x):
|
|
|
|
del i
|
|
|
|
return jax.pure_callback(lambda y: y + np.ones(4, np.float32),
|
|
|
|
x, x)
|
|
|
|
|
|
|
|
class AClass:
|
|
|
|
def f(self, ys):
|
|
|
|
return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32))
|
|
|
|
|
|
|
|
num_devices = jax.local_device_count()
|
|
|
|
c = AClass()
|
|
|
|
out = jax.pmap(c.f)(np.ones((num_devices,), np.float32))
|
|
|
|
# c.f is an ephemeral bound method object, and it will be destroyed
|
|
|
|
# immediately. This test verifies that the execution itself keeps the
|
|
|
|
# callback alive.
|
|
|
|
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
|
|
|
|
|
2022-10-04 12:14:13 -07:00
|
|
|
def test_array_layout_is_preserved(self):
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
return jax.pure_callback(lambda x: x, x, x)
|
|
|
|
|
|
|
|
x = np.arange(6, dtype=np.int32).reshape((3, 2))
|
|
|
|
np.testing.assert_allclose(g(x), x)
|
2022-08-15 17:05:27 -07:00
|
|
|
|
2023-05-10 16:01:01 -07:00
|
|
|
def test_can_shard_pure_callback_maximally(self):
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(func, x, x)
|
|
|
|
|
|
|
|
inp = jnp.arange(float(jax.local_device_count()))
|
|
|
|
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
|
|
|
|
jax.block_until_ready(out)
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out, np.arange(jax.local_device_count()) * 2
|
|
|
|
)
|
|
|
|
|
2023-08-25 14:57:00 -07:00
|
|
|
def test_can_shard_pure_callback_maximally_with_sharding(self):
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
|
|
|
|
callback_device = jax.devices()[-1]
|
|
|
|
callback_device_index = sharding._device_assignment.index(callback_device)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
sharding = jax.sharding.SingleDeviceSharding(callback_device)
|
|
|
|
return jax.pure_callback(func, x, x, sharding=sharding)
|
|
|
|
|
|
|
|
f_jit = jax.jit(f, in_shardings=sharding, out_shardings=sharding)
|
|
|
|
|
|
|
|
inp = jnp.arange(float(jax.local_device_count()))
|
|
|
|
out = f_jit(inp)
|
|
|
|
jax.block_until_ready(out)
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out, np.arange(jax.local_device_count()) * 2
|
|
|
|
)
|
2025-01-09 13:37:00 -08:00
|
|
|
stablehlo_ir = f_jit.lower(inp).as_text()
|
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn(
|
|
|
|
"sdy.sharding ="
|
|
|
|
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
|
|
|
" []>]>",
|
|
|
|
stablehlo_ir)
|
|
|
|
self.assertIn(
|
|
|
|
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
|
|
|
f" device_ids=[{callback_device_index}]>",
|
|
|
|
stablehlo_ir)
|
|
|
|
else:
|
|
|
|
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2023-05-10 16:01:01 -07:00
|
|
|
def test_can_shard_pure_callback_manually(self):
|
|
|
|
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return jax.pure_callback(func, x, x)
|
|
|
|
f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec)
|
|
|
|
|
|
|
|
inp = jnp.arange(float(jax.local_device_count() * 2))
|
|
|
|
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
|
|
|
|
y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count())
|
|
|
|
jax.block_until_ready(out)
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
out, inp + y
|
|
|
|
)
|
|
|
|
|
2024-03-27 14:25:19 -07:00
|
|
|
def test_does_not_deadlock(self):
|
|
|
|
if jtu.device_under_test() == "tpu":
|
|
|
|
self.skipTest("The test raises an exception on TPU")
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
y = jnp.asarray(x) + 1
|
|
|
|
return np.asarray(2 * jnp.log(y))
|
|
|
|
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
|
|
|
out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x)
|
|
|
|
np.testing.assert_allclose(out, 2 * jnp.log(x + 1))
|
|
|
|
|
2024-07-19 17:24:33 -04:00
|
|
|
def test_vmap_method_raise(self):
|
|
|
|
@jax.vmap
|
|
|
|
def f(x):
|
|
|
|
# Setting vectorized to None disables the current default behavior of
|
|
|
|
# falling back on sequential.
|
|
|
|
return jax.pure_callback(np.sin, x, x, vectorized=None)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"):
|
|
|
|
f(jnp.arange(4.))
|
|
|
|
|
|
|
|
def test_deprecated_vectorized(self):
|
|
|
|
def f(x, **kwargs):
|
|
|
|
return jax.pure_callback(np.sin, x, x, **kwargs)
|
|
|
|
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The default behavior"):
|
|
|
|
jax.vmap(f)(jnp.arange(4.0))
|
|
|
|
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
|
|
|
f(jnp.arange(4.0), vectorized=True)
|
|
|
|
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
|
|
|
f(jnp.arange(4.0), vectorized=False)
|
|
|
|
|
2024-10-21 11:53:31 -04:00
|
|
|
def test_vmap_method_expand_dims(self):
|
2024-07-19 17:24:33 -04:00
|
|
|
def callback(x, y):
|
|
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
|
|
self.assertTupleEqual(y.shape, (1,))
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
def f(x, y):
|
2024-10-21 11:53:31 -04:00
|
|
|
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
|
2024-07-19 17:24:33 -04:00
|
|
|
|
|
|
|
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
|
|
|
|
2024-10-21 11:53:31 -04:00
|
|
|
def test_vmap_method_broadcast_all(self):
|
2024-07-19 17:24:33 -04:00
|
|
|
def callback(x, y):
|
|
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
|
|
self.assertTupleEqual(y.shape, (4,))
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
def f(x, y):
|
|
|
|
return jax.pure_callback(callback, x, x, y,
|
2024-10-21 11:53:31 -04:00
|
|
|
vmap_method="broadcast_all")
|
2024-07-19 17:24:33 -04:00
|
|
|
|
|
|
|
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
|
|
|
|
2025-01-29 04:23:47 -08:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
|
|
@jtu.run_on_devices("cpu")
|
|
|
|
def test_async_deadlock(self):
|
2025-01-29 09:48:21 -08:00
|
|
|
self.skipTest("Too slow and memory intensive.")
|
2025-01-29 04:23:47 -08:00
|
|
|
|
|
|
|
# See https://github.com/jax-ml/jax/issues/24255
|
|
|
|
eig = jax.jit(jnp.linalg.eig)
|
|
|
|
|
|
|
|
def callback(x):
|
|
|
|
return jax.block_until_ready(eig(x))
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
self.assertEqual(x.dtype, jnp.complex64)
|
|
|
|
out_type = (
|
|
|
|
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
|
|
|
|
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
|
|
)
|
|
|
|
return jax.pure_callback(callback, out_type, x)
|
|
|
|
|
|
|
|
result = 0.0
|
|
|
|
for _ in range(10):
|
|
|
|
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
|
|
|
|
jax.block_until_ready(result) # doesn't deadlock
|
|
|
|
|
2023-05-10 16:01:01 -07:00
|
|
|
|
2023-09-02 21:50:32 -07:00
|
|
|
class IOCallbackTest(jtu.JaxTestCase):
|
2022-11-10 12:00:21 -08:00
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
super().setUp()
|
2023-09-27 12:10:06 -07:00
|
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
2023-09-13 16:35:02 +01:00
|
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
2022-11-10 12:00:21 -08:00
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
super().tearDown()
|
|
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
|
|
|
|
def test_io_callback_can_mutate_state(self):
|
|
|
|
x = 0
|
|
|
|
def cb():
|
|
|
|
nonlocal x
|
|
|
|
x += 1
|
|
|
|
return np.array(x, np.int32)
|
|
|
|
|
|
|
|
def f():
|
|
|
|
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
|
|
|
|
f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(x, 1)
|
|
|
|
f()
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(x, 2)
|
|
|
|
|
|
|
|
def test_io_callback_can_be_batched_if_unordered(self):
|
|
|
|
_mut = 0
|
|
|
|
def cb(x):
|
|
|
|
nonlocal _mut
|
|
|
|
_mut += 1
|
|
|
|
return x
|
|
|
|
|
|
|
|
x = jnp.arange(4)
|
|
|
|
def f(x):
|
|
|
|
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
|
|
|
|
jax.vmap(f)(x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(_mut, 4)
|
|
|
|
jax.vmap(f)(x)
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(_mut, 8)
|
|
|
|
|
|
|
|
def test_cannot_call_ordered_io_in_pmap(self):
|
|
|
|
def f(x):
|
|
|
|
return io_callback(
|
|
|
|
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Ordered effects not supported in `pmap`"):
|
|
|
|
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
|
|
|
|
|
|
|
def test_cannot_call_ordered_io_in_vmap(self):
|
|
|
|
def f(x):
|
|
|
|
return io_callback(
|
|
|
|
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Cannot `vmap` ordered IO callback"):
|
|
|
|
jax.vmap(f)(jnp.arange(4))
|
|
|
|
|
|
|
|
def test_cannot_use_io_callback_in_jvp(self):
|
|
|
|
def f(x):
|
|
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "IO callbacks do not support JVP."):
|
|
|
|
jax.jvp(f, (0.,), (1.,))
|
|
|
|
|
|
|
|
def test_cannot_use_io_callback_in_linearize(self):
|
|
|
|
def f(x):
|
|
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "IO callbacks do not support JVP."):
|
|
|
|
jax.linearize(f, 0.)
|
|
|
|
|
|
|
|
def test_cannot_use_io_callback_in_transpose(self):
|
|
|
|
x = jnp.array(1.)
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "IO callbacks do not support transpose."):
|
|
|
|
jax.linear_transpose(f, x)(x)
|
|
|
|
|
|
|
|
def test_cannot_vmap_of_cond_io_callback(self):
|
|
|
|
def f(pred):
|
|
|
|
def true_fun():
|
|
|
|
io_callback(lambda: print("true"), None)
|
|
|
|
def false_fun():
|
|
|
|
io_callback(lambda: print("false"), None)
|
|
|
|
return lax.cond(pred, false_fun, true_fun)
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
|
|
"IO effect not supported in vmap-of-cond."):
|
|
|
|
jax.vmap(f)(jnp.array([True, True]))
|
|
|
|
|
|
|
|
def test_cannot_vmap_of_while_io_callback(self):
|
|
|
|
def check(x):
|
|
|
|
assert np.all(x < 5)
|
|
|
|
|
|
|
|
def f(i):
|
|
|
|
def cond(i):
|
|
|
|
return i < 5
|
|
|
|
def body(i):
|
|
|
|
io_callback(check, None, i)
|
|
|
|
return i + 1
|
|
|
|
return lax.while_loop(cond, body, i)
|
2024-04-11 21:07:56 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
Exception, "not supported in while_loop with batched predicate"):
|
2022-11-10 12:00:21 -08:00
|
|
|
jax.vmap(f)(jnp.array([0, 4]))
|
|
|
|
|
|
|
|
def test_cannot_use_io_callback_in_checkpoint(self):
|
|
|
|
@jax.grad
|
|
|
|
@jax.checkpoint
|
|
|
|
def f(x, y):
|
|
|
|
io_callback(lambda x: x, y, y)
|
|
|
|
return x
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
|
|
"Effects not supported in partial-eval of `checkpoint`"):
|
|
|
|
f(2., 3.)
|
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
dict(
|
|
|
|
testcase_name=f'{ordered=}_{with_sharding=}',
|
|
|
|
ordered=ordered,
|
|
|
|
with_sharding=with_sharding,
|
|
|
|
)
|
|
|
|
for ordered in [True, False]
|
|
|
|
for with_sharding in [True, False]
|
|
|
|
)
|
|
|
|
def test_can_use_io_callback_in_pjit(
|
|
|
|
self, *, ordered: bool, with_sharding: bool
|
|
|
|
):
|
|
|
|
devices = jax.devices()
|
|
|
|
mesh = jax.sharding.Mesh(np.array(devices), ['dev'])
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
_collected: list[int] = []
|
2023-08-25 14:57:00 -07:00
|
|
|
def _cb(x):
|
2023-09-18 02:49:53 -07:00
|
|
|
nonlocal _collected
|
|
|
|
_collected.append(int(x.sum()))
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
io_callback_kwargs = dict(ordered=ordered)
|
|
|
|
callback_device = devices[0]
|
|
|
|
if with_sharding:
|
|
|
|
callback_device = devices[-1]
|
|
|
|
io_callback_kwargs['sharding'] = jax.sharding.SingleDeviceSharding(
|
|
|
|
callback_device
|
|
|
|
)
|
2023-08-25 14:57:00 -07:00
|
|
|
|
|
|
|
def f(x):
|
2023-09-18 02:49:53 -07:00
|
|
|
io_callback(_cb, None, x, **io_callback_kwargs)
|
|
|
|
io_callback(_cb, None, x + 1, **io_callback_kwargs)
|
2023-08-25 14:57:00 -07:00
|
|
|
return x
|
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
in_spec = jax.sharding.NamedSharding(
|
|
|
|
mesh, jax.sharding.PartitionSpec('dev')
|
|
|
|
)
|
2023-08-25 14:57:00 -07:00
|
|
|
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
2023-09-18 02:49:53 -07:00
|
|
|
f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec)
|
|
|
|
expected = []
|
2023-08-25 14:57:00 -07:00
|
|
|
with mesh:
|
2023-09-18 02:49:53 -07:00
|
|
|
x = jnp.arange(mesh.size)
|
|
|
|
f(x)
|
|
|
|
expected.extend([int(x.sum()), int((x + 1).sum())])
|
|
|
|
f(x + 5)
|
|
|
|
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
|
|
|
|
|
|
|
|
jax.effects_barrier()
|
|
|
|
if ordered:
|
|
|
|
self.assertAllClose(_collected, expected)
|
|
|
|
else:
|
|
|
|
self.assertEqual(len(_collected), len(expected))
|
|
|
|
for v in expected:
|
|
|
|
self.assertIn(v, _collected)
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
callback_device_index = in_spec._device_assignment.index(callback_device)
|
2025-01-09 13:37:00 -08:00
|
|
|
stablehlo_ir = f.lower(x).as_text()
|
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.assertIn(
|
|
|
|
"sdy.sharding ="
|
|
|
|
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
|
|
|
" []>]>",
|
|
|
|
stablehlo_ir)
|
|
|
|
self.assertIn(
|
|
|
|
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
|
|
|
f" device_ids=[{callback_device_index}]>",
|
|
|
|
stablehlo_ir)
|
|
|
|
else:
|
|
|
|
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
2023-08-25 14:57:00 -07:00
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
def test_sequence_pjit_io_callback_ordered(self):
|
|
|
|
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
|
|
|
|
# pair on a different device assignment.
|
|
|
|
_collected: list[int] = []
|
|
|
|
def _cb(i, x):
|
|
|
|
nonlocal _collected
|
|
|
|
# Sleep different amounts of time, to test the ordering.
|
|
|
|
time.sleep([0.02, 0.03, 0.04][len(_collected) % 3])
|
|
|
|
logging.info('Collected iteration %s: %s', i, x)
|
|
|
|
_collected.append(int(x.sum()))
|
|
|
|
|
|
|
|
def f_base(i, x):
|
|
|
|
io_callback(_cb, None, i, x, ordered=True)
|
|
|
|
io_callback(_cb, None, i, x + 1, ordered=True)
|
|
|
|
|
|
|
|
nr_iterations = 8
|
|
|
|
# TODO(zce): If I pin to 1 device below (jax.devices()[:1]) then this test
|
|
|
|
# flakes. It also flakes when pinned to 2 devices. It seems that repeatedly
|
|
|
|
# dispatching to the same device triggers the problem.
|
|
|
|
devices = jax.devices()
|
|
|
|
expected = [] # The expected value for _collected
|
|
|
|
for i in range(nr_iterations):
|
|
|
|
if len(devices) > 1:
|
|
|
|
devices_for_iteration = [
|
|
|
|
devices[i % len(devices)],
|
|
|
|
devices[(i + 1) % len(devices)],
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
devices_for_iteration = devices
|
|
|
|
logging.info(
|
|
|
|
'Running iteration %d on devices %s', i, devices_for_iteration
|
|
|
|
)
|
|
|
|
mesh = jax.sharding.Mesh(np.array(devices_for_iteration), ['dev'])
|
|
|
|
in_spec = (
|
2025-02-02 19:35:00 -08:00
|
|
|
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
|
2023-09-18 02:49:53 -07:00
|
|
|
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')),
|
|
|
|
)
|
|
|
|
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
|
|
|
f = pjit.pjit(f_base, in_shardings=in_spec, out_shardings=out_spec)
|
|
|
|
with mesh:
|
|
|
|
x = jax.device_put(
|
|
|
|
np.arange(len(devices_for_iteration), dtype=np.int32) + 10 * i,
|
|
|
|
in_spec[1],
|
|
|
|
)
|
|
|
|
f(i, x)
|
|
|
|
expected.extend([int(x.sum()), int((x + 1).sum())])
|
|
|
|
f(i, x + 5)
|
|
|
|
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
|
|
|
|
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertEqual(_collected, expected)
|
|
|
|
|
2023-09-25 13:46:50 -07:00
|
|
|
def test_can_shard_io_callback_manually(self):
|
2025-01-09 13:37:00 -08:00
|
|
|
if config.use_shardy_partitioner.value:
|
|
|
|
self.skipTest("TODO(b/384938613): Failing under shardy.")
|
2023-09-25 13:46:50 -07:00
|
|
|
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
|
|
|
|
_collected = collections.defaultdict(list)
|
|
|
|
|
|
|
|
def func(shard_id, x):
|
|
|
|
nonlocal _collected
|
|
|
|
_collected[shard_id.item()].append(x)
|
|
|
|
|
|
|
|
def f(shard_ids, x):
|
|
|
|
io_callback(func, None, shard_ids, x, ordered=True)
|
|
|
|
io_callback(func, None, shard_ids, x + 1, ordered=True)
|
|
|
|
f = shard_map(f, mesh=mesh, in_specs=spec, out_specs=None)
|
|
|
|
|
|
|
|
shard_ids = jnp.arange(mesh.devices.size)
|
|
|
|
inp = jnp.arange(2 * jax.local_device_count())
|
|
|
|
jax.jit(f, in_shardings=sharding, out_shardings=None)(shard_ids, inp)
|
|
|
|
jax.effects_barrier()
|
|
|
|
|
|
|
|
self.assertLen(_collected, mesh.devices.size)
|
|
|
|
# Verify the partial ordering: no specified order across shards, but strict
|
|
|
|
# ordering between the two calls in each shard.
|
|
|
|
for shard in _collected.values():
|
|
|
|
self.assertLen(shard, 2)
|
|
|
|
np.testing.assert_array_equal(shard[0] + 1, shard[1])
|
|
|
|
|
2024-04-11 20:45:46 -07:00
|
|
|
def test_batching_with_side_effects(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
|
2024-04-11 20:45:46 -07:00
|
|
|
x_lst = []
|
|
|
|
def append_x(x):
|
|
|
|
nonlocal x_lst
|
|
|
|
x_lst.append(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
io_callback(append_x, None, x, ordered=False)
|
|
|
|
io_callback(append_x, None, 2 * x, ordered=False)
|
|
|
|
|
|
|
|
jax.vmap(f)(jnp.arange(3.))
|
|
|
|
jax.effects_barrier()
|
|
|
|
self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False)
|
|
|
|
|
2024-04-11 21:07:56 -07:00
|
|
|
def test_batching_with_side_effects_while_loop(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
|
2024-04-11 21:07:56 -07:00
|
|
|
x_lst = []
|
|
|
|
def append_x(x):
|
|
|
|
nonlocal x_lst
|
|
|
|
x_lst.append(x)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
def body(i):
|
|
|
|
io_callback(append_x, None, x, ordered=False)
|
|
|
|
io_callback(append_x, None, 2 * x, ordered=False)
|
|
|
|
return i + 1
|
|
|
|
jax.lax.while_loop(lambda i: i < 2, body, 0)
|
|
|
|
|
2024-04-15 17:12:45 -07:00
|
|
|
jax.vmap(f)(jnp.arange(3.)) # don't crash
|
2024-04-11 21:07:56 -07:00
|
|
|
jax.effects_barrier()
|
|
|
|
|
2023-09-18 02:49:53 -07:00
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|