mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Raise a good error message when a ShapeDtypeStruct is closed over as a const which is not a valid arg during execution.
PiperOrigin-RevId: 540296131
This commit is contained in:
parent
0089749f54
commit
38b9bf8cac
@ -165,7 +165,8 @@ def canonicalize_dtype(x):
|
||||
if handler: return handler(x)
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return canonicalize_dtype(x.__jax_array__())
|
||||
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
|
||||
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid "
|
||||
"JAX type.")
|
||||
|
||||
def _canonicalize_masked_array_dtype(x):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
|
@ -3462,6 +3462,12 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
|
||||
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash
|
||||
|
||||
def test_shape_dtype_struct_as_const_error(self):
|
||||
const = jax.ShapeDtypeStruct((8,), jnp.int32)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
r"Argument.*is not a valid JAX type"):
|
||||
jax.jit(lambda x: (x, const))(jnp.arange(8))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user