Merge pull request #6816 from gnecula:bfloat16_random

PiperOrigin-RevId: 375417596
This commit is contained in:
jax authors 2021-05-23 22:58:35 -07:00
commit 7ea7cea687
2 changed files with 9 additions and 0 deletions

View File

@ -198,6 +198,8 @@ def _issubclass(a, b):
return False
def issubdtype(a, b):
if a == "bfloat16":
a = bfloat16
if a == bfloat16:
if isinstance(b, np.dtype):
return b == _bfloat16_dtype

View File

@ -210,6 +210,13 @@ class LaxRandomTest(jtu.JaxTestCase):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)
def testNormalBfloat16(self):
# Passing bfloat16 as dtype string.
# https://github.com/google/jax/issues/6813
res_bfloat16_str = random.normal(random.PRNGKey(0), dtype='bfloat16')
res_bfloat16 = random.normal(random.PRNGKey(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in complex_dtypes))