mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19580 from jakevdp:callback-error
PiperOrigin-RevId: 603228092
This commit is contained in:
commit
26995ad800
@ -32,6 +32,7 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
@ -642,7 +643,9 @@ def make_array_from_callback(
|
||||
"""Returns a ``jax.Array`` via data fetched from ``data_callback``.
|
||||
|
||||
``data_callback`` is used to fetch the data for each addressable shard of the
|
||||
returned ``jax.Array``.
|
||||
returned ``jax.Array``. This function must return concrete arrays, meaning that
|
||||
``make_array_from_callback`` has limited compatibility with JAX transformations
|
||||
like :func:`jit` or :func:`vmap`.
|
||||
|
||||
Args:
|
||||
shape : Shape of the ``jax.Array``.
|
||||
@ -688,6 +691,10 @@ def make_array_from_callback(
|
||||
per_device_values = [data_callback(device_to_index_map[device])
|
||||
for device in devices]
|
||||
|
||||
if isinstance(per_device_values[0], core.Tracer):
|
||||
raise errors.UnexpectedTracerError(
|
||||
"jax.make_array_from_callback cannot be called within a traced context.")
|
||||
|
||||
first_value = xla.canonicalize_dtype(per_device_values[0])
|
||||
aval = core.ShapedArray(shape, first_value.dtype, weak_type=False)
|
||||
|
||||
@ -794,6 +801,9 @@ def make_array_from_single_device_arrays(
|
||||
"""
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
if any(isinstance(arr, core.Tracer) for arr in arrays):
|
||||
raise ValueError("jax.make_array_from_single_device_arrays requires a list of concrete arrays as input. "
|
||||
f"got types {set(map(type, arrays))}")
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
|
||||
|
@ -1157,6 +1157,36 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
self.assertEqual(str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2)")
|
||||
|
||||
def test_make_array_from_callback_error(self):
|
||||
mesh_shape = (2, 3)
|
||||
global_shape = tuple(np.square(mesh_shape))
|
||||
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
pspec = P('x', 'y')
|
||||
sharding = jax.sharding.NamedSharding(mesh, pspec)
|
||||
n = math.prod(global_shape)
|
||||
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
|
||||
|
||||
def f(arr):
|
||||
return array.make_array_from_callback(arr.shape, sharding, lambda i: arr[i])
|
||||
|
||||
out = f(global_x)
|
||||
self.assertEqual(out.shape, global_shape)
|
||||
|
||||
msg = "jax.make_array_from_callback cannot be called within a traced context"
|
||||
with self.assertRaisesRegex(jax.errors.UnexpectedTracerError, msg):
|
||||
jax.jit(f)(global_x)
|
||||
|
||||
def test_make_array_from_single_device_arrays_error(self):
|
||||
x = jnp.arange(10)
|
||||
sharding = x.sharding
|
||||
|
||||
def f(x):
|
||||
return jax.make_array_from_single_device_arrays(x.shape, sharding, [x])
|
||||
|
||||
msg = "jax.make_array_from_single_device_arrays requires a list of concrete arrays"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(f)(x)
|
||||
|
||||
|
||||
class RngShardingTest(jtu.JaxTestCase):
|
||||
# tests that the PRNGs are automatically sharded as expected
|
||||
|
Loading…
x
Reference in New Issue
Block a user