Fix jnp.unwrap() test failures on GPU.

A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16.

In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function.

PiperOrigin-RevId: 476494989
This commit is contained in:
Peter Hawkins 2022-09-23 17:11:10 -07:00 committed by jax authors
parent d2fcfb6b83
commit 8ee7129874

View File

@ -1073,7 +1073,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for discont in [None, "pi", 2]
for period in ["2pi", "pi"]
for axis in list(range(-len(shape), len(shape)))))
@jtu.skip_on_devices('gpu')
def testUnwrap(self, shape, dtype, axis, discont, period):
if numpy_version < (1, 21) and period != "2pi":
self.skipTest("numpy < 1.21 does not support the period argument to unwrap()")
@ -1082,18 +1081,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
discont = special_vals.get(discont, discont)
rng = jtu.rand_default(self.rng())
if numpy_version < (1, 21):
np_fun = partial(np.unwrap, axis=axis, discont=discont)
else:
np_fun = partial(np.unwrap, axis=axis, discont=discont, period=period)
def np_fun(x):
dtype = None
if x.dtype == dtypes.bfloat16:
dtype = x.dtype
x = x.astype(np.float32)
if numpy_version < (1, 21):
out = np.unwrap(x, axis=axis, discont=discont or np.pi)
else:
out = np.unwrap(x, axis=axis, discont=discont, period=period)
return out if dtype is None else out.astype(dtype)
jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period)
if not dtypes.issubdtype(dtype, np.inexact):
# This case requires implicit dtype promotion
jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if dtype != jnp.bfloat16: # numpy crashes on bfloat16
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False,
atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2})
self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(