mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #26092 from mattjj:make-array-error
PiperOrigin-RevId: 719667108
This commit is contained in:
commit
a8adf75295
@ -1014,18 +1014,25 @@ def make_array_from_single_device_arrays(
|
||||
"""
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
if any(isinstance(arr, core.Tracer) for arr in arrays):
|
||||
raise ValueError(
|
||||
"jax.make_array_from_single_device_arrays requires a list of concrete"
|
||||
f" arrays as input. got types {set(map(type, arrays))}")
|
||||
aval = core.update_aval_with_sharding(
|
||||
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
|
||||
committed=True)
|
||||
# TODO(phawkins): ideally the cast() could be checked.
|
||||
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
|
||||
committed=True)
|
||||
try:
|
||||
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
|
||||
committed=True)
|
||||
except TypeError:
|
||||
if not isinstance(arrays, Sequence):
|
||||
raise TypeError("jax.make_array_from_single_device_arrays `arrays` "
|
||||
"argument must be a Sequence (list or tuple), but got "
|
||||
f"{type(arrays)}.")
|
||||
if any(isinstance(arr, core.Tracer) for arr in arrays):
|
||||
raise ValueError(
|
||||
"jax.make_array_from_single_device_arrays requires a list of concrete"
|
||||
f" arrays as input, but got types {set(map(type, arrays))}")
|
||||
raise
|
||||
|
||||
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
|
||||
|
@ -1250,6 +1250,17 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(f)(x)
|
||||
|
||||
def test_make_array_from_single_device_arrays_nonlist_error(self):
|
||||
x = jnp.arange(10)
|
||||
sharding = x.sharding
|
||||
|
||||
def f(x):
|
||||
return jax.make_array_from_single_device_arrays(x.shape, sharding, x)
|
||||
|
||||
msg = "jax.make_array_from_single_device_arrays `arrays` argument"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jax.jit(f)(x)
|
||||
|
||||
def test_make_array_from_single_device_arrays_bad_inputs(self):
|
||||
x = jnp.arange(10)
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user