Raise a good error message when a ShapeDtypeStruct is closed over as a const which is not a valid arg during execution.

PiperOrigin-RevId: 540296131
This commit is contained in:
Yash Katariya 2023-06-14 09:39:54 -07:00 committed by jax authors
parent 0089749f54
commit 38b9bf8cac
2 changed files with 8 additions and 1 deletions

View File

@ -165,7 +165,8 @@ def canonicalize_dtype(x):
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid "
"JAX type.")
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "

View File

@ -3462,6 +3462,12 @@ class ArrayPjitTest(jtu.JaxTestCase):
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash
def test_shape_dtype_struct_as_const_error(self):
const = jax.ShapeDtypeStruct((8,), jnp.int32)
with self.assertRaisesRegex(TypeError,
r"Argument.*is not a valid JAX type"):
jax.jit(lambda x: (x, const))(jnp.arange(8))
class TempSharding(Sharding):