mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1385 lines
41 KiB
Python
1385 lines
41 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
import contextlib
|
|
import functools
|
|
import logging
|
|
import time
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import test_util as jtu
|
|
from jax._src import util
|
|
from jax.experimental import io_callback
|
|
from jax.experimental import pjit
|
|
from jax.experimental.shard_map import shard_map
|
|
import jax.numpy as jnp
|
|
from jax.sharding import Mesh
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(2)
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
# Some test methods take a kwarg
|
|
# callback=[io_callback(ordered=True) | io_callback(ordered=False) | pure_callback]
|
|
io_callback_ordered = functools.partial(io_callback, ordered=True)
|
|
io_calback_unordered = functools.partial(io_callback, ordered=False)
|
|
with_pure_and_io_callbacks = parameterized.named_parameters(
|
|
dict(testcase_name=flavor,
|
|
callback=dict(io_unordered=io_calback_unordered,
|
|
io_ordered=io_callback_ordered,
|
|
pure=jax.pure_callback)[flavor])
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
)
|
|
|
|
|
|
class PythonCallbackTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_scalar_values(self, *, callback):
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(lambda x: x + 1.0, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
out = f(0.)
|
|
self.assertEqual(out, 1.)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
|
|
callback=dict(
|
|
io_unordered=io_calback_unordered,
|
|
io_ordered=io_callback_ordered,
|
|
pure=jax.pure_callback,
|
|
)[flavor],
|
|
expect_dtype=expect_dtype,
|
|
)
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
|
|
)
|
|
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
|
|
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(
|
|
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
|
|
)
|
|
|
|
if not config.enable_x64.value and expect_dtype in (np.int64, np.float64):
|
|
ctx = self.assertRaisesRegex(Exception, "result_shape_dtypes cannot specify 64-bit types")
|
|
elif config.enable_x64.value and expect_dtype in (np.int32, np.float32):
|
|
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
|
|
else:
|
|
ctx = contextlib.nullcontext()
|
|
|
|
with ctx:
|
|
out = f(0.0)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, returned_literal)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_returning_custom_array(self, *, callback):
|
|
# Some users write the callback in TF, returning a tf.Tensor. We don't
|
|
# want to add TF as a dependency, but simulate that use case with a
|
|
# custom array class.
|
|
class CustomArray:
|
|
|
|
def __init__(self, a: np.ndarray):
|
|
self.a = a
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.a.shape
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.a.dtype
|
|
|
|
def __array__(self):
|
|
return self.a
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(
|
|
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
|
|
core.ShapedArray((), np.float32),
|
|
x,
|
|
)
|
|
|
|
out = f(0.0)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, 42.0)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name=f"{flavor}_{dtype}",
|
|
dtype=dtype,
|
|
callback=dict(io_unordered=io_calback_unordered,
|
|
io_ordered=io_callback_ordered,
|
|
pure=jax.pure_callback)[flavor])
|
|
for flavor in ("io_unordered", "io_ordered", "pure")
|
|
for dtype in jtu.dtypes.all
|
|
)
|
|
def test_callback_works_with_all_types(self, *, callback, dtype):
|
|
def host_func(x):
|
|
if dtype == np.bool_:
|
|
return ~ x
|
|
else:
|
|
return x + x
|
|
_received = None
|
|
def _cb(x):
|
|
nonlocal _received
|
|
_received = x
|
|
return host_func(x)
|
|
|
|
if dtype == np.bool_:
|
|
x = np.array([True, False, True, True], dtype=np.bool_)
|
|
else:
|
|
x = np.arange(4, dtype=dtype)
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(_cb,
|
|
core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
out = f(x)
|
|
self.assertAllClose(out, host_func(x))
|
|
jax.effects_barrier()
|
|
self.assertAllClose(_received, x)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_wrong_number_of_args(self, *, callback):
|
|
|
|
@jax.jit
|
|
def f():
|
|
# Calling a function that expects `x` with no arguments
|
|
return callback(lambda x: np.ones(4, np.float32),
|
|
core.ShapedArray((4,), np.float32))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
f()
|
|
jax.effects_barrier()
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_wrong_number_of_returned_values(self, *, callback):
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
# Calling a function with two return values that expects one return value
|
|
return callback(lambda x: (x, np.ones(4, np.float32)), x, x)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
f(2.)
|
|
jax.effects_barrier()
|
|
|
|
@jax.jit
|
|
def g():
|
|
# Specifically for io_callback, calling a function with a return value
|
|
# that expects no return values
|
|
return io_callback(lambda: None, (core.ShapedArray(
|
|
(1,), np.float32), core.ShapedArray((2,), np.float32)))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
g()
|
|
jax.effects_barrier()
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_wrong_shape_outputs(self, *, callback):
|
|
|
|
@jax.jit
|
|
def f():
|
|
# Calling a function expected a (1,) shaped return value but getting ()
|
|
return callback(lambda: np.float32(1.), core.ShapedArray((1,),
|
|
np.float32))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
f()
|
|
jax.effects_barrier()
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_wrong_dtype_outputs(self, *, callback):
|
|
|
|
def _cb():
|
|
return np.array([1], np.float64)
|
|
|
|
@jax.jit
|
|
def f():
|
|
# Calling a function expected a f32 return value but getting f64
|
|
return callback(_cb, core.ShapedArray((1,), np.float32))
|
|
|
|
if config.enable_x64.value:
|
|
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype for return value")
|
|
else:
|
|
ctx = contextlib.nullcontext()
|
|
with ctx:
|
|
res = f()
|
|
jax.effects_barrier()
|
|
self.assertAllClose(res, np.array([1], np.float32))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_wrongly_specified_64_bit_dtype(self, *, callback):
|
|
if config.enable_x64.value:
|
|
raise unittest.SkipTest("Test only needed when 64-bit mode disabled.")
|
|
|
|
@jax.jit
|
|
def f():
|
|
return callback(lambda: np.float64(1.),
|
|
core.ShapedArray((), np.float64))
|
|
|
|
with self.assertRaises(ValueError):
|
|
f()
|
|
jax.effects_barrier()
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_single_return_value(self, *, callback):
|
|
|
|
@jax.jit
|
|
def f():
|
|
return callback(lambda: np.ones(4, np.float32),
|
|
core.ShapedArray((4,), np.float32))
|
|
|
|
out = f()
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(out, np.ones(4, np.float32))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_multiple_return_values(self, *, callback):
|
|
|
|
@jax.jit
|
|
def f():
|
|
return callback(lambda: (np.ones(4, np.float32), np.ones(5, np.int32)),
|
|
(core.ShapedArray(
|
|
(4,), np.float32), core.ShapedArray((5,), np.int32)))
|
|
|
|
x, y = f()
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(x, np.ones(4, np.float32))
|
|
np.testing.assert_allclose(y, np.ones(5, np.int32))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_multiple_arguments_and_return_values(self, *, callback):
|
|
|
|
def _callback(x, y, z):
|
|
return (x, y + z)
|
|
|
|
@jax.jit
|
|
def f(x, y, z):
|
|
return callback(_callback, (core.ShapedArray(
|
|
(3,), x.dtype), core.ShapedArray((3,), x.dtype)), x, y, z)
|
|
|
|
x, y = f(jnp.ones(3), jnp.arange(3.), jnp.arange(3.) + 1.)
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(x, np.ones(3))
|
|
np.testing.assert_allclose(y, np.array([1., 3., 5]))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_send_zero_dim_arrays(self, *, callback):
|
|
result = np.full((2,), 42.0, dtype=np.float32)
|
|
x = np.zeros((2, 0), np.float32)
|
|
def _callback(x): # x: f32[2, 0]
|
|
return result
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x)
|
|
jax.effects_barrier()
|
|
self.assertAllClose(f(x), result)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_send_zero_dim_and_non_zero_dim_arrays(self, *, callback):
|
|
x = np.zeros((2, 0), np.float32)
|
|
y = np.full((2,), 42.0, dtype=np.float32)
|
|
result = y
|
|
def _callback(x, y): # x: f32[2, 0] y: f32[2]
|
|
return y
|
|
|
|
@jax.jit
|
|
def f(x, y):
|
|
return callback(
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x, y)
|
|
jax.effects_barrier()
|
|
self.assertAllClose(f(x, y), result)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_recv_zero_dim_arrays(self, *, callback):
|
|
result = np.full((2, 0), 42.0, dtype=np.float32)
|
|
x = np.zeros((2,), np.float32)
|
|
def _callback(_): # f32[2] -> f32[2, 0]
|
|
return result
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(
|
|
_callback, core.ShapedArray(result.shape, result.dtype), x)
|
|
jax.effects_barrier()
|
|
self.assertAllClose(f(x), result)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_recv_zero_dim_and_non_zero_dim_arrays(self, *, callback):
|
|
x = np.full((2,), 42., dtype=np.float32)
|
|
result0 = np.ones((2, 0), dtype=np.float32)
|
|
result1 = x
|
|
result2 = np.ones((3, 0), dtype=np.int32)
|
|
result3 = np.concatenate([x, x]) + 1.
|
|
def _callback(x): # x: f32[2] -> (f32[2, 0], f32[2], f32[3, 0], f32[4])
|
|
return (result0, x, result2, np.concatenate([x, x]) + 1.)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(
|
|
_callback, (core.ShapedArray(result0.shape, result0.dtype),
|
|
core.ShapedArray(result1.shape, result1.dtype),
|
|
core.ShapedArray(result2.shape, result2.dtype),
|
|
core.ShapedArray(result3.shape, result3.dtype)), x)
|
|
res = f(x)
|
|
jax.effects_barrier()
|
|
self.assertAllClose(res, (result0, result1, result2, result3))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_with_pytree_arguments_and_return_values(self, *, callback):
|
|
|
|
def _callback(x):
|
|
return dict(y=[x])
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(_callback, dict(y=[core.ShapedArray((), np.float32)]),
|
|
[x])
|
|
|
|
out = f(jnp.float32(2.))
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, dict(y=[2.]))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_while_loop_of_scalars(self, *, callback):
|
|
|
|
def _callback(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
def cond(x):
|
|
return x < 10
|
|
def body(x):
|
|
return callback(_callback, core.ShapedArray((), x.dtype), x)
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
out = f(0.)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, 10.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_while_loop(self, *, callback):
|
|
|
|
def _callback(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
|
|
def cond(x):
|
|
return jnp.any(x < 10)
|
|
|
|
def body(x):
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
out = f(jnp.arange(5.))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_cond_of_scalars(self, *, callback):
|
|
|
|
def _callback1(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
def _callback2(x):
|
|
return (x - 1.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(pred, x):
|
|
|
|
def true_fun(x):
|
|
return callback(_callback1, core.ShapedArray((), x.dtype), x)
|
|
|
|
def false_fun(x):
|
|
return callback(_callback2, core.ShapedArray((), x.dtype), x)
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
out = f(True, 1.)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, 2.)
|
|
out = f(False, 1.)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, 0.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_cond(self, *, callback):
|
|
|
|
def _callback1(x):
|
|
return x + 1.
|
|
|
|
def _callback2(x):
|
|
return x - 1.
|
|
|
|
@jax.jit
|
|
def f(pred, x):
|
|
|
|
def true_fun(x):
|
|
return callback(_callback1, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
def false_fun(x):
|
|
return callback(_callback2, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
out = f(True, jnp.ones(2))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
|
|
out = f(False, jnp.ones(2))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(out, jnp.zeros(2))
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_scan_of_scalars(self, *, callback):
|
|
|
|
def _callback(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
|
|
def body(x, _):
|
|
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
return x, ()
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
out = f(0.)
|
|
jax.effects_barrier()
|
|
self.assertEqual(out, 10.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_scan(self, *, callback):
|
|
|
|
def _callback(x):
|
|
return x + 1.
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
|
|
def body(x, _):
|
|
x = callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
return x, ()
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
out = f(jnp.arange(2.))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_pmap_of_scalars(self, *, callback):
|
|
if callback is io_callback_ordered:
|
|
self.skipTest("N/A")
|
|
|
|
def _callback(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
@jax.pmap
|
|
def f(x):
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
out = f(jnp.arange(jax.local_device_count(), dtype=jnp.float32))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(
|
|
out, np.arange(jax.local_device_count(), dtype=np.float32) + 1.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
def test_callback_inside_of_pmap(self, *, callback):
|
|
if callback is io_callback_ordered:
|
|
self.skipTest("N/A")
|
|
|
|
def _callback(x):
|
|
return x + 1.
|
|
|
|
@jax.pmap
|
|
def f(x):
|
|
return callback(_callback, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
out = f(
|
|
jnp.arange(2 * jax.local_device_count(),
|
|
dtype=jnp.float32).reshape([-1, 2]))
|
|
jax.effects_barrier()
|
|
np.testing.assert_allclose(
|
|
out,
|
|
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
|
|
|
@with_pure_and_io_callbacks
|
|
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
|
def test_exception_in_callback(self, *, callback):
|
|
def fail(x):
|
|
raise RuntimeError("Ooops")
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return callback(fail, core.ShapedArray(x.shape, x.dtype), x)
|
|
|
|
with self.assertLogs(level="ERROR") as l:
|
|
try:
|
|
f(0.0).block_until_ready()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
api_name = (
|
|
"pure_callback" if callback is jax.pure_callback else "io_callback"
|
|
)
|
|
output = "\n".join(l.output)
|
|
self.assertIn(f"jax.{api_name} failed", output)
|
|
self.assertIn("Traceback (most recent call last)", output)
|
|
|
|
@with_pure_and_io_callbacks
|
|
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
|
|
def test_compilation_caching(self, *, callback):
|
|
def f_outside(x):
|
|
return 2 * x
|
|
|
|
def fun(x):
|
|
return callback(f_outside, x, x)
|
|
|
|
x = np.arange(6, dtype=np.int32).reshape((2, 3))
|
|
with jtu.count_primitive_compiles() as count:
|
|
for _ in range(3):
|
|
self.assertAllClose(2 * x, fun(x))
|
|
self.assertEqual(count(), 1)
|
|
|
|
|
|
class PureCallbackTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
def test_pure_callback_passes_jax_arrays_without_jit(self):
|
|
|
|
def cb(x):
|
|
self.assertIsInstance(x, jax.Array)
|
|
return x
|
|
|
|
def f(x):
|
|
return jax.pure_callback(cb, x, x)
|
|
f(jnp.array(2.))
|
|
|
|
def test_can_dce_pure_callback(self):
|
|
|
|
if jax.default_backend() == "tpu":
|
|
raise unittest.SkipTest("DCE doesn't currently happen on TPU")
|
|
|
|
log = []
|
|
def _callback(x):
|
|
# Should never happen!
|
|
log.append("hello world")
|
|
return (x * 2.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
_ = jax.pure_callback(_callback, x, x)
|
|
return x * 2.
|
|
_ = f(2.)
|
|
self.assertEmpty(log)
|
|
|
|
def test_can_vmap_pure_callback(self):
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
def f(x):
|
|
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
|
out = f(jnp.arange(4.))
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
|
|
|
|
@jax.jit
|
|
def g(x):
|
|
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
|
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
|
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
|
|
|
@jax.jit
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
def h(x, y):
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
vmap_method="sequential")
|
|
out = h(jnp.arange(4.), 4.)
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
|
|
rtol=1E-7, check_dtypes=False)
|
|
|
|
@jax.jit
|
|
@functools.partial(jax.vmap)
|
|
def h(x, y):
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
vmap_method="sequential")
|
|
out = h(jnp.arange(4.), jnp.arange(10., 14.))
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
|
|
rtol=1E-7, check_dtypes=False)
|
|
|
|
@jax.jit
|
|
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
|
|
def h(x, y):
|
|
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
|
vmap_method="sequential")
|
|
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
|
|
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
|
|
14.)[None],
|
|
rtol=1E-7, check_dtypes=False)
|
|
|
|
def test_vmap_vectorized_callback(self):
|
|
|
|
def cb(x):
|
|
self.assertTupleEqual(x.shape, ())
|
|
return np.sin(x)
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
def f(x):
|
|
return jax.pure_callback(cb, x, x, vmap_method="sequential")
|
|
|
|
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
|
|
|
|
def cb2(x):
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
return np.sin(x)
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
def g(x):
|
|
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
|
|
|
|
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
|
|
|
@jax.jit
|
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
def h(x, y):
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
|
vmap_method="expand_dims")
|
|
out = h(jnp.arange(4.), 4.)
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
|
|
|
@jax.jit
|
|
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
|
|
def h(x, y):
|
|
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
|
vmap_method="legacy_vectorized")
|
|
out = h(jnp.arange(4.)[None], 4.)
|
|
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
|
|
|
|
def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):
|
|
|
|
def cb(x):
|
|
# Reduces over all dimension when it shouldn't
|
|
return np.sin(x).sum()
|
|
|
|
@jax.jit
|
|
@jax.vmap
|
|
def f(x):
|
|
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
f(jnp.arange(4.))
|
|
jax.effects_barrier()
|
|
|
|
def test_can_pmap_pure_callback(self):
|
|
|
|
@jax.pmap
|
|
def f(x):
|
|
return jax.pure_callback(np.sin, x, x)
|
|
out = f(jnp.arange(float(jax.local_device_count())))
|
|
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
|
|
|
|
def test_cant_take_grad_of_pure_callback(self):
|
|
|
|
def sin(x):
|
|
return np.sin(x)
|
|
|
|
@jax.jit
|
|
@jax.grad
|
|
def f(x):
|
|
return jax.pure_callback(sin, x, x)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Pure callbacks do not support JVP."):
|
|
f(2.)
|
|
|
|
def test_error_propagation(self):
|
|
def throws_error_fn(x):
|
|
raise RuntimeError("Errors should propagate.")
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return jax.pure_callback(throws_error_fn, x, x)
|
|
|
|
with self.assertRaisesRegex(Exception, "Errors should propagate."):
|
|
print(np.array(f(2.0)), flush=True)
|
|
|
|
def test_reentrant_error_propagation(self):
|
|
reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile()
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return jax.pure_callback(reentrant_fn, x, x)
|
|
|
|
try:
|
|
np.array(f(2.0))
|
|
except:
|
|
# Only should not deadlock.
|
|
pass
|
|
|
|
def test_can_take_grad_of_pure_callback_with_custom_jvp(self):
|
|
|
|
@jax.custom_jvp
|
|
def sin(x):
|
|
return jax.pure_callback(np.sin, x, x)
|
|
|
|
@sin.defjvp
|
|
def sin_jvp(xs, ts):
|
|
(x,), (t,), = xs, ts
|
|
return sin(x), jax.pure_callback(np.cos, x, x) * t
|
|
|
|
@jax.jit
|
|
@jax.grad
|
|
def f(x):
|
|
return sin(x)
|
|
out = f(2.)
|
|
np.testing.assert_allclose(out, jnp.cos(2.))
|
|
|
|
def test_callback_inside_of_cond(self):
|
|
|
|
def _callback1(x):
|
|
return x + 1.
|
|
|
|
def _callback2(x):
|
|
return x - 1.
|
|
|
|
@jax.jit
|
|
def f(pred, x):
|
|
|
|
def true_fun(x):
|
|
return jax.pure_callback(_callback1, x, x)
|
|
|
|
def false_fun(x):
|
|
return jax.pure_callback(_callback2, x, x)
|
|
|
|
return lax.cond(pred, true_fun, false_fun, x)
|
|
|
|
out = f(True, jnp.ones(2))
|
|
np.testing.assert_allclose(out, jnp.ones(2) * 2.)
|
|
out = f(False, jnp.ones(2))
|
|
np.testing.assert_allclose(out, jnp.zeros(2))
|
|
|
|
def test_callback_inside_of_scan(self):
|
|
|
|
def _callback(x):
|
|
return x + 1.
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
|
|
def body(x, _):
|
|
x = jax.pure_callback(_callback, x, x)
|
|
return x, ()
|
|
|
|
return lax.scan(body, x, jnp.arange(10))[0]
|
|
|
|
out = f(jnp.arange(2.))
|
|
np.testing.assert_allclose(out, jnp.arange(2.) + 10.)
|
|
|
|
def test_callback_inside_of_while_loop(self):
|
|
|
|
def _cond_callback(x):
|
|
return np.any(x < 10)
|
|
|
|
def _callback(x):
|
|
return (x + 1.).astype(x.dtype)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
|
|
def cond(x):
|
|
return jax.pure_callback(
|
|
_cond_callback, jax.ShapeDtypeStruct((), np.bool_), x)
|
|
|
|
def body(x):
|
|
return jax.pure_callback(_callback, x, x)
|
|
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
out = f(jnp.arange(5.))
|
|
np.testing.assert_allclose(out, jnp.arange(10., 15.))
|
|
|
|
def test_callback_inside_of_pmap(self):
|
|
|
|
def _callback(x):
|
|
return x + 1.
|
|
|
|
@jax.pmap
|
|
def f(x):
|
|
return jax.pure_callback(_callback, x, x)
|
|
|
|
out = f(
|
|
jnp.arange(2 * jax.local_device_count(),
|
|
dtype=jnp.float32).reshape([-1, 2]))
|
|
np.testing.assert_allclose(
|
|
out,
|
|
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
|
|
|
def test_callback_with_immediate_executable_destruction(self):
|
|
|
|
def loop_body(i, x):
|
|
del i
|
|
return jax.pure_callback(lambda y: y + np.ones(4, np.float32),
|
|
x, x)
|
|
|
|
class AClass:
|
|
def f(self, ys):
|
|
return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32))
|
|
|
|
num_devices = jax.local_device_count()
|
|
c = AClass()
|
|
out = jax.pmap(c.f)(np.ones((num_devices,), np.float32))
|
|
# c.f is an ephemeral bound method object, and it will be destroyed
|
|
# immediately. This test verifies that the execution itself keeps the
|
|
# callback alive.
|
|
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
|
|
|
|
def test_array_layout_is_preserved(self):
|
|
|
|
def g(x):
|
|
return jax.pure_callback(lambda x: x, x, x)
|
|
|
|
x = np.arange(6, dtype=np.int32).reshape((3, 2))
|
|
np.testing.assert_allclose(g(x), x)
|
|
|
|
def test_can_shard_pure_callback_maximally(self):
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
def func(x):
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
def f(x):
|
|
return jax.pure_callback(func, x, x)
|
|
|
|
inp = jnp.arange(float(jax.local_device_count()))
|
|
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
|
|
jax.block_until_ready(out)
|
|
np.testing.assert_allclose(
|
|
out, np.arange(jax.local_device_count()) * 2
|
|
)
|
|
|
|
def test_can_shard_pure_callback_maximally_with_sharding(self):
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
def func(x):
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
callback_device = jax.devices()[-1]
|
|
callback_device_index = sharding._device_assignment.index(callback_device)
|
|
|
|
def f(x):
|
|
sharding = jax.sharding.SingleDeviceSharding(callback_device)
|
|
return jax.pure_callback(func, x, x, sharding=sharding)
|
|
|
|
f_jit = jax.jit(f, in_shardings=sharding, out_shardings=sharding)
|
|
|
|
inp = jnp.arange(float(jax.local_device_count()))
|
|
out = f_jit(inp)
|
|
jax.block_until_ready(out)
|
|
np.testing.assert_allclose(
|
|
out, np.arange(jax.local_device_count()) * 2
|
|
)
|
|
stablehlo_ir = f_jit.lower(inp).as_text()
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
"sdy.sharding ="
|
|
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
|
" []>]>",
|
|
stablehlo_ir)
|
|
self.assertIn(
|
|
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
|
f" device_ids=[{callback_device_index}]>",
|
|
stablehlo_ir)
|
|
else:
|
|
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
|
|
|
def test_can_shard_pure_callback_manually(self):
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
def func(x):
|
|
return x + np.arange(x.shape[0], dtype=x.dtype)
|
|
|
|
def f(x):
|
|
return jax.pure_callback(func, x, x)
|
|
f = shard_map(f, mesh=mesh, in_specs=(spec,), out_specs=spec)
|
|
|
|
inp = jnp.arange(float(jax.local_device_count() * 2))
|
|
out = jax.jit(f, in_shardings=sharding, out_shardings=sharding)(inp)
|
|
y = np.tile(np.arange(2, dtype=inp.dtype), jax.local_device_count())
|
|
jax.block_until_ready(out)
|
|
np.testing.assert_allclose(
|
|
out, inp + y
|
|
)
|
|
|
|
def test_does_not_deadlock(self):
|
|
if jtu.device_under_test() == "tpu":
|
|
self.skipTest("The test raises an exception on TPU")
|
|
|
|
def f(x):
|
|
y = jnp.asarray(x) + 1
|
|
return np.asarray(2 * jnp.log(y))
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
|
out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x)
|
|
np.testing.assert_allclose(out, 2 * jnp.log(x + 1))
|
|
|
|
def test_vmap_method_raise(self):
|
|
@jax.vmap
|
|
def f(x):
|
|
# Setting vectorized to None disables the current default behavior of
|
|
# falling back on sequential.
|
|
return jax.pure_callback(np.sin, x, x, vectorized=None)
|
|
|
|
with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"):
|
|
f(jnp.arange(4.))
|
|
|
|
def test_deprecated_vectorized(self):
|
|
def f(x, **kwargs):
|
|
return jax.pure_callback(np.sin, x, x, **kwargs)
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The default behavior"):
|
|
jax.vmap(f)(jnp.arange(4.0))
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
|
f(jnp.arange(4.0), vectorized=True)
|
|
|
|
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
|
f(jnp.arange(4.0), vectorized=False)
|
|
|
|
def test_vmap_method_expand_dims(self):
|
|
def callback(x, y):
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
self.assertTupleEqual(y.shape, (1,))
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
|
|
|
|
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
|
|
|
def test_vmap_method_broadcast_all(self):
|
|
def callback(x, y):
|
|
self.assertTupleEqual(x.shape, (4,))
|
|
self.assertTupleEqual(y.shape, (4,))
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return jax.pure_callback(callback, x, x, y,
|
|
vmap_method="broadcast_all")
|
|
|
|
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
@jtu.run_on_devices("cpu")
|
|
def test_async_deadlock(self):
|
|
self.skipTest("Too slow and memory intensive.")
|
|
|
|
# See https://github.com/jax-ml/jax/issues/24255
|
|
eig = jax.jit(jnp.linalg.eig)
|
|
|
|
def callback(x):
|
|
return jax.block_until_ready(eig(x))
|
|
|
|
def fun(x):
|
|
self.assertEqual(x.dtype, jnp.complex64)
|
|
out_type = (
|
|
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
|
|
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
)
|
|
return jax.pure_callback(callback, out_type, x)
|
|
|
|
result = 0.0
|
|
for _ in range(10):
|
|
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
|
|
jax.block_until_ready(result) # doesn't deadlock
|
|
|
|
def test_non_default_stride(self):
|
|
x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4)
|
|
def callback(x):
|
|
return np.asfortranarray(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return jax.pure_callback(
|
|
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
|
|
)
|
|
|
|
result = f(x)
|
|
np.testing.assert_array_equal(x, result)
|
|
|
|
|
|
class IOCallbackTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.test_device_matches(["cpu", "gpu", "tpu"]):
|
|
self.skipTest(f"Host callback not supported on {jtu.device_under_test()}")
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
dispatch.runtime_tokens.clear()
|
|
|
|
def test_io_callback_can_mutate_state(self):
|
|
x = 0
|
|
def cb():
|
|
nonlocal x
|
|
x += 1
|
|
return np.array(x, np.int32)
|
|
|
|
def f():
|
|
return io_callback(cb, jax.ShapeDtypeStruct((), jnp.int32))
|
|
f()
|
|
jax.effects_barrier()
|
|
self.assertEqual(x, 1)
|
|
f()
|
|
jax.effects_barrier()
|
|
self.assertEqual(x, 2)
|
|
|
|
def test_io_callback_can_be_batched_if_unordered(self):
|
|
_mut = 0
|
|
def cb(x):
|
|
nonlocal _mut
|
|
_mut += 1
|
|
return x
|
|
|
|
x = jnp.arange(4)
|
|
def f(x):
|
|
return io_callback(cb, jax.ShapeDtypeStruct((), x.dtype), x)
|
|
jax.vmap(f)(x)
|
|
jax.effects_barrier()
|
|
self.assertEqual(_mut, 4)
|
|
jax.vmap(f)(x)
|
|
jax.effects_barrier()
|
|
self.assertEqual(_mut, 8)
|
|
|
|
def test_cannot_call_ordered_io_in_pmap(self):
|
|
def f(x):
|
|
return io_callback(
|
|
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Ordered effects not supported in `pmap`"):
|
|
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
|
|
|
def test_cannot_call_ordered_io_in_vmap(self):
|
|
def f(x):
|
|
return io_callback(
|
|
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Cannot `vmap` ordered IO callback"):
|
|
jax.vmap(f)(jnp.arange(4))
|
|
|
|
def test_cannot_use_io_callback_in_jvp(self):
|
|
def f(x):
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "IO callbacks do not support JVP."):
|
|
jax.jvp(f, (0.,), (1.,))
|
|
|
|
def test_cannot_use_io_callback_in_linearize(self):
|
|
def f(x):
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), jnp.float32), x)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "IO callbacks do not support JVP."):
|
|
jax.linearize(f, 0.)
|
|
|
|
def test_cannot_use_io_callback_in_transpose(self):
|
|
x = jnp.array(1.)
|
|
|
|
def f(x):
|
|
return io_callback(lambda x: x, jax.ShapeDtypeStruct((), x.dtype), x)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "IO callbacks do not support transpose."):
|
|
jax.linear_transpose(f, x)(x)
|
|
|
|
def test_cannot_vmap_of_cond_io_callback(self):
|
|
def f(pred):
|
|
def true_fun():
|
|
io_callback(lambda: print("true"), None)
|
|
def false_fun():
|
|
io_callback(lambda: print("false"), None)
|
|
return lax.cond(pred, false_fun, true_fun)
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
"IO effect not supported in vmap-of-cond."):
|
|
jax.vmap(f)(jnp.array([True, True]))
|
|
|
|
def test_cannot_vmap_of_while_io_callback(self):
|
|
def check(x):
|
|
assert np.all(x < 5)
|
|
|
|
def f(i):
|
|
def cond(i):
|
|
return i < 5
|
|
def body(i):
|
|
io_callback(check, None, i)
|
|
return i + 1
|
|
return lax.while_loop(cond, body, i)
|
|
with self.assertRaisesRegex(
|
|
Exception, "not supported in while_loop with batched predicate"):
|
|
jax.vmap(f)(jnp.array([0, 4]))
|
|
|
|
def test_cannot_use_io_callback_in_checkpoint(self):
|
|
@jax.grad
|
|
@jax.checkpoint
|
|
def f(x, y):
|
|
io_callback(lambda x: x, y, y)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
"Effects not supported in partial-eval of `checkpoint`"):
|
|
f(2., 3.)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name=f'{ordered=}_{with_sharding=}',
|
|
ordered=ordered,
|
|
with_sharding=with_sharding,
|
|
)
|
|
for ordered in [True, False]
|
|
for with_sharding in [True, False]
|
|
)
|
|
def test_can_use_io_callback_in_pjit(
|
|
self, *, ordered: bool, with_sharding: bool
|
|
):
|
|
devices = jax.devices()
|
|
mesh = jax.sharding.Mesh(np.array(devices), ['dev'])
|
|
|
|
_collected: list[int] = []
|
|
def _cb(x):
|
|
nonlocal _collected
|
|
_collected.append(int(x.sum()))
|
|
|
|
io_callback_kwargs = dict(ordered=ordered)
|
|
callback_device = devices[0]
|
|
if with_sharding:
|
|
callback_device = devices[-1]
|
|
io_callback_kwargs['sharding'] = jax.sharding.SingleDeviceSharding(
|
|
callback_device
|
|
)
|
|
|
|
def f(x):
|
|
io_callback(_cb, None, x, **io_callback_kwargs)
|
|
io_callback(_cb, None, x + 1, **io_callback_kwargs)
|
|
return x
|
|
|
|
in_spec = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec('dev')
|
|
)
|
|
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
|
f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec)
|
|
expected = []
|
|
with mesh:
|
|
x = jnp.arange(mesh.size)
|
|
f(x)
|
|
expected.extend([int(x.sum()), int((x + 1).sum())])
|
|
f(x + 5)
|
|
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
|
|
|
|
jax.effects_barrier()
|
|
if ordered:
|
|
self.assertAllClose(_collected, expected)
|
|
else:
|
|
self.assertEqual(len(_collected), len(expected))
|
|
for v in expected:
|
|
self.assertIn(v, _collected)
|
|
|
|
callback_device_index = in_spec._device_assignment.index(callback_device)
|
|
stablehlo_ir = f.lower(x).as_text()
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
"sdy.sharding ="
|
|
f" #sdy.sharding_per_value<[<@maximal_mesh_{callback_device_index},"
|
|
" []>]>",
|
|
stablehlo_ir)
|
|
self.assertIn(
|
|
f"sdy.mesh @maximal_mesh_{callback_device_index} = <[],"
|
|
f" device_ids=[{callback_device_index}]>",
|
|
stablehlo_ir)
|
|
else:
|
|
self.assertIn(f"{{maximal device={callback_device_index}}}", stablehlo_ir)
|
|
|
|
def test_sequence_pjit_io_callback_ordered(self):
|
|
# A sequence of pairs of calls to pjit(io_callback(ordered=True)) with each
|
|
# pair on a different device assignment.
|
|
_collected: list[int] = []
|
|
def _cb(i, x):
|
|
nonlocal _collected
|
|
# Sleep different amounts of time, to test the ordering.
|
|
time.sleep([0.02, 0.03, 0.04][len(_collected) % 3])
|
|
logging.info('Collected iteration %s: %s', i, x)
|
|
_collected.append(int(x.sum()))
|
|
|
|
def f_base(i, x):
|
|
io_callback(_cb, None, i, x, ordered=True)
|
|
io_callback(_cb, None, i, x + 1, ordered=True)
|
|
|
|
nr_iterations = 8
|
|
# TODO(zce): If I pin to 1 device below (jax.devices()[:1]) then this test
|
|
# flakes. It also flakes when pinned to 2 devices. It seems that repeatedly
|
|
# dispatching to the same device triggers the problem.
|
|
devices = jax.devices()
|
|
expected = [] # The expected value for _collected
|
|
for i in range(nr_iterations):
|
|
if len(devices) > 1:
|
|
devices_for_iteration = [
|
|
devices[i % len(devices)],
|
|
devices[(i + 1) % len(devices)],
|
|
]
|
|
else:
|
|
devices_for_iteration = devices
|
|
logging.info(
|
|
'Running iteration %d on devices %s', i, devices_for_iteration
|
|
)
|
|
mesh = jax.sharding.Mesh(np.array(devices_for_iteration), ['dev'])
|
|
in_spec = (
|
|
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
|
|
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev')),
|
|
)
|
|
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
|
f = pjit.pjit(f_base, in_shardings=in_spec, out_shardings=out_spec)
|
|
with mesh:
|
|
x = jax.device_put(
|
|
np.arange(len(devices_for_iteration), dtype=np.int32) + 10 * i,
|
|
in_spec[1],
|
|
)
|
|
f(i, x)
|
|
expected.extend([int(x.sum()), int((x + 1).sum())])
|
|
f(i, x + 5)
|
|
expected.extend([int((x + 5).sum()), int((x + 6).sum())])
|
|
|
|
jax.effects_barrier()
|
|
self.assertEqual(_collected, expected)
|
|
|
|
def test_can_shard_io_callback_manually(self):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("TODO(b/384938613): Failing under shardy.")
|
|
|
|
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
|
|
|
spec = jax.sharding.PartitionSpec('x')
|
|
sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
|
|
_collected = collections.defaultdict(list)
|
|
|
|
def func(shard_id, x):
|
|
nonlocal _collected
|
|
_collected[shard_id.item()].append(x)
|
|
|
|
def f(shard_ids, x):
|
|
io_callback(func, None, shard_ids, x, ordered=True)
|
|
io_callback(func, None, shard_ids, x + 1, ordered=True)
|
|
f = shard_map(f, mesh=mesh, in_specs=spec, out_specs=None)
|
|
|
|
shard_ids = jnp.arange(mesh.devices.size)
|
|
inp = jnp.arange(2 * jax.local_device_count())
|
|
jax.jit(f, in_shardings=sharding, out_shardings=None)(shard_ids, inp)
|
|
jax.effects_barrier()
|
|
|
|
self.assertLen(_collected, mesh.devices.size)
|
|
# Verify the partial ordering: no specified order across shards, but strict
|
|
# ordering between the two calls in each shard.
|
|
for shard in _collected.values():
|
|
self.assertLen(shard, 2)
|
|
np.testing.assert_array_equal(shard[0] + 1, shard[1])
|
|
|
|
def test_batching_with_side_effects(self):
|
|
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
|
|
x_lst = []
|
|
def append_x(x):
|
|
nonlocal x_lst
|
|
x_lst.append(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
io_callback(append_x, None, x, ordered=False)
|
|
io_callback(append_x, None, 2 * x, ordered=False)
|
|
|
|
jax.vmap(f)(jnp.arange(3.))
|
|
jax.effects_barrier()
|
|
self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False)
|
|
|
|
def test_batching_with_side_effects_while_loop(self):
|
|
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
|
|
x_lst = []
|
|
def append_x(x):
|
|
nonlocal x_lst
|
|
x_lst.append(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
def body(i):
|
|
io_callback(append_x, None, x, ordered=False)
|
|
io_callback(append_x, None, 2 * x, ordered=False)
|
|
return i + 1
|
|
jax.lax.while_loop(lambda i: i < 2, body, 0)
|
|
|
|
jax.vmap(f)(jnp.arange(3.)) # don't crash
|
|
jax.effects_barrier()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|