mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal. Fixed the checks that the callback returns values of expected shape and dtype, and added tests. Reverts 19e6156ccec0df7a900471df7840bc421da2898b PiperOrigin-RevId: 619156176
This commit is contained in:
parent
33cf53c413
commit
75db481299
@ -2461,15 +2461,22 @@ def emit_python_callback(
|
||||
raise RuntimeError(
|
||||
"Mismatched number of outputs from callback. "
|
||||
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
|
||||
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
|
||||
out_vals = tuple(np.asarray(a) for a in out_vals)
|
||||
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
|
||||
if out_val.shape != out_aval.shape:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output shape for return value {i}: "
|
||||
"Expected: {}, Actual: {}".format(out_aval.shape, out_val.shape))
|
||||
f"Incorrect output shape for return value #{i}: "
|
||||
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
|
||||
if out_val.dtype != dtypes.canonicalize_dtype(out_val.dtype):
|
||||
raise RuntimeError(
|
||||
"Cannot return 64-bit values when `jax_enable_x64` is disabled. "
|
||||
f"Actual: {out_val.dtype}")
|
||||
if out_val.dtype != out_aval.dtype:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output dtype for return value {i}: "
|
||||
"Expected: {}, Actual: {}".format(out_aval.dtype, out_val.dtype))
|
||||
f"Incorrect output dtype for return value #{i}: "
|
||||
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
|
||||
|
||||
if platform == "tpu":
|
||||
# On TPU we cannot receive empty arrays. So, we return from the wrapped
|
||||
# callback only the non-empty results, and we will create empty constants
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import textwrap
|
||||
@ -27,11 +28,11 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import maps
|
||||
from jax._src.maps import xmap
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.maps import xmap
|
||||
from jax.experimental import io_callback
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -72,6 +73,7 @@ with_pure_and_io_callbacks = parameterized.named_parameters(
|
||||
for flavor in ("io_unordered", "io_ordered", "pure")
|
||||
)
|
||||
|
||||
|
||||
class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -93,6 +95,73 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
out = f(0.)
|
||||
self.assertEqual(out, 1.)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name=f"{flavor}_expect_dtype_{expect_dtype}",
|
||||
callback=dict(
|
||||
io_unordered=io_calback_unordered,
|
||||
io_ordered=io_callback_ordered,
|
||||
pure=jax.pure_callback,
|
||||
)[flavor],
|
||||
expect_dtype=expect_dtype,
|
||||
)
|
||||
for flavor in ("io_unordered", "io_ordered", "pure")
|
||||
for expect_dtype in (np.int32, np.int64, np.float32, np.float64)
|
||||
)
|
||||
def test_callback_returning_python_literal(self, *, callback, expect_dtype):
|
||||
returned_literal = 42 if expect_dtype in (np.int32, np.int64) else 42.0
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback(
|
||||
lambda x: returned_literal, core.ShapedArray((), expect_dtype), x
|
||||
)
|
||||
|
||||
if not config.enable_x64.value:
|
||||
ctx = self.assertRaisesRegex(Exception, "Cannot return 64-bit values")
|
||||
elif expect_dtype in (np.int32, np.float32):
|
||||
ctx = self.assertRaisesRegex(Exception, "Incorrect output dtype")
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
out = f(0.0)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(out, returned_literal)
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
def test_callback_returning_custom_array(self, *, callback):
|
||||
# Some users write the callback in TF, returning a tf.Tensor. We don't
|
||||
# want to add TF as a dependency, but simulate that use case with a
|
||||
# custom array class.
|
||||
class CustomArray:
|
||||
|
||||
def __init__(self, a: np.ndarray):
|
||||
self.a = a
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.a.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.a.dtype
|
||||
|
||||
def __array__(self):
|
||||
return self.a
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback(
|
||||
lambda x: CustomArray(np.array(42.0, dtype=np.float32)),
|
||||
core.ShapedArray((), np.float32),
|
||||
x,
|
||||
)
|
||||
|
||||
out = f(0.0)
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(out, 42.0)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"{flavor}_{dtype}",
|
||||
dtype=dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user