Merge pull request #19580 from jakevdp:callback-error

PiperOrigin-RevId: 603228092
This commit is contained in:
jax authors 2024-01-31 19:09:09 -08:00
commit 26995ad800
2 changed files with 41 additions and 1 deletions

View File

@ -32,6 +32,7 @@ from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
@ -642,7 +643,9 @@ def make_array_from_callback(
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
``data_callback`` is used to fetch the data for each addressable shard of the
returned ``jax.Array``.
returned ``jax.Array``. This function must return concrete arrays, meaning that
``make_array_from_callback`` has limited compatibility with JAX transformations
like :func:`jit` or :func:`vmap`.
Args:
shape : Shape of the ``jax.Array``.
@ -688,6 +691,10 @@ def make_array_from_callback(
per_device_values = [data_callback(device_to_index_map[device])
for device in devices]
if isinstance(per_device_values[0], core.Tracer):
raise errors.UnexpectedTracerError(
"jax.make_array_from_callback cannot be called within a traced context.")
first_value = xla.canonicalize_dtype(per_device_values[0])
aval = core.ShapedArray(shape, first_value.dtype, weak_type=False)
@ -794,6 +801,9 @@ def make_array_from_single_device_arrays(
"""
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
if any(isinstance(arr, core.Tracer) for arr in arrays):
raise ValueError("jax.make_array_from_single_device_arrays requires a list of concrete arrays as input. "
f"got types {set(map(type, arrays))}")
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)

View File

@ -1157,6 +1157,36 @@ class ShardingTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
self.assertEqual(str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2)")
def test_make_array_from_callback_error(self):
mesh_shape = (2, 3)
global_shape = tuple(np.square(mesh_shape))
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
pspec = P('x', 'y')
sharding = jax.sharding.NamedSharding(mesh, pspec)
n = math.prod(global_shape)
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
def f(arr):
return array.make_array_from_callback(arr.shape, sharding, lambda i: arr[i])
out = f(global_x)
self.assertEqual(out.shape, global_shape)
msg = "jax.make_array_from_callback cannot be called within a traced context"
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, msg):
jax.jit(f)(global_x)
def test_make_array_from_single_device_arrays_error(self):
x = jnp.arange(10)
sharding = x.sharding
def f(x):
return jax.make_array_from_single_device_arrays(x.shape, sharding, [x])
msg = "jax.make_array_from_single_device_arrays requires a list of concrete arrays"
with self.assertRaisesRegex(ValueError, msg):
jax.jit(f)(x)
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected