Merge pull request #26092 from mattjj:make-array-error

PiperOrigin-RevId: 719667108
This commit is contained in:
jax authors 2025-01-25 09:56:04 -08:00
commit a8adf75295
2 changed files with 24 additions and 6 deletions

View File

@ -1014,18 +1014,25 @@ def make_array_from_single_device_arrays(
"""
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
if any(isinstance(arr, core.Tracer) for arr in arrays):
raise ValueError(
"jax.make_array_from_single_device_arrays requires a list of concrete"
f" arrays as input. got types {set(map(type, arrays))}")
aval = core.update_aval_with_sharding(
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
committed=True)
# TODO(phawkins): ideally the cast() could be checked.
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)
try:
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)
except TypeError:
if not isinstance(arrays, Sequence):
raise TypeError("jax.make_array_from_single_device_arrays `arrays` "
"argument must be a Sequence (list or tuple), but got "
f"{type(arrays)}.")
if any(isinstance(arr, core.Tracer) for arr in arrays):
raise ValueError(
"jax.make_array_from_single_device_arrays requires a list of concrete"
f" arrays as input, but got types {set(map(type, arrays))}")
raise
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity

View File

@ -1250,6 +1250,17 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(f)(x)
def test_make_array_from_single_device_arrays_nonlist_error(self):
x = jnp.arange(10)
sharding = x.sharding
def f(x):
return jax.make_array_from_single_device_arrays(x.shape, sharding, x)
msg = "jax.make_array_from_single_device_arrays `arrays` argument"
with self.assertRaisesRegex(TypeError, msg):
jax.jit(f)(x)
def test_make_array_from_single_device_arrays_bad_inputs(self):
x = jnp.arange(10)
mesh = jtu.create_mesh((2,), ('x',))