rocm_jax/tests/python_callback_test.py
George Necula f27816af30 [callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
2023-09-05 14:23:28 -07:00

1055 lines
30 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import textwrap
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
from jax import lax
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
from jax.experimental import io_callback
import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np
config.parse_flags_with_absl()
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
map, unsafe_map = util.safe_map, map
# Some test methods take a kwarg
# callback=[io_callback(ordered=True) | io_callback(ordered=False) | pure_callback]
io_callback_ordered = functools.partial(io_callback, ordered=True)
io_calback_unordered = functools.partial(io_callback, ordered=False)
with_pure_and_io_callbacks = parameterized.named_parameters(
dict(testcase_name=flavor,
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor])
for flavor in ("io_unordered", "io_ordered", "pure")
)
class PythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@with_pure_and_io_callbacks
def test_callback_with_scalar_values(self, *, callback):
@jax.jit
def f(x):
return callback(lambda x: x + np.float32(1.),
core.ShapedArray(x.shape, x.dtype), x)
out = f(0.)
self.assertEqual(out, 1.)
@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor])
for flavor in ("io_unordered", "io_ordered", "pure")
for dtype in jtu.dtypes.all
)
def test_callback_works_with_all_types(self, *, callback, dtype):
def host_func(x):
if dtype == np.bool_:
return ~ x
else:
return x + x
_received = None
def _cb(x):
nonlocal _received
_received = x
return host_func(x)
if dtype == np.bool_:
x = np.array([True, False, True, True], dtype=np.bool_)
else:
x = np.arange(4, dtype=dtype)
@jax.jit
def f(x):
return callback(_cb,
core.ShapedArray(x.shape, x.dtype), x)
out = f(x)
self.assertAllClose(out, host_func(x))
jax.effects_barrier()
self.assertAllClose(_received, x)
@with_pure_and_io_callbacks
def test_callback_with_wrong_number_of_args(self, *, callback):
@jax.jit
def f():
# Calling a function that expects `x` with no arguments
return callback(lambda x: np.ones(4, np.float32),
core.ShapedArray((4,), np.float32))
with self.assertRaises(RuntimeError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_number_of_returned_values(self, *, callback):
@jax.jit
def f(x):
# Calling a function with two return values that expects one return value
return callback(lambda x: (x, np.ones(4, np.float32)), x, x)
with self.assertRaises(RuntimeError):
f(2.)
jax.effects_barrier()
@jax.jit
def g():
# Specifically for io_callback, calling a function with a return value
# that expects no return values
return io_callback(lambda: None, (core.ShapedArray(
(1,), np.float32), core.ShapedArray((2,), np.float32)))
with self.assertRaises(RuntimeError):
g()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_shape_outputs(self, *, callback):
@jax.jit
def f():
# Calling a function expected a (1,) shaped return value but getting ()
return callback(lambda: np.float32(1.), core.ShapedArray((1,),
np.float32))
with self.assertRaises(RuntimeError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrong_dtype_outputs(self, *, callback):
def _cb():
return np.array([1], np.float64)
@jax.jit
def f():
# Calling a function expected a f32 return value but getting f64
return callback(_cb, core.ShapedArray((1,), np.float32))
with self.assertRaises(RuntimeError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
if config.jax_enable_x64:
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
@jax.jit
def f():
return callback(lambda: np.float64(1.),
core.ShapedArray((), np.float64))
with self.assertRaises(ValueError):
f()
jax.effects_barrier()
@with_pure_and_io_callbacks
def test_callback_with_single_return_value(self, *, callback):
@jax.jit
def f():
return callback(lambda: np.ones(4, np.float32),
core.ShapedArray((4,), np.float32))
out = f()
jax.effects_barrier()
np.testing.assert_allclose(out, np.ones(4, np.float32))
@with_pure_and_io_callbacks
def test_callback_with_multiple_return_values(self, *, callback):
@jax.jit
def f():
return callback(lambda: (np.ones(4, np.float32), np.ones(5, np.int32)),
(core.ShapedArray(
(4,), np.float32), core.ShapedArray((5,), np.int32)))
x, y = f()
jax.effects_barrier()
np.testing.assert_allclose(x, np.ones(4, np.float32))
np.testing.assert_allclose(y, np.ones(5, np.int32))
@with_pure_and_io_callbacks
def test_callback_with_multiple_arguments_and_return_values(self, *, callback):
def _callback(x, y, z):
return (x, y + z)
@jax.jit
def f(x, y, z):
return callback(_callback, (core.ShapedArray(
(3,), x.dtype), core.ShapedArray((3,), x.dtype)), x, y, z)
x, y = f(jnp.ones(3), jnp.arange(3.), jnp.arange(3.) + 1.)
jax.effects_barrier()
np.testing.assert_allclose(x, np.ones(3))
np.testing.assert_allclose(y, np.array([1., 3., 5]))
@with_pure_and_io_callbacks
def test_send_zero_dim_arrays(self, *, callback):
result = np.full((2,), 42.0, dtype=np.float32)
x = np.zeros((2, 0), np.float32)
def _callback(x): # x: f32[2, 0]
return result
@jax.jit
def f(x):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x)
jax.effects_barrier()
self.assertAllClose(f(x), result)
@with_pure_and_io_callbacks
def test_send_zero_dim_and_non_zero_dim_arrays(self, *, callback):
x = np.zeros((2, 0), np.float32)
y = np.full((2,), 42.0, dtype=np.float32)
result = y
def _callback(x, y): # x: f32[2, 0] y: f32[2]
return y
@jax.jit
def f(x, y):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x, y)
jax.effects_barrier()
self.assertAllClose(f(x, y), result)
@with_pure_and_io_callbacks
def test_recv_zero_dim_arrays(self, *, callback):
result = np.full((2, 0), 42.0, dtype=np.float32)
x = np.zeros((2,), np.float32)
def _callback(_): # f32[2] -> f32[2, 0]
return result
@jax.jit
def f(x):
return callback(
_callback, core.ShapedArray(result.shape, result.dtype), x)
jax.effects_barrier()
self.assertAllClose(f(x), result)
@with_pure_and_io_callbacks
def test_recv_zero_dim_and_non_zero_dim_arrays(self, *, callback):
x = np.full((2,), 42., dtype=np.float32)
result0 = np.ones((2, 0), dtype=np.float32)
result1 = x
result2 = np.ones((3, 0), dtype=np.int32)
result3 = np.concatenate([x, x]) + 1.
def _callback(x): # x: f32[2] -> (f32[2, 0], f32[2], f32[3, 0], f32[4])
return (result0, x, result2, np.concatenate([x, x]) + 1.)
@jax.jit
def f(x):
return callback(
_callback, (core.ShapedArray(result0.shape, result0.dtype),
core.ShapedArray(result1.shape, result1.dtype),
core.ShapedArray(result2.shape, result2.dtype),
core.ShapedArray(result3.shape, result3.dtype)), x)
res = f(x)
jax.effects_barrier()
self.assertAllClose(res, (result0, result1, result2, result3))
@with_pure_and_io_callbacks
def test_callback_with_pytree_arguments_and_return_values(self, *, callback):
def _callback(x):
return dict(y=[x])
@jax.jit
def f(x):
return callback(_callback, dict(y=[core.ShapedArray((), np.float32)]),
[x])
out = f(jnp.float32(2.))
jax.effects_barrier()
self.assertEqual(out, dict(y=[2.]))
@with_pure_and_io_callbacks
def test_callback_inside_of_while_loop_of_scalars(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return x < 10
def body(x):
return callback(_callback, core.ShapedArray((), x.dtype), x)
return lax.while_loop(cond, body, x)
out = f(0.)
jax.effects_barrier()
self.assertEqual(out, 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_while_loop(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return jnp.any(x < 10)
def body(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return lax.while_loop(cond, body, x)
out = f(jnp.arange(5.))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.arange(10., 15.))
@with_pure_and_io_callbacks
def test_callback_inside_of_cond_of_scalars(self, *, callback):
def _callback1(x):
return (x + 1.).astype(x.dtype)
def _callback2(x):
return (x - 1.).astype(x.dtype)
@jax.jit
def f(pred, x):
def true_fun(x):
return callback(_callback1, core.ShapedArray((), x.dtype), x)
def false_fun(x):
return callback(_callback2, core.ShapedArray((), x.dtype), x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, 1.)
jax.effects_barrier()
self.assertEqual(out, 2.)
out = f(False, 1.)
jax.effects_barrier()
self.assertEqual(out, 0.)
@with_pure_and_io_callbacks
def test_callback_inside_of_cond(self, *, callback):
def _callback1(x):
return x + 1.
def _callback2(x):
return x - 1.
@jax.jit
def f(pred, x):
def true_fun(x):
return callback(_callback1, core.ShapedArray(x.shape, x.dtype), x)
def false_fun(x):
return callback(_callback2, core.ShapedArray(x.shape, x.dtype), x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, jnp.ones(2))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
out = f(False, jnp.ones(2))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.zeros(2))
@with_pure_and_io_callbacks
def test_callback_inside_of_scan_of_scalars(self, *, callback):
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def body(x, _):
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(0.)
jax.effects_barrier()
self.assertEqual(out, 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_scan(self, *, callback):
def _callback(x):
return x + 1.
@jax.jit
def f(x):
def body(x, _):
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(jnp.arange(2.))
jax.effects_barrier()
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
@with_pure_and_io_callbacks
def test_callback_inside_of_pmap_of_scalars(self, *, callback):
if callback is io_callback_ordered:
self.skipTest("N/A")
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.pmap
def f(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
out = f(jnp.arange(jax.local_device_count(), dtype=jnp.float32))
jax.effects_barrier()
np.testing.assert_allclose(
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
@with_pure_and_io_callbacks
def test_callback_inside_of_pmap(self, *, callback):
if callback is io_callback_ordered:
self.skipTest("N/A")
def _callback(x):
return x + 1.
@jax.pmap
def f(x):
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
out = f(
jnp.arange(2 * jax.local_device_count(),
dtype=jnp.float32).reshape([-1, 2]))
jax.effects_barrier()
np.testing.assert_allclose(
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
class PureCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_pure_callback_passes_ndarrays_without_jit(self):
def cb(x):
self.assertIs(type(x), np.ndarray)
return x
def f(x):
return jax.pure_callback(cb, x, x)
f(jnp.array(2.))
def test_can_dce_pure_callback(self):
if jax.default_backend() == "tpu":
raise unittest.SkipTest("DCE doesn't currently happen on TPU")
log = []
def _callback(x):
# Should never happen!
log.append("hello world")
return (x * 2.).astype(x.dtype)
@jax.jit
def f(x):
_ = jax.pure_callback(_callback, x, x)
return x * 2.
_ = f(2.)
self.assertEmpty(log)
def test_can_vmap_pure_callback(self):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(np.sin, x, x)
out = f(jnp.arange(4.))
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
@jax.jit
def g(x):
return jax.pure_callback(np.sin, x, x)
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
out = h(jnp.arange(4.), 4.)
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
rtol=1E-7, check_dtypes=False)
@jax.jit
@functools.partial(jax.vmap)
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
out = h(jnp.arange(4.), jnp.arange(10., 14.))
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
rtol=1E-7, check_dtypes=False)
def test_vmap_vectorized_callback(self):
def cb(x):
self.assertTupleEqual(x.shape, ())
return np.sin(x)
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x)
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
def cb2(x):
self.assertTupleEqual(x.shape, (4,))
return np.sin(x)
@jax.jit
@jax.vmap
def g(x):
return jax.pure_callback(cb2, x, x, vectorized=True)
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vectorized=True)
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
def cb(x):
# Reduces over all dimension when it shouldn't
return np.sin(x).sum()
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, vectorized=True)
with self.assertRaises(RuntimeError):
f(jnp.arange(4.))
jax.effects_barrier()
def test_can_pmap_pure_callback(self):
@jax.pmap
def f(x):
return jax.pure_callback(np.sin, x, x)
out = f(jnp.arange(float(jax.local_device_count())))
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
def test_can_pjit_pure_callback_under_hard_xmap(self):
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
jtu.set_spmd_lowering_flag(True)
jtu.set_spmd_manual_lowering_flag(True)
try:
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
def f(x):
axis_resources = {v: v for v in mesh.axis_names}
return xmap(
lambda x: jax.pure_callback(np.sin, x, x),
in_axes=(('x',),),
out_axes=('x',),
axis_resources=axis_resources,
axis_sizes=mesh.shape,
)(x)
def without_xmap_f(x):
return jax.pure_callback(np.sin, x, x)
with mesh:
inp = jnp.arange(float(jax.local_device_count()))
out = pjit.pjit(f, in_shardings=spec, out_shardings=spec)(inp)
np.testing.assert_allclose(
out, np.sin(np.arange(jax.local_device_count()))
)
finally:
jtu.restore_spmd_manual_lowering_flag()
jtu.restore_spmd_lowering_flag()
def test_cant_take_grad_of_pure_callback(self):
def sin(x):
return np.sin(x)
@jax.jit
@jax.grad
def f(x):
return jax.pure_callback(sin, x, x)
with self.assertRaisesRegex(
ValueError, "Pure callbacks do not support JVP."):
f(2.)
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
@jax.custom_jvp
def sin(x):
return jax.pure_callback(np.sin, x, x)
@sin.defjvp
def sin_jvp(xs, ts):
(x,), (t,), = xs, ts
return sin(x), jax.pure_callback(np.cos, x, x) * t
@jax.jit
@jax.grad
def f(x):
return sin(x)
out = f(2.)
np.testing.assert_allclose(out, jnp.cos(2.))
def test_callback_inside_of_cond(self):
def _callback1(x):
return x + 1.
def _callback2(x):
return x - 1.
@jax.jit
def f(pred, x):
def true_fun(x):
return jax.pure_callback(_callback1, x, x)
def false_fun(x):
return jax.pure_callback(_callback2, x, x)
return lax.cond(pred, true_fun, false_fun, x)
out = f(True, jnp.ones(2))
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
out = f(False, jnp.ones(2))
np.testing.assert_allclose(out, jnp.zeros(2))
def test_callback_inside_of_scan(self):
def _callback(x):
return x + 1.
@jax.jit
def f(x):
def body(x, _):
x = jax.pure_callback(_callback, x, x)
return x, ()
return lax.scan(body, x, jnp.arange(10))[0]
out = f(jnp.arange(2.))
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
def test_callback_inside_of_while_loop(self):
def _cond_callback(x):
return np.any(x < 10)
def _callback(x):
return (x + 1.).astype(x.dtype)
@jax.jit
def f(x):
def cond(x):
return jax.pure_callback(
_cond_callback, jax.ShapeDtypeStruct((), np.bool_), x)
def body(x):
return jax.pure_callback(_callback, x, x)
return lax.while_loop(cond, body, x)
out = f(jnp.arange(5.))
np.testing.assert_allclose(out, jnp.arange(10., 15.))
def test_callback_inside_of_pmap(self):
def _callback(x):
return x + 1.
@jax.pmap
def f(x):
return jax.pure_callback(_callback, x, x)
out = f(
jnp.arange(2 * jax.local_device_count(),
dtype=jnp.float32).reshape([-1, 2]))
np.testing.assert_allclose(
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
def test_callback_inside_xmap(self):
def _callback(x):
return (x + 1.).astype(x.dtype)
def f(x):
return jax.pure_callback(_callback, x, x)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_vectorized_callback_inside_xmap(self):
def _callback(x):
return (x + 1.).astype(x.dtype)
def f(x):
return jax.pure_callback(_callback, x, x, vectorized=True)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_array_layout_is_preserved(self):
def g(x):
return jax.pure_callback(lambda x: x, x, x)
x = np.arange(6, dtype=np.int32).reshape((3, 2))
np.testing.assert_allclose(g(x), x)
def test_can_shard_pure_callback_maximally(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
def f(x):
return jax.pure_callback(func, x, x)
inp = jnp.arange(float(jax.local_device_count()))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
jax.block_until_ready(out)
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)
def test_can_shard_pure_callback_maximally_with_sharding(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
callback_device = jax.devices()[-1]
callback_device_index = sharding._device_assignment.index(callback_device)
def f(x):
sharding = jax.sharding.SingleDeviceSharding(callback_device)
return jax.pure_callback(func, x, x, sharding=sharding)
f_jit = jax.jit(f, in_shardings=sharding, out_shardings=sharding)
inp = jnp.arange(float(jax.local_device_count()))
out = f_jit(inp)
jax.block_until_ready(out)
np.testing.assert_allclose(
out, np.arange(jax.local_device_count()) * 2
)
self.assertIn(
f'{{maximal device={callback_device_index}}}',
str(f_jit.lower(inp).compiler_ir(dialect='stablehlo')),
)
def test_can_shard_pure_callback_manually(self):
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
sharding = jax.sharding.NamedSharding(mesh, spec)
def func(x):
return x + np.arange(x.shape[0], dtype=x.dtype)
def f(x):
return jax.pure_callback(func, x, x)
f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec)
inp = jnp.arange(float(jax.local_device_count() * 2))
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count())
jax.block_until_ready(out)
np.testing.assert_allclose(
out, inp + y
)
class IOCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
def test_io_callback_can_mutate_state(self):
x = 0
def cb():
nonlocal x
x += 1
return np.array(x, np.int32)
def f():
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
f()
jax.effects_barrier()
self.assertEqual(x, 1)
f()
jax.effects_barrier()
self.assertEqual(x, 2)
def test_io_callback_can_be_batched_if_unordered(self):
_mut = 0
def cb(x):
nonlocal _mut
_mut += 1
return x
x = jnp.arange(4)
def f(x):
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 4)
jax.vmap(f)(x)
jax.effects_barrier()
self.assertEqual(_mut, 8)
def test_cannot_call_ordered_io_in_pmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`"):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_cannot_call_ordered_io_in_xmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16))
def test_cannot_call_ordered_io_in_vmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
jax.vmap(f)(jnp.arange(4))
def test_cannot_use_io_callback_in_jvp(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.jvp(f, (0.,), (1.,))
def test_cannot_use_io_callback_in_linearize(self):
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support JVP."):
jax.linearize(f, 0.)
def test_cannot_use_io_callback_in_transpose(self):
x = jnp.array(1.)
def f(x):
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
with self.assertRaisesRegex(
ValueError, "IO callbacks do not support transpose."):
jax.linear_transpose(f, x)(x)
def test_cannot_vmap_of_cond_io_callback(self):
def f(pred):
def true_fun():
io_callback(lambda: print("true"), None)
def false_fun():
io_callback(lambda: print("false"), None)
return lax.cond(pred, false_fun, true_fun)
with self.assertRaisesRegex(NotImplementedError,
"IO effect not supported in vmap-of-cond."):
jax.vmap(f)(jnp.array([True, True]))
def test_cannot_vmap_of_while_io_callback(self):
def check(x):
assert np.all(x < 5)
def f(i):
def cond(i):
return i < 5
def body(i):
io_callback(check, None, i)
return i + 1
return lax.while_loop(cond, body, i)
with self.assertRaisesRegex(NotImplementedError,
"IO effect not supported in vmap-of-while."):
jax.vmap(f)(jnp.array([0, 4]))
def test_cannot_use_io_callback_in_checkpoint(self):
@jax.grad
@jax.checkpoint
def f(x, y):
io_callback(lambda x: x, y, y)
return x
with self.assertRaisesRegex(NotImplementedError,
"Effects not supported in partial-eval of `checkpoint`"):
f(2., 3.)
def test_can_use_io_callback_in_pjit(self):
_mut = 0
def _cb(x):
nonlocal _mut
_mut = x.sum()
def f(x):
io_callback(_cb, None, x)
return x
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec)
with mesh:
f(jnp.arange(mesh.size))
jax.effects_barrier()
self.assertEqual(_mut, jnp.arange(mesh.size).sum())
def test_can_use_io_callback_in_pjit_with_sharding(self):
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
_mut = 0
def _cb(x):
nonlocal _mut
_mut = x.sum()
callback_device = jax.devices()[-1]
callback_device_index = spec._device_assignment.index(callback_device)
def f(x):
sharding = jax.sharding.SingleDeviceSharding(callback_device)
io_callback(_cb, None, x, sharding=sharding)
return x
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
f = pjit.pjit(f, in_shardings=spec, out_shardings=out_spec)
inp = jnp.arange(mesh.size)
with mesh:
f(inp)
jax.effects_barrier()
self.assertEqual(_mut, jnp.arange(mesh.size).sum())
self.assertIn(
f'{{maximal device={callback_device_index}}}',
str(f.lower(inp).compiler_ir(dialect='stablehlo')),
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())