rocm_jax/tests/python_callback_test.py
2025-02-12 13:53:36 -08:00

1385 lines
41 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import logging
import time
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
map, unsafe_map = util.safe_map, map
# Some test methods take a kwarg
# callback=[io_callback(ordered=True) | io_callback(ordered=False) | pure_callback]
io_callback_ordered = functools.partial(io_callback, ordered=True)
io_calback_unordered = functools.partial(io_callback, ordered=False)
with_pure_and_io_callbacks = parameterized.named_parameters(
dict(testcase_name=flavor,
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor])
for flavor in ("io_unordered", "io_ordered", "pure")
)
class PythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@with_pure_and_io_callbacks
def test_callback_with_scalar_values(self, *, callback):
@jax.jit
def f(x):
return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x)
out = f(0.)
self.assertEqual(out, 1.)
@parameterized.named_parameters(
dict(
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
callback=dict(
io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback,
)[flavor],
expect_dtype=expect_dtype,
)
for flavor in ("io_unordered", "io_ordered", "pure")
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
)
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0
@jax.jit
def f(x):
return callback(
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
)
if not config.enable_x64.value and expect_dtype in (np.int64, np.float64):
ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types")
elif config.enable_x64.value and expect_dtype in (np.int32, np.float32):
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()
with ctx:
out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, returned_literal)
@with_pure_and_io_callbacks
def test_callback_returning_custom_array(self, *, callback):
# Some users write the callback in TF, returning a tf.Tensor. We don't
# want to add TF as a dependency, but simulate that use case with a
# custom array class.
class CustomArray:
def __init__(self, a: np.ndarray):
self.a = a
@property
def shape(self):
return self.a.shape
@property
def dtype(self):
return self.a.dtype
def __array__(self):
return self.a
@jax.jit
def f(x):
return callback(
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
core.ShapedArray((), np.float32),
x,
)
out = f(0.0)
jax.effects_barrier()
self.assertEqual(out, 42.0)
@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor])
for flavor in ("io_unordered", "io_ordered", "pure")
for dtype in jtu.dtypes.all
)
def test_callback_works_with_all_types(self, *, callback, dtype):
def host_func(x):
if dtype == np.bool_:
return ~ x
else:
return x + x
_received = None
def _cb(x):
nonlocal _received
_received = x
return host_func(x)
if dtype == np.bool_:
x = np.array([True, False, True, True], dtype=np.bool_)
else:
x = np.arange(4, dtype=dtype)
@jax.jit
def f(x):
return callback(_cb,
core.ShapedArray(x.shape, x.dtype), x)
out = f(x)
self.assertAllClose(out, host_func(x))
jax.effects_barrier()
self.assertAllClose(_received, x)
@with_pure_and_io_callbacks
def test_callback_with_wrong_number_of_args(self, *, callback):
@jax.jit
def f():
# Calling a function that expects `x` with no arguments
return callback(lambda x: np.ones(4, np.float32),
core.ShapedArray((4,), np.float32))
with self.assertRaises(RuntimeError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_number_of_returned_values(self, *, callback):
@jax.jit
def f(x):
# Calling a function with two return values that expects one return value
return callback(lambda x: (x, np.ones(4, np.float32)), x, x)
with self.assertRaises(RuntimeError):
f(2.)
jax.effects_barrier()
@jax.jit
def g():
# Specifically for io_callback, calling a function with a return value
# that expects no return values
return io_callback(lambda: None, (core.ShapedArray(
(1,), np.float32), core.ShapedArray((2,), np.float32)))
with self.assertRaises(RuntimeError):
g()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_shape_outputs(self, *, callback):
@jax.jit
def f():
# Calling a function expected a (1,) shaped return value but getting ()
return callback(lambda: np.float32(1.), core.ShapedArray((1,),
np.float32))
with self.assertRaises(RuntimeError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback):
def _cb():
return np.array([1], np.float64)
@jax.jit
def f():
# Calling a function expected a f32 return value but getting f64
return callback(_cb, core.ShapedArray((1,), np.float32))
if config.enable_x64.value:
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
else:
ctx = contextlib.nullcontext()
with ctx:
res = f()
jax.effects_barrier()
self.assertAllClose(res, np.array([1], np.float32))
@with_pure_and_io_callbacks
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
if config.enable_x64.value:
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
@jax.jit
def f():
return callback(lambda: np.float64(1.),
core.ShapedArray((), np.float64))
with self.assertRaises(ValueError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_single_return_value(self, *, callback):
@jax.jit
def f():
return callback(lambda: np.ones(4, np.float32),
core.ShapedArray((4,), np.float32))
out = f()
jax.effects_barrier()
np.testing.assert_allclose(out, np.ones(4, np.float32))
@with_pure_and_io_callbacks
def test_callback_with_multiple_return_values(self, *, callback):
@jax.jit
def f():
return callback(lambda: (np.ones(4, np.float32), np.ones(5, np.int32)),
(core.ShapedArray(
(4,), np.float32), core.ShapedArray((5,), np.int32)))
x, y = f()
jax.effects_barrier()
np.testing.assert_allclose(x, np.ones(4, np.float32))
np.testing.assert_allclose(y, np.ones(5, np.int32))
@with_pure_and_io_callbacks
def test_callback_with_multiple_arguments_and_return_values(self, *, callback):
def _callback(x, y, z):
return (x, y + z)
@jax.jit
def f(x, y, z):
return callback(_callback, (core.ShapedArray(
(3,), x.dtype), core.ShapedArray((3,), x.dtype)), x, y, z)
x, y = f(jnp.ones(3), jnp.arange(3.), jnp.arange(3.) + 1.)
jax.effects_barrier()
np.testing.assert_allclose(x, np.ones(3))
np.testing.assert_allclose(y, np.array([1., 3., 5]))
@with_pure_and_io_callbacks
def test_send_zero_dim_arrays(self, *, callback):
result = np.full((2,), 42.0, dtype=np.float32)
x = np.zeros((2, 0), np.float32)
def _callback(x): # x: f32[2, 0]
return result
@jax.jit
def f(x):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x)
jax.effects_barrier()
self.assertAllClose(f(x), result)
@with_pure_and_io_callbacks
def test_send_zero_dim_and_non_zero_dim_arrays(self, *, callback):
x = np.zeros((2, 0), np.float32)
y = np.full((2,), 42.0, dtype=np.float32)
result = y
def _callback(x, y): # x: f32[2, 0] y: f32[2]
return y
@jax.jit
def f(x, y):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x, y)
jax.effects_barrier()
self.assertAllClose(f(x, y), result)
@with_pure_and_io_callbacks
def test_recv_zero_dim_arrays(self, *, callback):
result = np.full((2, 0), 42.0, dtype=np.float32)
x = np.zeros((2,), np.float32)
def _callback(_): # f32[2] -> f32[2, 0]
return result
@jax.jit
def f(x):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x)
jax.effects_barrier()
self.assertAllClose(f(x), result)
@with_pure_and_io_callbacks
def test_recv_zero_dim_and_non_zero_dim_arrays(self, *, callback):
x = np.full((2,), 42., dtype=np.float32)
result0 = np.ones((2, 0), dtype=np.float32)
result1 = x
result2 = np.ones((3, 0), dtype=np.int32)
result3 = np.concatenate([x, x]) + 1.
def _callback(x): # x: f32[2] -> (f32[2, 0], f32[2], f32[3, 0], f32[4])
return (result0, x, result2, np.concatenate([x, x]) + 1.)
@jax.jit
def f(x):
return callback(
_callback, (core.ShapedArray(result0.shape, result0.dtype),
core.ShapedArray(result1.shape, result1.dtype),
core.ShapedArray(result2.shape, result2.dtype),
core.ShapedArray(result3.shape, result3.dtype)), x)
res = f(x)
jax.effects_barrier()
self.assertAllClose(res, (result0, result1, result2, result3))
@with_pure_and_io_callbacks
def test_callback_with_pytree_arguments_and_return_values(self, *, callback):
def _callback(x):
return dict(y=[x])
@jax.jit
def f(x):
return callback(_callback, dict(y=[core.ShapedArray((), np.float32)]),
[x])
out = f(jnp.float32(2.))
jax.effects_barrier()
self.assertEqual(out, dict(y=[2.]))
@with_pure_and_io_callbacks
def test_callback_inside_of_while_loop_of_scalars(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return x < 10
def body(x):
return callback(_callback, core.ShapedArray((), x.dtype), x)
return lax.while_loop(cond, body, x)
out = f(0.)
jax.effects_barrier()
self.assertEqual(out, 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_while_loop(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return jnp.any(x < 10)
def body(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return lax.while_loop(cond, body, x)
out = f(jnp.arange(5.))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.arange(10., 15.))
@with_pure_and_io_callbacks
def test_callback_inside_of_cond_of_scalars(self, *, callback):
def _callback1(x):
return (x + 1.).astype(x.dtype)
def _callback2(x):
return (x - 1.).astype(x.dtype)
@jax.jit
def f(pred, x):
def true_fun(x):
return callback(_callback1, core.ShapedArray((), x.dtype), x)
def false_fun(x):
return callback(_callback2, core.ShapedArray((), x.dtype), x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, 1.)
jax.effects_barrier()
self.assertEqual(out, 2.)
out = f(False, 1.)
jax.effects_barrier()
self.assertEqual(out, 0.)
@with_pure_and_io_callbacks
def test_callback_inside_of_cond(self, *, callback):
def _callback1(x):
return x + 1.
def _callback2(x):
return x - 1.
@jax.jit
def f(pred, x):
def true_fun(x):
return callback(_callback1, core.ShapedArray(x.shape, x.dtype), x)
def false_fun(x):
return callback(_callback2, core.ShapedArray(x.shape, x.dtype), x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, jnp.ones(2))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
out = f(False, jnp.ones(2))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.zeros(2))
@with_pure_and_io_callbacks
def test_callback_inside_of_scan_of_scalars(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def body(x, _):
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(0.)
jax.effects_barrier()
self.assertEqual(out, 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_scan(self, *, callback):
def _callback(x):
return x + 1.
@jax.jit
def f(x):
def body(x, _):
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(jnp.arange(2.))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_pmap_of_scalars(self, *, callback):
if callback is io_callback_ordered:
self.skipTest("N/A")
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.pmap
def f(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
out = f(jnp.arange(jax.local_device_count(), dtype=jnp.float32))
jax.effects_barrier()
np.testing.assert_allclose(
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
@with_pure_and_io_callbacks
def test_callback_inside_of_pmap(self, *, callback):
if callback is io_callback_ordered:
self.skipTest("N/A")
def _callback(x):
return x + 1.
@jax.pmap
def f(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
out = f(
jnp.arange(2 * jax.local_device_count(),
dtype=jnp.float32).reshape([-1, 2]))
jax.effects_barrier()
np.testing.assert_allclose(
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_exception_in_callback(self, *, callback):
def fail(x):
raise RuntimeError("Ooops")
@jax.jit
def f(x):
return callback(fail, core.ShapedArray(x.shape, x.dtype), x)
with self.assertLogs(level="ERROR") as l:
try:
f(0.0).block_until_ready()
except RuntimeError:
pass
api_name = (
"pure_callback" if callback is jax.pure_callback else "io_callback"
)
output = "\n".join(l.output)
self.assertIn(f"jax.{api_name} failed", output)
self.assertIn("Traceback (most recent call last)", output)
@with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_compilation_caching(self, *, callback):
def f_outside(x):
return 2 * x
def fun(x):
return callback(f_outside, x, x)
x = np.arange(6, dtype=np.int32).reshape((2, 3))
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * x, fun(x))
self.assertEqual(count(), 1)
class PureCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_pure_callback_passes_jax_arrays_without_jit(self):
def cb(x):
self.assertIsInstance(x, jax.Array)
return x
def f(x):
return jax.pure_callback(cb, x, x)
f(jnp.array(2.))
def test_can_dce_pure_callback(self):
if jax.default_backend() == "tpu":
raise unittest.SkipTest("DCE doesn't currently happen on TPU")
log = []
def _callback(x):
# Should never happen!
log.append("hello world")
return (x * 2.).astype(x.dtype)
@jax.jit
def f(x):
_ = jax.pure_callback(_callback, x, x)
return x * 2.
_ = f(2.)
self.assertEmpty(log)
def test_can_vmap_pure_callback(self):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
out = f(jnp.arange(4.))
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
@jax.jit
def g(x):
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.), 4.)
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
rtol=1E-7, check_dtypes=False)
@jax.jit
@functools.partial(jax.vmap)
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.), jnp.arange(10., 14.))
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
rtol=1E-7, check_dtypes=False)
@jax.jit
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
14.)[None],
rtol=1E-7, check_dtypes=False)
def test_vmap_vectorized_callback(self):
def cb(x):
self.assertTupleEqual(x.shape, ())
return np.sin(x)
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, vmap_method="sequential")
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
def cb2(x):
self.assertTupleEqual(x.shape, (4,))
return np.sin(x)
@jax.jit
@jax.vmap
def g(x):
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vmap_method="expand_dims")
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
@jax.jit
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vmap_method="legacy_vectorized")
out = h(jnp.arange(4.)[None], 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
def cb(x):
# Reduces over all dimension when it shouldn't
return np.sin(x).sum()
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
with self.assertRaises(RuntimeError):
f(jnp.arange(4.))
jax.effects_barrier()
def test_can_pmap_pure_callback(self):
@jax.pmap
def f(x):
return jax.pure_callback(np.sin, x, x)
out = f(jnp.arange(float(jax.local_device_count())))
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
def test_cant_take_grad_of_pure_callback(self):
def sin(x):
return np.sin(x)
@jax.jit
@jax.grad
def f(x):
return jax.pure_callback(sin, x, x)
with self.assertRaisesRegex(
ValueError, "Pure callbacks do not support JVP."):
f(2.)
def test_error_propagation(self):
def throws_error_fn(x):
raise RuntimeError("Errors should propagate.")
@jax.jit
def f(x):
return jax.pure_callback(throws_error_fn, x, x)
with self.assertRaisesRegex(Exception, "Errors should propagate."):
print(np.array(f(2.0)), flush=True)
def test_reentrant_error_propagation(self):
reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile()
@jax.jit
def f(x):
return jax.pure_callback(reentrant_fn, x, x)
try:
np.array(f(2.0))
except:
# Only should not deadlock.
pass
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
@jax.custom_jvp
def sin(x):
return jax.pure_callback(np.sin, x, x)
@sin.defjvp
def sin_jvp(xs, ts):
(x,), (t,), = xs, ts
return sin(x), jax.pure_callback(np.cos, x, x) * t
@jax.jit
@jax.grad
def f(x):
return sin(x)
out = f(2.)
np.testing.assert_allclose(out, jnp.cos(2.))
def test_callback_inside_of_cond(self):
def _callback1(x):
return x + 1.
def _callback2(x):
return x - 1.
@jax.jit
def f(pred, x):
def true_fun(x):
return jax.pure_callback(_callback1, x, x)
def false_fun(x):
return jax.pure_callback(_callback2, x, x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, jnp.ones(2))
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
out = f(False, jnp.ones(2))
np.testing.assert_allclose(out, jnp.zeros(2))
def test_callback_inside_of_scan(self):
def _callback(x):
return x + 1.
@jax.jit
def f(x):
def body(x, _):
x = jax.pure_callback(_callback, x, x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(jnp.arange(2.))
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
def test_callback_inside_of_while_loop(self):
def _cond_callback(x):
return np.any(x < 10)
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return jax.pure_callback(
_cond_callback, jax.ShapeDtypeStruct((), np.bool_), x)
def body(x):
return jax.pure_callback(_callback, x, x)
return lax.while_loop(cond, body, x)
out = f(jnp.arange(5.))
np.testing.assert_allclose(out, jnp.arange(10., 15.))
def test_callback_inside_of_pmap(self):
def _callback(x):
return x + 1.
@jax.pmap
def f(x):
return jax.pure_callback(_callback, x, x)
out = f(
jnp.arange(2 * jax.local_device_count(),
dtype=jnp.float32).reshape([-1, 2]))
np.testing.assert_allclose(
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
def test_callback_with_immediate_executable_destruction(self):
def loop_body(i, x):
del i
return jax.pure_callback(lambda y: y + np.ones(4, np.float32),
x, x)
class AClass:
def f(self, ys):
return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32))
num_devices = jax.local_device_count()
c = AClass()
out = jax.pmap(c.f)(np.ones((num_devices,), np.float32))
# c.f is an ephemeral bound method object, and it will be destroyed
# immediately. This test verifies that the execution itself keeps the
# callback alive.
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
def test_array_layout_is_preserved(self):
def g(x):
return jax.pure_callback(lambda x: x, x, x)
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)
def test_can_shard_pure_callback_maximally(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
def f(x):
return jax.pure_callback(func, x, x)
inp = jnp.arange(float(jax.local_device_count()))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
jax.block_until_ready(out)
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)
def test_can_shard_pure_callback_maximally_with_sharding(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
callback_device = jax.devices()[-1]
callback_device_index = sharding._device_assignment.index(callback_device)
def f(x):
sharding = jax.sharding.SingleDeviceSharding(callback_device)
return jax.pure_callback(func, x, x, sharding=sharding)
f_jit = jax.jit(f, in_shardings=sharding, out_shardings=sharding)
inp = jnp.arange(float(jax.local_device_count()))
out = f_jit(inp)
jax.block_until_ready(out)
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)
stablehlo_ir = f_jit.lower(inp).as_text()
if config.use_shardy_partitioner.value:
self.assertIn(
"sdy.sharding ="
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
" []>]>",
stablehlo_ir)
self.assertIn(
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
f" device_ids=[{callback_device_index}]>",
stablehlo_ir)
else:
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
def test_can_shard_pure_callback_manually(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
def f(x):
return jax.pure_callback(func, x, x)
f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec)
inp = jnp.arange(float(jax.local_device_count() * 2))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count())
jax.block_until_ready(out)
np.testing.assert_allclose(
out, inp + y
)
def test_does_not_deadlock(self):
if jtu.device_under_test() == "tpu":
self.skipTest("The test raises an exception on TPU")
def f(x):
y = jnp.asarray(x) + 1
return np.asarray(2 * jnp.log(y))
x = jnp.array([1.0, 2.0, 3.0, 4.0])
out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x)
np.testing.assert_allclose(out, 2 * jnp.log(x + 1))
def test_vmap_method_raise(self):
@jax.vmap
def f(x):
# Setting vectorized to None disables the current default behavior of
# falling back on sequential.
return jax.pure_callback(np.sin, x, x, vectorized=None)
with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"):
f(jnp.arange(4.))
def test_deprecated_vectorized(self):
def f(x, **kwargs):
return jax.pure_callback(np.sin, x, x, **kwargs)
with self.assertWarnsRegex(DeprecationWarning, "The default behavior"):
jax.vmap(f)(jnp.arange(4.0))
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
f(jnp.arange(4.0), vectorized=True)
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
f(jnp.arange(4.0), vectorized=False)
def test_vmap_method_expand_dims(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (1,))
return x + y
def f(x, y):
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
def test_vmap_method_broadcast_all(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (4,))
return x + y
def f(x, y):
return jax.pure_callback(callback, x, x, y,
vmap_method="broadcast_all")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.run_on_devices("cpu")
def test_async_deadlock(self):
self.skipTest("Too slow and memory intensive.")
# See https://github.com/jax-ml/jax/issues/24255
eig = jax.jit(jnp.linalg.eig)
def callback(x):
return jax.block_until_ready(eig(x))
def fun(x):
self.assertEqual(x.dtype, jnp.complex64)
out_type = (
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype),
)
return jax.pure_callback(callback, out_type, x)
result = 0.0
for _ in range(10):
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
jax.block_until_ready(result) # doesn't deadlock
def test_non_default_stride(self):
x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4)
def callback(x):
return np.asfortranarray(x)
@jax.jit
def f(x):
return jax.pure_callback(
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
)
result = f(x)
np.testing.assert_array_equal(x, result)
class IOCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_io_callback_can_mutate_state(self):
x = 0
def cb():
nonlocal x
x += 1
return np.array(x, np.int32)
def f():
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
f()
jax.effects_barrier()
self.assertEqual(x, 1)
f()
jax.effects_barrier()
self.assertEqual(x, 2)
def test_io_callback_can_be_batched_if_unordered(self):
_mut = 0
def cb(x):
nonlocal _mut
_mut += 1
return x
x = jnp.arange(4)
def f(x):
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 4)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 8)
def test_cannot_call_ordered_io_in_pmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`"):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_cannot_call_ordered_io_in_vmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
jax.vmap(f)(jnp.arange(4))
def test_cannot_use_io_callback_in_jvp(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.jvp(f, (0.,), (1.,))
def test_cannot_use_io_callback_in_linearize(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.linearize(f, 0.)
def test_cannot_use_io_callback_in_transpose(self):
x = jnp.array(1.)
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support transpose."):
jax.linear_transpose(f, x)(x)
def test_cannot_vmap_of_cond_io_callback(self):
def f(pred):
def true_fun():
io_callback(lambda: print("true"), None)
def false_fun():
io_callback(lambda: print("false"), None)
return lax.cond(pred, false_fun, true_fun)
with self.assertRaisesRegex(NotImplementedError,
"IO effect not supported in vmap-of-cond."):
jax.vmap(f)(jnp.array([True, True]))
def test_cannot_vmap_of_while_io_callback(self):
def check(x):
assert np.all(x < 5)
def f(i):
def cond(i):
return i < 5
def body(i):
io_callback(check, None, i)
return i + 1
return lax.while_loop(cond, body, i)
with self.assertRaisesRegex(
Exception, "not supported in while_loop with batched predicate"):
jax.vmap(f)(jnp.array([0, 4]))
def test_cannot_use_io_callback_in_checkpoint(self):
@jax.grad
@jax.checkpoint
def f(x, y):
io_callback(lambda x: x, y, y)
return x
with self.assertRaisesRegex(NotImplementedError,
"Effects not supported in partial-eval of `checkpoint`"):
f(2., 3.)
@parameterized.named_parameters(
dict(
testcase_name=f'{ordered=}_{with_sharding=}',
ordered=ordered,
with_sharding=with_sharding,
)
for ordered in [True, False]
for with_sharding in [True, False]
)
def test_can_use_io_callback_in_pjit(
self, *, ordered: bool, with_sharding: bool
):
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), ['dev'])
_collected: list[int] = []
def _cb(x):
nonlocal _collected
_collected.append(int(x.sum()))
io_callback_kwargs = dict(ordered=ordered)
callback_device = devices[0]
if with_sharding:
callback_device = devices[-1]
io_callback_kwargs['sharding'] = jax.sharding.SingleDeviceSharding(
callback_device
)
def f(x):
io_callback(_cb, None, x, **io_callback_kwargs)
io_callback(_cb, None, x + 1, **io_callback_kwargs)
return x
in_spec = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('dev')
)
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec)
expected = []
with mesh:
x = jnp.arange(mesh.size)
f(x)
expected.extend([int(x.sum()), int((x + 1).sum())])
f(x + 5)
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
jax.effects_barrier()
if ordered:
self.assertAllClose(_collected, expected)
else:
self.assertEqual(len(_collected), len(expected))
for v in expected:
self.assertIn(v, _collected)
callback_device_index = in_spec._device_assignment.index(callback_device)
stablehlo_ir = f.lower(x).as_text()
if config.use_shardy_partitioner.value:
self.assertIn(
"sdy.sharding ="
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
" []>]>",
stablehlo_ir)
self.assertIn(
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
f" device_ids=[{callback_device_index}]>",
stablehlo_ir)
else:
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
def test_sequence_pjit_io_callback_ordered(self):
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
# pair on a different device assignment.
_collected: list[int] = []
def _cb(i, x):
nonlocal _collected
# Sleep different amounts of time, to test the ordering.
time.sleep([0.02, 0.03, 0.04][len(_collected) % 3])
logging.info('Collected iteration %s: %s', i, x)
_collected.append(int(x.sum()))
def f_base(i, x):
io_callback(_cb, None, i, x, ordered=True)
io_callback(_cb, None, i, x + 1, ordered=True)
nr_iterations = 8
# TODO(zce): If I pin to 1 device below (jax.devices()[:1]) then this test
# flakes. It also flakes when pinned to 2 devices. It seems that repeatedly
# dispatching to the same device triggers the problem.
devices = jax.devices()
expected = [] # The expected value for _collected
for i in range(nr_iterations):
if len(devices) > 1:
devices_for_iteration = [
devices[i % len(devices)],
devices[(i + 1) % len(devices)],
]
else:
devices_for_iteration = devices
logging.info(
'Running iteration %d on devices %s', i, devices_for_iteration
)
mesh = jax.sharding.Mesh(np.array(devices_for_iteration), ['dev'])
in_spec = (
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')),
)
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
f = pjit.pjit(f_base, in_shardings=in_spec, out_shardings=out_spec)
with mesh:
x = jax.device_put(
np.arange(len(devices_for_iteration), dtype=np.int32) + 10 * i,
in_spec[1],
)
f(i, x)
expected.extend([int(x.sum()), int((x + 1).sum())])
f(i, x + 5)
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
jax.effects_barrier()
self.assertEqual(_collected, expected)
def test_can_shard_io_callback_manually(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
_collected = collections.defaultdict(list)
def func(shard_id, x):
nonlocal _collected
_collected[shard_id.item()].append(x)
def f(shard_ids, x):
io_callback(func, None, shard_ids, x, ordered=True)
io_callback(func, None, shard_ids, x + 1, ordered=True)
f = shard_map(f, mesh=mesh, in_specs=spec, out_specs=None)
shard_ids = jnp.arange(mesh.devices.size)
inp = jnp.arange(2 * jax.local_device_count())
jax.jit(f, in_shardings=sharding, out_shardings=None)(shard_ids, inp)
jax.effects_barrier()
self.assertLen(_collected, mesh.devices.size)
# Verify the partial ordering: no specified order across shards, but strict
# ordering between the two calls in each shard.
for shard in _collected.values():
self.assertLen(shard, 2)
np.testing.assert_array_equal(shard[0] + 1, shard[1])
def test_batching_with_side_effects(self):
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
x_lst = []
def append_x(x):
nonlocal x_lst
x_lst.append(x)
@jax.jit
def f(x):
io_callback(append_x, None, x, ordered=False)
io_callback(append_x, None, 2 * x, ordered=False)
jax.vmap(f)(jnp.arange(3.))
jax.effects_barrier()
self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False)
def test_batching_with_side_effects_while_loop(self):
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
x_lst = []
def append_x(x):
nonlocal x_lst
x_lst.append(x)
@jax.jit
def f(x):
def body(i):
io_callback(append_x, None, x, ordered=False)
io_callback(append_x, None, 2 * x, ordered=False)
return i + 1
jax.lax.while_loop(lambda i: i < 2, body, 0)
jax.vmap(f)(jnp.arange(3.)) # don't crash
jax.effects_barrier()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())