mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fix opaque dtype case of dtypes.dtype
This commit is contained in:
parent
20896c1b2d
commit
e9061953f6
@ -617,10 +617,10 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
|
||||
dt = np.result_type(x)
|
||||
except TypeError as err:
|
||||
raise TypeError(f"Cannot determine dtype of {x}") from err
|
||||
if dt not in _jax_dtype_set:
|
||||
if dt not in _jax_dtype_set and not core.is_opaque_dtype(dt):
|
||||
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
|
||||
"type. Only arrays of numeric types are supported by JAX.")
|
||||
return canonicalize_dtype(dt) if canonicalize else dt
|
||||
return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt
|
||||
|
||||
def _lattice_result_type(*args: Any) -> Tuple[DType, bool]:
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
|
@ -526,6 +526,11 @@ class PrngTest(jtu.JaxTestCase):
|
||||
key = random.PRNGKey(0)
|
||||
self.assertIsInstance(key, jax.Array)
|
||||
|
||||
def test_key_output_vjp(self):
|
||||
# See https://github.com/google/jax/issues/14856
|
||||
def f(seed): return random.PRNGKey(seed)
|
||||
jax.vjp(f, 1) # doesn't crash
|
||||
|
||||
|
||||
class ThreefryPrngTest(jtu.JaxTestCase):
|
||||
def test_seed_no_implicit_transfers(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user