mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.dtypes: avoid erroring on non-hashable dtype
This commit is contained in:
parent
2679ece82d
commit
27893934d1
@ -752,14 +752,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
|
||||
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
|
||||
msg += f" in {fun_name}" if fun_name else ""
|
||||
raise TypeError(msg)
|
||||
if dtype is not None and np_dtype != canonicalize_dtype(dtype):
|
||||
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
|
||||
msg = ("Explicitly requested dtype {} {} is not available, "
|
||||
"and will be truncated to dtype {}. To enable more dtypes, set the "
|
||||
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
|
||||
"environment variable. "
|
||||
"See https://github.com/google/jax#current-gotchas for more.")
|
||||
fun_name = f"requested in {fun_name}" if fun_name else ""
|
||||
truncated_dtype = canonicalize_dtype(dtype).name
|
||||
truncated_dtype = canonicalize_dtype(np_dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)
|
||||
|
||||
def safe_to_cast(input_dtype_or_value: Any,
|
||||
|
@ -558,6 +558,13 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
|
||||
self.assertAllClose(new_scale, jnp.float32(1.0))
|
||||
|
||||
def test_check_dtype_non_hashable(self):
|
||||
# regression test for issue with checking non-hashable custom dtype
|
||||
class MyDtype:
|
||||
__hash__ = None
|
||||
dtype = np.dtype('float32')
|
||||
dtypes.check_user_dtype_supported(MyDtype())
|
||||
|
||||
class EArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user