fix opaque dtype case of dtypes.dtype

This commit is contained in:
Roy Frostig 2023-04-15 20:06:37 -07:00
parent 20896c1b2d
commit e9061953f6
2 changed files with 7 additions and 2 deletions

View File

@ -617,10 +617,10 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
dt = np.result_type(x)
except TypeError as err:
raise TypeError(f"Cannot determine dtype of {x}") from err
if dt not in _jax_dtype_set:
if dt not in _jax_dtype_set and not core.is_opaque_dtype(dt):
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
return canonicalize_dtype(dt) if canonicalize else dt
return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt
def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))

View File

@ -526,6 +526,11 @@ class PrngTest(jtu.JaxTestCase):
key = random.PRNGKey(0)
self.assertIsInstance(key, jax.Array)
def test_key_output_vjp(self):
# See https://github.com/google/jax/issues/14856
def f(seed): return random.PRNGKey(seed)
jax.vjp(f, 1) # doesn't crash
class ThreefryPrngTest(jtu.JaxTestCase):
def test_seed_no_implicit_transfers(self):