jax.dtypes: avoid erroring on non-hashable dtype

This commit is contained in:
Jake VanderPlas 2024-06-13 10:44:42 -07:00
parent 2679ece82d
commit 27893934d1
2 changed files with 9 additions and 2 deletions

View File

@ -752,14 +752,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
if dtype is not None and np_dtype != canonicalize_dtype(dtype):
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
msg = ("Explicitly requested dtype {} {} is not available, "
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. "
"See https://github.com/google/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = canonicalize_dtype(dtype).name
truncated_dtype = canonicalize_dtype(np_dtype).name
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)
def safe_to_cast(input_dtype_or_value: Any,

View File

@ -558,6 +558,13 @@ class DtypesTest(jtu.JaxTestCase):
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
self.assertAllClose(new_scale, jnp.float32(1.0))
def test_check_dtype_non_hashable(self):
# regression test for issue with checking non-hashable custom dtype
class MyDtype:
__hash__ = None
dtype = np.dtype('float32')
dtypes.check_user_dtype_supported(MyDtype())
class EArrayTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])