mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6816 from gnecula:bfloat16_random
PiperOrigin-RevId: 375417596
This commit is contained in:
commit
7ea7cea687
@ -198,6 +198,8 @@ def _issubclass(a, b):
|
||||
return False
|
||||
|
||||
def issubdtype(a, b):
|
||||
if a == "bfloat16":
|
||||
a = bfloat16
|
||||
if a == bfloat16:
|
||||
if isinstance(b, np.dtype):
|
||||
return b == _bfloat16_dtype
|
||||
|
@ -210,6 +210,13 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for samples in [uncompiled_samples, compiled_samples]:
|
||||
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
|
||||
|
||||
def testNormalBfloat16(self):
|
||||
# Passing bfloat16 as dtype string.
|
||||
# https://github.com/google/jax/issues/6813
|
||||
res_bfloat16_str = random.normal(random.PRNGKey(0), dtype='bfloat16')
|
||||
res_bfloat16 = random.normal(random.PRNGKey(0), dtype=jnp.bfloat16)
|
||||
self.assertAllClose(res_bfloat16, res_bfloat16_str)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
|
||||
for dtype in complex_dtypes))
|
||||
|
Loading…
x
Reference in New Issue
Block a user